diff --git a/rip/db.py b/rip/db.py index b83e40b..f9db66d 100644 --- a/rip/db.py +++ b/rip/db.py @@ -3,76 +3,96 @@ import logging import os import sqlite3 -from typing import Union +from typing import Union, List +import abc logger = logging.getLogger("streamrip") -class MusicDB: - """Simple interface for the downloaded track database.""" +class Database: + # list of table column names + structure: list + # name of table + name: str - def __init__(self, db_path: Union[str, os.PathLike], empty=False): - """Create a MusicDB object. + def __init__(self, path, empty=False): + assert self.structure != [] + assert self.name - :param db_path: filepath of the database - :type db_path: Union[str, os.PathLike] - """ if empty: self.path = None return - self.path = db_path + self.path = path if not os.path.exists(self.path): self.create() def create(self): - """Create a database at `self.path`.""" - with sqlite3.connect(self.path) as conn: - try: - conn.execute("CREATE TABLE downloads (id TEXT UNIQUE NOT NULL);") - logger.debug("Download-IDs database created: %s", self.path) - except sqlite3.OperationalError: - pass - - return self.path - - def __contains__(self, item_id: Union[str, int]) -> bool: - """Check whether the database contains an id. - - :param item_id: the id to check - :type item_id: str - :rtype: bool - """ - if self.path is None: - return False - - logger.debug("Checking database for ID %s", item_id) - with sqlite3.connect(self.path) as conn: - return ( - conn.execute( - "SELECT id FROM downloads where id=?", (item_id,) - ).fetchone() - is not None - ) - - def add(self, item_id: str): - """Add an id to the database. - - :param item_id: - :type item_id: str - """ - logger.debug("Adding ID %s", item_id) - if self.path is None: return with sqlite3.connect(self.path) as conn: try: - conn.execute( - "INSERT INTO downloads (id) VALUES (?)", - (item_id,), + params = ", ".join( + f"{key} TEXT UNIQUE NOT NULL" for key in self.structure ) - conn.commit() - except sqlite3.Error as err: - if "UNIQUE" not in str(err): - raise + command = f"CREATE TABLE {self.name} ({params});" + + logger.debug(f"executing {command}") + + conn.execute(command) + except sqlite3.OperationalError: + pass + + def keys(self): + return self.structure + + def contains(self, **items): + allowed_keys = set(self.structure) + assert all( + key in allowed_keys for key in items.keys() + ), f"Invalid key. Valid keys: {self.structure}" + + items = {k: str(v) for k, v in items.items()} + + if self.path is None: + return False + + with sqlite3.connect(self.path) as conn: + conditions = " AND ".join(f"{key}=?" for key in items.keys()) + command = f"SELECT {self.structure[0]} FROM {self.name} WHERE {conditions}" + + logger.debug(f"executing {command}") + + return conn.execute(command, tuple(items.values())).fetchone() is not None + + def __contains__(self, keys: dict) -> bool: + return self.contains(**keys) + + def add(self, items: List[str]): + assert len(items) == len(self.structure) + if self.path is None: + return + + params = ", ".join(self.structure) + question_marks = ", ".join("?" for _ in items) + command = f"INSERT INTO {self.name} ({params}) VALUES ({question_marks})" + + logger.debug(f"executing {command}") + + with sqlite3.connect(self.path) as conn: + conn.execute(command, tuple(items)) + + def __iter__(self): + with sqlite3.connect(self.path) as conn: + return conn.execute(f"SELECT * FROM {self.name}") + + +class Downloads(Database): + structure = ["id"] + name = "downloads" + + +class FailedDownloads(Database): + structure = ["source", "type", "id"] + name = "failed_downloads"