Improve database, refactor artist and label

This commit is contained in:
Nathan Thomas 2023-11-23 18:40:50 -08:00
parent f9b263a718
commit 1964a0e488
12 changed files with 258 additions and 145 deletions

48
streamrip/album_list.py Normal file
View file

@ -0,0 +1,48 @@
import asyncio
from dataclasses import dataclass
from .album import PendingAlbum
from .client import Client
from .config import Config
from .media import Media
@dataclass(slots=True)
class AlbumList(Media):
"""Represents a list of albums. Used by Artist and Label classes."""
name: str
albums: list[PendingAlbum]
client: Client
config: Config
async def preprocess(self):
pass
async def download(self):
# Resolve only 3 albums at a time to avoid
# initial latency of resolving ALL albums and tracks
# before any downloads
album_resolve_chunk_size = 10
async def _resolve_download(item: PendingAlbum):
album = await item.resolve()
if album is None:
return
await album.rip()
batches = self.batch(
[_resolve_download(album) for album in self.albums],
album_resolve_chunk_size,
)
for batch in batches:
await asyncio.gather(*batch)
async def postprocess(self):
pass
@staticmethod
def batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx : min(ndx + n, l)]

View file

@ -1,15 +1,28 @@
from .album import Album, PendingAlbum from dataclasses import dataclass
from .album import PendingAlbum
from .album_list import AlbumList
from .client import Client from .client import Client
from .config import Config from .config import Config
from .media import Media, Pending from .media import Pending
from .metadata import ArtistMetadata
class Artist(Media): class Artist(AlbumList):
name: str pass
albums: list[PendingAlbum]
config: Config
@dataclass(slots=True)
class PendingArtist(Pending): class PendingArtist(Pending):
id: str id: str
client: Client client: Client
config: Config
async def resolve(self) -> Artist:
resp = await self.client.get_metadata(self.id, "artist")
meta = ArtistMetadata.from_resp(resp, self.client.source)
albums = [
PendingAlbum(album_id, self.client, self.config)
for album_id in meta.album_ids()
]
return Artist(meta.name, albums, self.client, self.config)

View file

@ -45,20 +45,34 @@ def coro(f):
) )
@click.version_option(version="2.0") @click.version_option(version="2.0")
@click.option( @click.option(
"-c", "--config-path", default=CONFIG_PATH, help="Path to the configuration file" "--config-path", default=CONFIG_PATH, help="Path to the configuration file"
)
@click.option("-f", "--folder", help="The folder to download items into.")
@click.option(
"-ndb",
"--no-db",
help="Download items even if they have been logged in the database",
default=False,
is_flag=True,
)
@click.option("-q", "--quality", help="The maximum quality allowed to download")
@click.option(
"-c",
"--convert",
help="Convert the downloaded files to an audio codec (ALAC, FLAC, MP3, AAC, or OGG)",
) )
@click.option( @click.option(
"-v", "--verbose", help="Enable verbose output (debug mode)", is_flag=True "-v", "--verbose", help="Enable verbose output (debug mode)", is_flag=True
) )
@click.pass_context @click.pass_context
def rip(ctx, config_path, verbose): def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
""" """
Streamrip: the all in one music downloader. Streamrip: the all in one music downloader.
""" """
print(ctx, config_path, folder, no_db, quality, convert, verbose)
global logger global logger
FORMAT = "%(message)s"
logging.basicConfig( logging.basicConfig(
level="WARNING", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] level="WARNING", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
) )
logger = logging.getLogger("streamrip") logger = logging.getLogger("streamrip")
if verbose: if verbose:
@ -74,14 +88,29 @@ def rip(ctx, config_path, verbose):
install(console=console, suppress=[click, asyncio], max_frames=1) install(console=console, suppress=[click, asyncio], max_frames=1)
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
ctx.ensure_object(dict)
if not os.path.isfile(config_path): if not os.path.isfile(config_path):
echo_i(f"No file found at {config_path}, creating default config.") echo_i(f"No file found at {config_path}, creating default config.")
shutil.copy(BLANK_CONFIG_PATH, config_path) shutil.copy(BLANK_CONFIG_PATH, config_path)
set_user_defaults(config_path) set_user_defaults(config_path)
ctx.obj["config_path"] = config_path # pass to subcommands
ctx.obj["verbose"] = verbose ctx.ensure_object(dict)
c = Config(config_path)
# set session config values to command line args
if folder is not None:
c.session.downloads.folder = folder
c.session.database.downloads_enabled = not no_db
c.session.qobuz.quality = quality
c.session.tidal.quality = quality
c.session.deezer.quality = quality
c.session.soundcloud.quality = quality
if convert is not None:
c.session.conversion.enabled = True
assert convert.upper() in ("ALAC", "FLAC", "OGG", "MP3", "AAC")
c.session.conversion.codec = convert.upper()
ctx.obj["config"] = c
@rip.command() @rip.command()
@ -95,8 +124,7 @@ async def url(ctx, urls):
rip url TODO: find url rip url TODO: find url
""" """
config_path = ctx.obj["config_path"] with ctx.obj["config"] as cfg:
with Config(config_path) as cfg:
main = Main(cfg) main = Main(cfg)
for u in urls: for u in urls:
await main.add(u) await main.add(u)
@ -115,8 +143,7 @@ async def file(ctx, path):
rip file urls.txt rip file urls.txt
""" """
config_path = ctx.obj["config_path"] with ctx.obj["config"] as cfg:
with Config(config_path) as cfg:
main = Main(cfg) main = Main(cfg)
with open(path) as f: with open(path) as f:
for url in f: for url in f:

