Databases working, add no progress flag

This commit is contained in:
Nathan Thomas 2023-11-24 18:22:50 -08:00
parent 1964a0e488
commit 3e6284b04d
12 changed files with 197 additions and 112 deletions

View file

@ -7,6 +7,7 @@ from . import progress
from .artwork import download_artwork from .artwork import download_artwork
from .client import Client from .client import Client
from .config import Config from .config import Config
from .db import Database
from .exceptions import NonStreamable from .exceptions import NonStreamable
from .media import Media, Pending from .media import Media, Pending
from .metadata import AlbumMetadata from .metadata import AlbumMetadata
@ -23,6 +24,7 @@ class Album(Media):
config: Config config: Config
# folder where the tracks will be downloaded # folder where the tracks will be downloaded
folder: str folder: str
db: Database
async def preprocess(self): async def preprocess(self):
progress.add_title(self.meta.album) progress.add_title(self.meta.album)
@ -45,6 +47,7 @@ class PendingAlbum(Pending):
id: str id: str
client: Client client: Client
config: Config config: Config
db: Database
async def resolve(self) -> Album | None: async def resolve(self) -> Album | None:
resp = await self.client.get_metadata(self.id, "album") resp = await self.client.get_metadata(self.id, "album")
@ -75,12 +78,13 @@ class PendingAlbum(Pending):
client=self.client, client=self.client,
config=self.config, config=self.config,
folder=album_folder, folder=album_folder,
db=self.db,
cover_path=embed_cover, cover_path=embed_cover,
) )
for id in tracklist for id in tracklist
] ]
logger.debug("Pending tracks: %s", pending_tracks) logger.debug("Pending tracks: %s", pending_tracks)
return Album(meta, pending_tracks, self.config, album_folder) return Album(meta, pending_tracks, self.config, album_folder, self.db)
def _album_folder(self, parent: str, meta: AlbumMetadata) -> str: def _album_folder(self, parent: str, meta: AlbumMetadata) -> str:
formatter = self.config.session.filepaths.folder_format formatter = self.config.session.filepaths.folder_format

View file

@ -4,6 +4,7 @@ from .album import PendingAlbum
from .album_list import AlbumList from .album_list import AlbumList
from .client import Client from .client import Client
from .config import Config from .config import Config
from .db import Database
from .media import Pending from .media import Pending
from .metadata import ArtistMetadata from .metadata import ArtistMetadata
@ -17,12 +18,13 @@ class PendingArtist(Pending):
id: str id: str
client: Client client: Client
config: Config config: Config
db: Database
async def resolve(self) -> Artist: async def resolve(self) -> Artist:
resp = await self.client.get_metadata(self.id, "artist") resp = await self.client.get_metadata(self.id, "artist")
meta = ArtistMetadata.from_resp(resp, self.client.source) meta = ArtistMetadata.from_resp(resp, self.client.source)
albums = [ albums = [
PendingAlbum(album_id, self.client, self.config) PendingAlbum(album_id, self.client, self.config, self.db)
for album_id in meta.album_ids() for album_id in meta.album_ids()
] ]
return Artist(meta.name, albums, self.client, self.config) return Artist(meta.name, albums, self.client, self.config)

View file

@ -6,7 +6,6 @@ import subprocess
from functools import wraps from functools import wraps
import click import click
from click import secho
from click_help_colors import HelpColorsGroup # type: ignore from click_help_colors import HelpColorsGroup # type: ignore
from rich.logging import RichHandler from rich.logging import RichHandler
from rich.prompt import Confirm from rich.prompt import Confirm
@ -15,19 +14,7 @@ from rich.traceback import install
from .config import Config, set_user_defaults from .config import Config, set_user_defaults
from .console import console from .console import console
from .main import Main from .main import Main
from .user_paths import BLANK_CONFIG_PATH, CONFIG_PATH from .user_paths import BLANK_CONFIG_PATH, DEFAULT_CONFIG_PATH
def echo_i(msg, **kwargs):
secho(msg, fg="green", **kwargs)
def echo_w(msg, **kwargs):
secho(msg, fg="yellow", **kwargs)
def echo_e(msg, **kwargs):
secho(msg, fg="yellow", **kwargs)
def coro(f): def coro(f):
@ -45,7 +32,7 @@ def coro(f):
) )
@click.version_option(version="2.0") @click.version_option(version="2.0")
@click.option( @click.option(
"--config-path", default=CONFIG_PATH, help="Path to the configuration file" "--config-path", default=DEFAULT_CONFIG_PATH, help="Path to the configuration file"
) )
@click.option("-f", "--folder", help="The folder to download items into.") @click.option("-f", "--folder", help="The folder to download items into.")
@click.option( @click.option(
@ -61,18 +48,20 @@ def coro(f):
"--convert", "--convert",
help="Convert the downloaded files to an audio codec (ALAC, FLAC, MP3, AAC, or OGG)", help="Convert the downloaded files to an audio codec (ALAC, FLAC, MP3, AAC, or OGG)",
) )
@click.option(
"--no-progress", help="Do not show progress bars", is_flag=True, default=False
)
@click.option( @click.option(
"-v", "--verbose", help="Enable verbose output (debug mode)", is_flag=True "-v", "--verbose", help="Enable verbose output (debug mode)", is_flag=True
) )
@click.pass_context @click.pass_context
def rip(ctx, config_path, folder, no_db, quality, convert, verbose): def rip(ctx, config_path, folder, no_db, quality, convert, no_progress, verbose):
""" """
Streamrip: the all in one music downloader. Streamrip: the all in one music downloader.
""" """
print(ctx, config_path, folder, no_db, quality, convert, verbose)
global logger global logger
logging.basicConfig( logging.basicConfig(
level="WARNING", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
) )
logger = logging.getLogger("streamrip") logger = logging.getLogger("streamrip")
if verbose: if verbose:
@ -89,7 +78,9 @@ def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
if not os.path.isfile(config_path): if not os.path.isfile(config_path):
echo_i(f"No file found at {config_path}, creating default config.") console.print(
f"No file found at [bold cyan]{config_path}[/bold cyan], creating default config."
)
shutil.copy(BLANK_CONFIG_PATH, config_path) shutil.copy(BLANK_CONFIG_PATH, config_path)
set_user_defaults(config_path) set_user_defaults(config_path)
@ -98,18 +89,24 @@ def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
c = Config(config_path) c = Config(config_path)
# set session config values to command line args # set session config values to command line args
c.session.database.downloads_enabled = not no_db
if folder is not None: if folder is not None:
c.session.downloads.folder = folder c.session.downloads.folder = folder
c.session.database.downloads_enabled = not no_db
c.session.qobuz.quality = quality if quality is not None:
c.session.tidal.quality = quality c.session.qobuz.quality = quality
c.session.deezer.quality = quality c.session.tidal.quality = quality
c.session.soundcloud.quality = quality c.session.deezer.quality = quality
c.session.soundcloud.quality = quality
if convert is not None: if convert is not None:
c.session.conversion.enabled = True c.session.conversion.enabled = True
assert convert.upper() in ("ALAC", "FLAC", "OGG", "MP3", "AAC") assert convert.upper() in ("ALAC", "FLAC", "OGG", "MP3", "AAC")
c.session.conversion.codec = convert.upper() c.session.conversion.codec = convert.upper()
if no_progress:
c.session.cli.progress_bars = False
ctx.obj["config"] = c ctx.obj["config"] = c
@ -118,16 +115,10 @@ def rip(ctx, config_path, folder, no_db, quality, convert, verbose):
@click.pass_context @click.pass_context
@coro @coro
async def url(ctx, urls): async def url(ctx, urls):
"""Download content from URLs. """Download content from URLs."""
Example usage:
rip url TODO: find url
"""
with ctx.obj["config"] as cfg: with ctx.obj["config"] as cfg:
main = Main(cfg) main = Main(cfg)
for u in urls: await main.add_all(urls)
await main.add(u)
await main.resolve() await main.resolve()
await main.rip() await main.rip()
@ -146,8 +137,7 @@ async def file(ctx, path):
with ctx.obj["config"] as cfg: with ctx.obj["config"] as cfg:
main = Main(cfg) main = Main(cfg)
with open(path) as f: with open(path) as f:
for url in f: await main.add_all([line for line in f])
await main.add(url)
await main.resolve() await main.resolve()
await main.rip() await main.rip()
@ -164,7 +154,7 @@ def config():
def config_open(ctx, vim): def config_open(ctx, vim):
"""Open the config file in a text editor.""" """Open the config file in a text editor."""
config_path = ctx.obj["config_path"] config_path = ctx.obj["config_path"]
echo_i(f"Opening file at {config_path}") console.log(f"Opening file at [bold cyan]{config_path}")
if vim: if vim:
if shutil.which("nvim") is not None: if shutil.which("nvim") is not None:
subprocess.run(["nvim", config_path]) subprocess.run(["nvim", config_path])
@ -189,7 +179,7 @@ def config_reset(ctx, yes):
shutil.copy(BLANK_CONFIG_PATH, config_path) shutil.copy(BLANK_CONFIG_PATH, config_path)
set_user_defaults(config_path) set_user_defaults(config_path)
echo_i(f"Reset the config file at {config_path}!") console.print(f"Reset the config file at [bold cyan]{config_path}!")
@rip.command() @rip.command()
@ -199,14 +189,15 @@ def config_reset(ctx, yes):
async def search(query, source): async def search(query, source):
""" """
Search for content using a specific source. Search for content using a specific source.
""" """
echo_i(f'Searching for "{query}" in source: {source}') raise NotImplementedError
@rip.command() @rip.command()
@click.argument("url", required=True) @click.argument("url", required=True)
def lastfm(url): def lastfm(url):
pass raise NotImplementedError
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -2,13 +2,13 @@
import copy import copy
import logging import logging
import os
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from tomlkit.api import dumps, parse from tomlkit.api import dumps, parse
from tomlkit.toml_document import TOMLDocument from tomlkit.toml_document import TOMLDocument
from .user_paths import ( from .user_paths import (
DEFAULT_CONFIG_PATH,
DEFAULT_DOWNLOADS_DB_PATH, DEFAULT_DOWNLOADS_DB_PATH,
DEFAULT_DOWNLOADS_FOLDER, DEFAULT_DOWNLOADS_FOLDER,
DEFAULT_FAILED_DOWNLOADS_DB_PATH, DEFAULT_FAILED_DOWNLOADS_DB_PATH,
@ -19,8 +19,6 @@ logger = logging.getLogger("streamrip")
CURRENT_CONFIG_VERSION = "2.0" CURRENT_CONFIG_VERSION = "2.0"
DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.toml")
@dataclass(slots=True) @dataclass(slots=True)
class QobuzConfig: class QobuzConfig:

View file

@ -4,13 +4,12 @@ import logging
import os import os
import sqlite3 import sqlite3
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
logger = logging.getLogger("streamrip") logger = logging.getLogger("streamrip")
# apologies to anyone reading this file
class DatabaseInterface(ABC):
class Database(ABC):
@abstractmethod @abstractmethod
def create(self): def create(self):
pass pass
@ -27,8 +26,31 @@ class Database(ABC):
def remove(self, kvs): def remove(self, kvs):
pass pass
@abstractmethod
def all(self) -> list:
pass
class DatabaseBase(Database):
class Dummy(DatabaseInterface):
"""This exists as a mock to use in case databases are disabled."""
def create(self):
pass
def contains(self, **_):
return False
def add(self, *_):
pass
def remove(self, *_):
pass
def all(self):
return []
class DatabaseBase(DatabaseInterface):
"""A wrapper for an sqlite database.""" """A wrapper for an sqlite database."""
structure: dict structure: dict
@ -41,6 +63,7 @@ class DatabaseBase(Database):
""" """
assert self.structure != {} assert self.structure != {}
assert self.name assert self.name
assert path
self.path = path self.path = path
@ -124,10 +147,10 @@ class DatabaseBase(Database):
logger.debug(command) logger.debug(command)
conn.execute(command, tuple(items.values())) conn.execute(command, tuple(items.values()))
def __iter__(self): def all(self):
"""Iterate through the rows of the table.""" """Iterate through the rows of the table."""
with sqlite3.connect(self.path) as conn: with sqlite3.connect(self.path) as conn:
return conn.execute(f"SELECT * FROM {self.name}") return list(conn.execute(f"SELECT * FROM {self.name}"))
def reset(self): def reset(self):
"""Delete the database file.""" """Delete the database file."""
@ -137,20 +160,6 @@ class DatabaseBase(Database):
pass pass
class Dummy(Database):
def create(self):
pass
def contains(self):
return False
def add(self):
pass
def remove(self):
pass
class Downloads(DatabaseBase): class Downloads(DatabaseBase):
"""A table that stores the downloaded IDs.""" """A table that stores the downloaded IDs."""
@ -160,7 +169,7 @@ class Downloads(DatabaseBase):
} }
class FailedDownloads(DatabaseBase): class Failed(DatabaseBase):
"""A table that stores information about failed downloads.""" """A table that stores information about failed downloads."""
name = "failed_downloads" name = "failed_downloads"
@ -169,3 +178,21 @@ class FailedDownloads(DatabaseBase):
"media_type": ["text"], "media_type": ["text"],
"id": ["text", "unique"], "id": ["text", "unique"],
} }
@dataclass(slots=True)
class Database:
downloads: DatabaseInterface
failed: DatabaseInterface
def downloaded(self, item_id: str) -> bool:
return self.downloads.contains(id=item_id)
def set_downloaded(self, item_id: str):
self.downloads.add((item_id,))
def get_failed_downloads(self) -> list[tuple[str, str, str]]:
return self.failed.all()
def set_failed(self, source: str, media_type: str, id: str):
self.failed.add((source, media_type, id))