View file

@ -3,39 +3,52 @@
import logging import logging
import os import os
import sqlite3 import sqlite3
from abc import ABC, abstractmethod
logger = logging.getLogger("streamrip") logger = logging.getLogger("streamrip")
# apologies to anyone reading this file
class Database:
class Database(ABC):
@abstractmethod
def create(self):
pass
@abstractmethod
def contains(self, **items) -> bool:
pass
@abstractmethod
def add(self, kvs):
pass
@abstractmethod
def remove(self, kvs):
pass
class DatabaseBase(Database):
"""A wrapper for an sqlite database.""" """A wrapper for an sqlite database."""
structure: dict structure: dict
name: str name: str
def __init__(self, path: str, dummy: bool = False): def __init__(self, path: str):
"""Create a Database instance. """Create a Database instance.
:param path: Path to the database file. :param path: Path to the database file.
:param dummy: Make the database empty.
""" """
assert self.structure != [] assert self.structure != {}
assert self.name assert self.name
self.path = path self.path = path
self.is_dummy = dummy
if self.is_dummy:
return
if not os.path.exists(self.path): if not os.path.exists(self.path):
self.create() self.create()
def create(self): def create(self):
"""Create a database.""" """Create a database."""
if self.is_dummy:
return
with sqlite3.connect(self.path) as conn: with sqlite3.connect(self.path) as conn:
params = ", ".join( params = ", ".join(
f"{key} {' '.join(map(str.upper, props))} NOT NULL" f"{key} {' '.join(map(str.upper, props))} NOT NULL"
@ -57,8 +70,6 @@ class Database:
:param items: a dict of column-name + expected value :param items: a dict of column-name + expected value
:rtype: bool :rtype: bool
""" """
if self.is_dummy:
return False
allowed_keys = set(self.structure.keys()) allowed_keys = set(self.structure.keys())
assert all( assert all(
@ -75,43 +86,12 @@ class Database:
return bool(conn.execute(command, tuple(items.values())).fetchone()[0]) return bool(conn.execute(command, tuple(items.values())).fetchone()[0])
def __contains__(self, keys: str | dict) -> bool: def add(self, items: tuple[str]):
"""Check whether a key-value pair exists in the database.
:param keys: Either a dict with the structure {key: value_to_search_for, ...},
or if there is only one key in the table, value_to_search_for can be
passed in by itself.
:type keys: Union[str, dict]
:rtype: bool
"""
if isinstance(keys, dict):
return self.contains(**keys)
if isinstance(keys, str) and len(self.structure) == 1:
only_key = tuple(self.structure.keys())[0]
query = {only_key: keys}
logger.debug("Searching for %s in database", query)
return self.contains(**query)
raise TypeError(keys)
def add(self, items: str | tuple[str]):
"""Add a row to the table. """Add a row to the table.
:param items: Column-name + value. Values must be provided for all cols. :param items: Column-name + value. Values must be provided for all cols.
:type items: Tuple[str] :type items: Tuple[str]
""" """
if self.is_dummy:
return
if isinstance(items, str):
if len(self.structure) == 1:
items = (items,)
else:
raise TypeError(
"Only tables with 1 column can have string inputs. Use a list "
"where len(list) == len(structure)."
)
assert len(items) == len(self.structure) assert len(items) == len(self.structure)
@ -136,9 +116,6 @@ class Database:
:param items: :param items:
""" """
# not in use currently
if self.is_dummy:
return
conditions = " AND ".join(f"{key}=?" for key in items.keys()) conditions = " AND ".join(f"{key}=?" for key in items.keys())
command = f"DELETE FROM {self.name} WHERE {conditions}" command = f"DELETE FROM {self.name} WHERE {conditions}"
@ -149,9 +126,6 @@ class Database:
def __iter__(self): def __iter__(self):
"""Iterate through the rows of the table.""" """Iterate through the rows of the table."""
if self.is_dummy:
return ()
with sqlite3.connect(self.path) as conn: with sqlite3.connect(self.path) as conn:
return conn.execute(f"SELECT * FROM {self.name}") return conn.execute(f"SELECT * FROM {self.name}")
@ -163,7 +137,21 @@ class Database:
pass pass
class Downloads(Database): class Dummy(Database):
def create(self):
pass
def contains(self):
return False
def add(self):
pass
def remove(self):
pass
class Downloads(DatabaseBase):
"""A table that stores the downloaded IDs.""" """A table that stores the downloaded IDs."""
name = "downloads" name = "downloads"
@ -172,7 +160,7 @@ class Downloads(Database):
} }
class FailedDownloads(Database): class FailedDownloads(DatabaseBase):
"""A table that stores information about failed downloads.""" """A table that stores information about failed downloads."""
name = "failed_downloads" name = "failed_downloads"

29
streamrip/label.py Normal file
View file

@ -0,0 +1,29 @@
import asyncio
from dataclasses import dataclass
from .album import PendingAlbum
from .album_list import AlbumList
from .client import Client
from .config import Config
from .media import Pending
from .metadata import LabelMetadata
class Label(AlbumList):
pass
@dataclass(slots=True)
class PendingLabel(Pending):
id: str
client: Client
config: Config
async def resolve(self) -> Label:
resp = await self.client.get_metadata(self.id, "label")
meta = LabelMetadata.from_resp(resp, self.client.source)
albums = [
PendingAlbum(album_id, self.client, self.config)
for album_id in meta.album_ids()
]
return Label(meta.name, albums, self.client, self.config)

View file

@ -1,12 +1,16 @@
"""Manages the information that will be embeded in the audio file.""" """Manages the information that will be embeded in the audio file."""
from . import util from . import util
from .album_metadata import AlbumMetadata from .album_metadata import AlbumMetadata
from .artist_metadata import ArtistMetadata
from .covers import Covers from .covers import Covers
from .label_metadata import LabelMetadata
from .playlist_metadata import PlaylistMetadata from .playlist_metadata import PlaylistMetadata
from .track_metadata import TrackMetadata from .track_metadata import TrackMetadata
__all__ = [ __all__ = [
"AlbumMetadata", "AlbumMetadata",
"ArtistMetadata",
"LabelMetadata",
"TrackMetadata", "TrackMetadata",
"PlaylistMetadata", "PlaylistMetadata",
"Covers", "Covers",

View file

@ -5,7 +5,6 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..exceptions import NonStreamable
from .covers import Covers from .covers import Covers
from .util import get_quality_id, safe_get, typed from .util import get_quality_id, safe_get, typed

View file

@ -0,0 +1,27 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
logger = logging.getLogger("streamrip")
@dataclass(slots=True)
class ArtistMetadata:
name: str
ids: list[str]
def album_ids(self):
return self.ids
@classmethod
def from_resp(cls, resp: dict, source: str) -> ArtistMetadata:
logger.debug(resp)
if source == "qobuz":
return cls(resp["name"], [a["id"] for a in resp["albums"]["items"]])
elif source == "tidal":
return cls(resp["name"], [a["id"] for a in resp["albums"]])
elif source == "deezer":
return cls(resp["name"], [a["id"] for a in resp["albums"]])
else:
raise NotImplementedError

View file

@ -0,0 +1,27 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
logger = logging.getLogger("streamrip")
@dataclass(slots=True)
class LabelMetadata:
name: str
ids: list[str]
def album_ids(self):
return self.ids
@classmethod
def from_resp(cls, resp: dict, source: str) -> LabelMetadata:
logger.debug(resp)
if source == "qobuz":
return cls(resp["name"], [a["id"] for a in resp["albums"]["items"]])
elif source == "tidal":
return cls(resp["name"], [a["id"] for a in resp["albums"]])
elif source == "deezer":
return cls(resp["name"], [a["id"] for a in resp["albums"]])
else:
raise NotImplementedError

View file

@ -27,10 +27,12 @@ class PendingPlaylistTrack(Pending):
async def resolve(self) -> Track | None: async def resolve(self) -> Track | None:
resp = await self.client.get_metadata(self.id, "track") resp = await self.client.get_metadata(self.id, "track")
album = AlbumMetadata.from_resp(resp, self.client.source) album = AlbumMetadata.from_track_resp(resp, self.client.source)
meta = TrackMetadata.from_resp(album, self.client.source, resp) meta = TrackMetadata.from_resp(album, self.client.source, resp)
if meta is None: if meta is None:
logger.error(f"Cannot stream track ({self.id}) on {self.client.source}") logger.error(
f"Track ({self.id}) not available for stream on {self.client.source}"
)
return None return None
c = self.config.session.metadata c = self.config.session.metadata

View file

@ -4,6 +4,7 @@ from typing import Callable
from rich.console import Group from rich.console import Group
from rich.live import Live from rich.live import Live
from rich.progress import Progress from rich.progress import Progress
from rich.rule import Rule
from rich.text import Text from rich.text import Text
from .console import console from .console import console
@ -38,18 +39,20 @@ class ProgressManager:
self.live.stop() self.live.stop()
def add_title(self, title: str): def add_title(self, title: str):
self.task_titles.append(title) self.task_titles.append(title.strip())
def remove_title(self, title: str): def remove_title(self, title: str):
self.task_titles.remove(title) self.task_titles.remove(title.strip())
def get_title_text(self) -> Text: def get_title_text(self) -> Rule:
t = self.prefix + Text(", ".join(self.task_titles)) titles = ", ".join(self.task_titles[:3])
t.overflow = "ellipsis" if len(self.task_titles) > 3:
return t titles += "..."
t = self.prefix + Text(titles)
return Rule(t)
@dataclass @dataclass(slots=True)
class Handle: class Handle:
update: Callable[[int], None] update: Callable[[int], None]
done: Callable[[], None] done: Callable[[], None]
@ -66,18 +69,22 @@ _p = ProgressManager()
def get_progress_callback(enabled: bool, total: int, desc: str) -> Handle: def get_progress_callback(enabled: bool, total: int, desc: str) -> Handle:
global _p
if not enabled: if not enabled:
return Handle(lambda _: None, lambda: None) return Handle(lambda _: None, lambda: None)
return _p.get_callback(total, desc) return _p.get_callback(total, desc)
def add_title(title: str): def add_title(title: str):
global _p
_p.add_title(title) _p.add_title(title)
def remove_title(title: str): def remove_title(title: str):
global _p
_p.remove_title(title) _p.remove_title(title)
def clear_progress(): def clear_progress():
global _p
_p.cleanup() _p.cleanup()

View file

@ -4,8 +4,10 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .album import PendingAlbum from .album import PendingAlbum
from .artist import PendingArtist
from .client import Client from .client import Client
from .config import Config from .config import Config
from .label import PendingLabel
from .media import Pending from .media import Pending
from .playlist import PendingPlaylist from .playlist import PendingPlaylist
from .soundcloud_client import SoundcloudClient from .soundcloud_client import SoundcloudClient
@ -56,6 +58,10 @@ class GenericURL(URL):
return PendingAlbum(item_id, client, config) return PendingAlbum(item_id, client, config)
elif media_type == "playlist": elif media_type == "playlist":
return PendingPlaylist(item_id, client, config) return PendingPlaylist(item_id, client, config)
elif media_type == "artist":
return PendingArtist(item_id, client, config)
elif media_type == "label":
return PendingLabel(item_id, client, config)
else: else:
raise NotImplementedError raise NotImplementedError
@ -73,8 +79,7 @@ class QobuzInterpreterURL(URL):
async def into_pending(self, client: Client, config: Config) -> Pending: async def into_pending(self, client: Client, config: Config) -> Pending:
url = self.match.group(0) url = self.match.group(0)
artist_id = await self.extract_interpreter_url(url, client) artist_id = await self.extract_interpreter_url(url, client)
raise NotImplementedError return PendingArtist(artist_id, client, config)
# return PendingArtist()
@staticmethod @staticmethod
async def extract_interpreter_url(url: str, client: Client) -> str: async def extract_interpreter_url(url: str, client: Client) -> str:
@ -147,66 +152,3 @@ def parse_url(url: str) -> URL | None:
# TODO: the rest of the url types # TODO: the rest of the url types
] ]
return next((u for u in parsed_urls if u is not None), None) return next((u for u in parsed_urls if u is not None), None)
# TODO: recycle this class
class UniversalURL:
"""
>>> u = UniversalURL.from_str('https://sampleurl.com')
>>> if u is not None:
>>> pending = await u.into_pending_item()
"""
source: str
media_type: str | None
match: re.Match | None
def __init__(self, url: str):
url = url.strip()
qobuz_interpreter_url = QOBUZ_INTERPRETER_URL_REGEX.match(url)
if qobuz_interpreter_url is not None:
self.source = "qobuz"
self.media_type = "artist"
self.url_type = "interpreter"
self.match = qobuz_interpreter_url
return
deezer_dynamic_url = DEEZER_DYNAMIC_LINK_REGEX.match(url)
if deezer_dynamic_url is not None:
self.match = deezer_dynamic_url
self.source = "deezer"
self.media_type = None
self.url_type = "deezer_dynamic"
return
soundcloud_url = SOUNDCLOUD_URL_REGEX.match(url)
if soundcloud_url is not None:
self.match = soundcloud_url
self.source = "soundcloud"
self.media_type = None
self.url_type = "soundcloud"
return
generic_url = URL_REGEX.match(url)
if generic_url is not None:
self.match = generic_url
self.source = self.match.group(1)
self.media_type = self.match.group(2)
self.url_type = "generic"
async def into_pending_item(self, client: Client, config: Config) -> Pending | None:
if self.url_type == "generic":
assert self.match is not None
item_id = self.match.group(3)
assert isinstance(item_id, str)
assert client.source == self.source
if self.media_type == "track":
return PendingSingle(item_id, client, config)
elif self.media_type == "album":
return PendingAlbum(item_id, client, config)
else:
raise NotImplementedError
else:
raise NotImplementedError