View file

@ -1,10 +1,10 @@
import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from .album import PendingAlbum from .album import PendingAlbum
from .album_list import AlbumList from .album_list import AlbumList
from .client import Client from .client import Client
from .config import Config from .config import Config
from .db import Database
from .media import Pending from .media import Pending
from .metadata import LabelMetadata from .metadata import LabelMetadata
@ -18,12 +18,13 @@ class PendingLabel(Pending):
id: str id: str
client: Client client: Client
config: Config config: Config
db: Database
async def resolve(self) -> Label: async def resolve(self) -> Label:
resp = await self.client.get_metadata(self.id, "label") resp = await self.client.get_metadata(self.id, "label")
meta = LabelMetadata.from_resp(resp, self.client.source) meta = LabelMetadata.from_resp(resp, self.client.source)
albums = [ albums = [
PendingAlbum(album_id, self.client, self.config) PendingAlbum(album_id, self.client, self.config, self.db)
for album_id in meta.album_ids() for album_id in meta.album_ids()
] ]
return Label(meta.name, albums, self.client, self.config) return Label(meta.name, albums, self.client, self.config)

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
from . import db
from .artwork import remove_artwork_tempdirs from .artwork import remove_artwork_tempdirs
from .client import Client from .client import Client
from .config import Config from .config import Config
@ -26,9 +27,8 @@ class Main:
""" """
def __init__(self, config: Config): def __init__(self, config: Config):
# Pipeline: # Data pipeline:
# input URL -> (URL) -> (Pending) -> (Media) -> (Downloadable) # input URL -> (URL) -> (Pending) -> (Media) -> (Downloadable) -> audio file
# -> downloaded audio file
self.pending: list[Pending] = [] self.pending: list[Pending] = []
self.media: list[Media] = [] self.media: list[Media] = []
self.config = config self.config = config
@ -37,20 +37,55 @@ class Main:
# "tidal": TidalClient(config), # "tidal": TidalClient(config),
# "deezer": DeezerClient(config), # "deezer": DeezerClient(config),
"soundcloud": SoundcloudClient(config), "soundcloud": SoundcloudClient(config),
# "deezloader": DeezloaderClient(config),
} }
self.database: db.Database
c = self.config.session.database
if c.downloads_enabled:
downloads_db = db.Downloads(c.downloads_path)
else:
downloads_db = db.Dummy()
if c.failed_downloads_enabled:
failed_downloads_db = db.Failed(c.failed_downloads_path)
else:
failed_downloads_db = db.Dummy()
self.database = db.Database(downloads_db, failed_downloads_db)
async def add(self, url: str): async def add(self, url: str):
"""Add url as a pending item. Do not `asyncio.gather` calls to this!""" """Add url as a pending item.
Do not `asyncio.gather` calls to this! Use `add_all` for concurrency.
"""
parsed = parse_url(url) parsed = parse_url(url)
if parsed is None: if parsed is None:
raise Exception(f"Unable to parse url {url}") raise Exception(f"Unable to parse url {url}")
client = await self.get_logged_in_client(parsed.source) client = await self.get_logged_in_client(parsed.source)
self.pending.append(await parsed.into_pending(client, self.config)) self.pending.append(
await parsed.into_pending(client, self.config, self.database)
)
logger.debug("Added url=%s", url) logger.debug("Added url=%s", url)
async def add_all(self, urls: list[str]):
parsed = [parse_url(url) for url in urls]
url_w_client = [
(p, await self.get_logged_in_client(p.source))
for p in parsed
if p is not None
]
pendings = await asyncio.gather(
*[
url.into_pending(client, self.config, self.database)
for url, client in url_w_client
]
)
self.pending.extend(pendings)
async def get_logged_in_client(self, source: str): async def get_logged_in_client(self, source: str):
"""Return a functioning client instance for `source`."""
client = self.clients[source] client = self.clients[source]
if not client.logged_in: if not client.logged_in:
prompter = get_prompter(client, self.config) prompter = get_prompter(client, self.config)
@ -81,5 +116,9 @@ class Main:
if hasattr(client, "session"): if hasattr(client, "session"):
await client.session.close() await client.session.close()
# close global progress bar manager
clear_progress() clear_progress()
# We remove artwork tempdirs here because multiple singles
# may be able to share downloaded artwork in the same `rip` session
# We don't know that a cover will not be used again until end of execution
remove_artwork_tempdirs() remove_artwork_tempdirs()

View file

@ -7,6 +7,7 @@ from . import progress
from .artwork import download_artwork from .artwork import download_artwork
from .client import Client from .client import Client
from .config import Config from .config import Config
from .db import Database
from .filepath_utils import clean_filename from .filepath_utils import clean_filename
from .media import Media, Pending from .media import Media, Pending
from .metadata import AlbumMetadata, Covers, PlaylistMetadata, TrackMetadata from .metadata import AlbumMetadata, Covers, PlaylistMetadata, TrackMetadata
@ -23,8 +24,12 @@ class PendingPlaylistTrack(Pending):
folder: str folder: str
playlist_name: str playlist_name: str
position: int position: int
db: Database
async def resolve(self) -> Track | None: async def resolve(self) -> Track | None:
if self.db.downloaded(self.id):
logger.info(f"Track ({self.id}) already logged in database. Skipping.")
return None
resp = await self.client.get_metadata(self.id, "track") resp = await self.client.get_metadata(self.id, "track")
album = AlbumMetadata.from_track_resp(resp, self.client.source) album = AlbumMetadata.from_track_resp(resp, self.client.source)
@ -33,6 +38,7 @@ class PendingPlaylistTrack(Pending):
logger.error( logger.error(
f"Track ({self.id}) not available for stream on {self.client.source}" f"Track ({self.id}) not available for stream on {self.client.source}"
) )
self.db.set_failed(self.client.source, "track", self.id)
return None return None
c = self.config.session.metadata c = self.config.session.metadata
@ -46,7 +52,9 @@ class PendingPlaylistTrack(Pending):
self._download_cover(album.covers, self.folder), self._download_cover(album.covers, self.folder),
self.client.get_downloadable(self.id, quality), self.client.get_downloadable(self.id, quality),
) )
return Track(meta, downloadable, self.config, self.folder, embedded_cover_path) return Track(
meta, downloadable, self.config, self.folder, embedded_cover_path, self.db
)
async def _download_cover(self, covers: Covers, folder: str) -> str | None: async def _download_cover(self, covers: Covers, folder: str) -> str | None:
embed_path, _ = await download_artwork( embed_path, _ = await download_artwork(
@ -90,6 +98,7 @@ class PendingPlaylist(Pending):
id: str id: str
client: Client client: Client
config: Config config: Config
db: Database
async def resolve(self) -> Playlist | None: async def resolve(self) -> Playlist | None:
resp = await self.client.get_metadata(self.id, "playlist") resp = await self.client.get_metadata(self.id, "playlist")
@ -99,7 +108,7 @@ class PendingPlaylist(Pending):
folder = os.path.join(parent, clean_filename(name)) folder = os.path.join(parent, clean_filename(name))
tracks = [ tracks = [
PendingPlaylistTrack( PendingPlaylistTrack(
id, self.client, self.config, folder, name, position + 1 id, self.client, self.config, folder, name, position + 1, self.db
) )
for position, id in enumerate(meta.ids()) for position, id in enumerate(meta.ids())
] ]

View file

@ -14,9 +14,11 @@ class ProgressManager:
def __init__(self): def __init__(self):
self.started = False self.started = False
self.progress = Progress(console=console) self.progress = Progress(console=console)
self.prefix = Text.assemble(("Downloading ", "bold cyan"), overflow="ellipsis")
self.live = Live(Group(self.prefix, self.progress), refresh_per_second=10)
self.task_titles = [] self.task_titles = []
self.prefix = Text.assemble(("Downloading ", "bold cyan"), overflow="ellipsis")
self.live = Live(
Group(self.get_title_text(), self.progress), refresh_per_second=10
)
def get_callback(self, total: int, desc: str): def get_callback(self, total: int, desc: str):
if not self.started: if not self.started:

View file

@ -7,8 +7,8 @@ from . import converter
from .artwork import download_artwork from .artwork import download_artwork
from .client import Client from .client import Client
from .config import Config from .config import Config
from .db import Database
from .downloadable import Downloadable from .downloadable import Downloadable
from .exceptions import NonStreamable
from .filepath_utils import clean_filename from .filepath_utils import clean_filename
from .media import Media, Pending from .media import Media, Pending
from .metadata import AlbumMetadata, Covers, TrackMetadata from .metadata import AlbumMetadata, Covers, TrackMetadata
@ -27,6 +27,7 @@ class Track(Media):
folder: str folder: str
# Is None if a cover doesn't exist for the track # Is None if a cover doesn't exist for the track
cover_path: str | None cover_path: str | None
db: Database
# change? # change?
download_path: str = "" download_path: str = ""
@ -45,15 +46,11 @@ class Track(Media):
await self.downloadable.download(self.download_path, callback) await self.downloadable.download(self.download_path, callback)
async def postprocess(self): async def postprocess(self):
await self._tag() await tag_file(self.download_path, self.meta, self.cover_path)
if self.config.session.conversion.enabled: if self.config.session.conversion.enabled:
await self._convert() await self._convert()
# if self.cover_path is not None: self.db.set_downloaded(self.meta.info.id)
# os.remove(self.cover_path)
async def _tag(self):
await tag_file(self.download_path, self.meta, self.cover_path)
async def _convert(self): async def _convert(self):
c = self.config.session.conversion c = self.config.session.conversion
@ -88,22 +85,30 @@ class PendingTrack(Pending):
client: Client client: Client
config: Config config: Config
folder: str folder: str
db: Database
# cover_path is None <==> Artwork for this track doesn't exist in API # cover_path is None <==> Artwork for this track doesn't exist in API
cover_path: str | None cover_path: str | None
async def resolve(self) -> Track | None: async def resolve(self) -> Track | None:
resp = await self.client.get_metadata(self.id, "track") if self.db.downloaded(self.id):
meta = TrackMetadata.from_resp(self.album, self.client.source, resp) logger.info(
if meta is None: f"Skipping track {self.id}. Marked as downloaded in the database."
logger.error(
f"Track {self.id} not available for stream on {self.client.source}"
) )
return None return None
quality = getattr(self.config.session, self.client.source).quality resp = await self.client.get_metadata(self.id, "track")
assert isinstance(quality, int) source = self.client.source
meta = TrackMetadata.from_resp(self.album, source, resp)
if meta is None:
logger.error(f"Track {self.id} not available for stream on {source}")
self.db.set_failed(source, "track", self.id)
return None
quality = self.config.session.get_source(source).quality
downloadable = await self.client.get_downloadable(self.id, quality) downloadable = await self.client.get_downloadable(self.id, quality)
return Track(meta, downloadable, self.config, self.folder, self.cover_path) return Track(
meta, downloadable, self.config, self.folder, self.cover_path, self.db
)
@dataclass(slots=True) @dataclass(slots=True)
@ -117,6 +122,7 @@ class PendingSingle(Pending):
id: str id: str
client: Client client: Client
config: Config config: Config
db: Database
async def resolve(self) -> Track | None: async def resolve(self) -> Track | None:
resp = await self.client.get_metadata(self.id, "track") resp = await self.client.get_metadata(self.id, "track")
@ -126,6 +132,7 @@ class PendingSingle(Pending):
meta = TrackMetadata.from_resp(album, self.client.source, resp) meta = TrackMetadata.from_resp(album, self.client.source, resp)
if meta is None: if meta is None:
self.db.set_failed(self.client.source, "track", self.id)
logger.error(f"Cannot stream track ({self.id}) on {self.client.source}") logger.error(f"Cannot stream track ({self.id}) on {self.client.source}")
return None return None
@ -140,7 +147,9 @@ class PendingSingle(Pending):
self._download_cover(album.covers, folder), self._download_cover(album.covers, folder),
self.client.get_downloadable(self.id, quality), self.client.get_downloadable(self.id, quality),
) )
return Track(meta, downloadable, self.config, folder, embedded_cover_path) return Track(
meta, downloadable, self.config, folder, embedded_cover_path, self.db
)
def _format_folder(self, meta: AlbumMetadata) -> str: def _format_folder(self, meta: AlbumMetadata) -> str:
c = self.config.session c = self.config.session

View file

@ -7,6 +7,7 @@ from .album import PendingAlbum
from .artist import PendingArtist from .artist import PendingArtist
from .client import Client from .client import Client
from .config import Config from .config import Config
from .db import Database
from .label import PendingLabel from .label import PendingLabel
from .media import Pending from .media import Pending
from .playlist import PendingPlaylist from .playlist import PendingPlaylist
@ -23,7 +24,6 @@ from .validation_regexps import (
class URL(ABC): class URL(ABC):
match: re.Match
source: str source: str
def __init__(self, match: re.Match, source: str): def __init__(self, match: re.Match, source: str):
@ -35,7 +35,9 @@ class URL(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def into_pending(self, client: Client, config: Config) -> Pending: async def into_pending(
self, client: Client, config: Config, db: Database
) -> Pending:
raise NotImplementedError raise NotImplementedError
@ -48,20 +50,22 @@ class GenericURL(URL):
source = generic_url.group(1) source = generic_url.group(1)
return cls(generic_url, source) return cls(generic_url, source)
async def into_pending(self, client: Client, config: Config) -> Pending: async def into_pending(
self, client: Client, config: Config, db: Database
) -> Pending:
source, media_type, item_id = self.match.groups() source, media_type, item_id = self.match.groups()
assert client.source == source assert client.source == source
if media_type == "track": if media_type == "track":
return PendingSingle(item_id, client, config) return PendingSingle(item_id, client, config, db)
elif media_type == "album": elif media_type == "album":
return PendingAlbum(item_id, client, config) return PendingAlbum(item_id, client, config, db)
elif media_type == "playlist": elif media_type == "playlist":
return PendingPlaylist(item_id, client, config) return PendingPlaylist(item_id, client, config, db)
elif media_type == "artist": elif media_type == "artist":
return PendingArtist(item_id, client, config) return PendingArtist(item_id, client, config, db)
elif media_type == "label": elif media_type == "label":
return PendingLabel(item_id, client, config) return PendingLabel(item_id, client, config, db)
else: else:
raise NotImplementedError raise NotImplementedError
@ -76,10 +80,12 @@ class QobuzInterpreterURL(URL):
return None return None
return cls(qobuz_interpreter_url, "qobuz") return cls(qobuz_interpreter_url, "qobuz")
async def into_pending(self, client: Client, config: Config) -> Pending: async def into_pending(
self, client: Client, config: Config, db: Database
) -> Pending:
url = self.match.group(0) url = self.match.group(0)
artist_id = await self.extract_interpreter_url(url, client) artist_id = await self.extract_interpreter_url(url, client)
return PendingArtist(artist_id, client, config) return PendingArtist(artist_id, client, config, db)
@staticmethod @staticmethod
async def extract_interpreter_url(url: str, client: Client) -> str: async def extract_interpreter_url(url: str, client: Client) -> str:
@ -113,14 +119,16 @@ class SoundcloudURL(URL):
def __init__(self, url: str): def __init__(self, url: str):
self.url = url self.url = url
async def into_pending(self, client: SoundcloudClient, config: Config) -> Pending: async def into_pending(
self, client: SoundcloudClient, config: Config, db: Database
) -> Pending:
resolved = await client._resolve_url(self.url) resolved = await client._resolve_url(self.url)
media_type = resolved["kind"] media_type = resolved["kind"]
item_id = str(resolved["id"]) item_id = str(resolved["id"])
if media_type == "track": if media_type == "track":
return PendingSingle(item_id, client, config) return PendingSingle(item_id, client, config, db)
elif media_type == "playlist": elif media_type == "playlist":
return PendingPlaylist(item_id, client, config) return PendingPlaylist(item_id, client, config, db)
else: else:
raise NotImplementedError(media_type) raise NotImplementedError(media_type)

View file

@ -8,10 +8,7 @@ APP_DIR = user_config_dir(APPNAME)
HOME = Path.home() HOME = Path.home()
LOG_DIR = CACHE_DIR = CONFIG_DIR = APP_DIR LOG_DIR = CACHE_DIR = CONFIG_DIR = APP_DIR
DEFAULT_CONFIG_PATH = os.path.join(CONFIG_DIR, "config.toml")
CONFIG_PATH = os.path.join(CONFIG_DIR, "config.toml")
DB_PATH = os.path.join(LOG_DIR, "downloads.db")
FAILED_DB_PATH = os.path.join(LOG_DIR, "failed_downloads.db")
DOWNLOADS_DIR = os.path.join(HOME, "StreamripDownloads") DOWNLOADS_DIR = os.path.join(HOME, "StreamripDownloads")
# file shipped with script # file shipped with script
@ -20,6 +17,4 @@ BLANK_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.toml")
DEFAULT_DOWNLOADS_FOLDER = os.path.join(HOME, "StreamripDownloads") DEFAULT_DOWNLOADS_FOLDER = os.path.join(HOME, "StreamripDownloads")
DEFAULT_DOWNLOADS_DB_PATH = os.path.join(LOG_DIR, "downloads.db") DEFAULT_DOWNLOADS_DB_PATH = os.path.join(LOG_DIR, "downloads.db")
DEFAULT_FAILED_DOWNLOADS_DB_PATH = os.path.join(LOG_DIR, "failed_downloads.db") DEFAULT_FAILED_DOWNLOADS_DB_PATH = os.path.join(LOG_DIR, "failed_downloads.db")
DEFAULT_YOUTUBE_VIDEO_DOWNLOADS_FOLDER = os.path.join( DEFAULT_YOUTUBE_VIDEO_DOWNLOADS_FOLDER = os.path.join(DOWNLOADS_DIR, "YouTubeVideos")
HOME, "StreamripDownloads", "YouTubeVideos"
)