diff --git a/streamrip/client/client.py b/streamrip/client/client.py index c25256c..80b4142 100644 --- a/streamrip/client/client.py +++ b/streamrip/client/client.py @@ -30,7 +30,7 @@ class Client(ABC): raise NotImplementedError @abstractmethod - async def search(self, query: str, media_type: str, limit: int = 500) -> list[dict]: + async def search(self, media_type: str, query: str, limit: int = 500) -> list[dict]: raise NotImplementedError @abstractmethod diff --git a/streamrip/client/deezer.py b/streamrip/client/deezer.py index 0d185c1..c07a602 100644 --- a/streamrip/client/deezer.py +++ b/streamrip/client/deezer.py @@ -51,7 +51,7 @@ class DeezerClient(Client): return item - async def search(self, query: str, media_type: str, limit: int = 200): + async def search(self, media_type: str, query: str, limit: int = 200): # TODO: use limit parameter if media_type == "featured": try: diff --git a/streamrip/client/qobuz.py b/streamrip/client/qobuz.py index ef121d1..1316c49 100644 --- a/streamrip/client/qobuz.py +++ b/streamrip/client/qobuz.py @@ -226,7 +226,9 @@ class QobuzClient(Client): status, resp = await self._api_request(epoint, params) if status != 200: - raise Exception(f'Error fetching metadata. "{resp["message"]}"') + raise NonStreamable( + f'Error fetching metadata. Message: "{resp["message"]}"' + ) return resp diff --git a/streamrip/config.py b/streamrip/config.py index aa958b5..13e513f 100644 --- a/streamrip/config.py +++ b/streamrip/config.py @@ -12,7 +12,7 @@ from tomlkit.toml_document import TOMLDocument logger = logging.getLogger("streamrip") -APP_DIR = click.get_app_dir("streamrip", force_posix=True) +APP_DIR = click.get_app_dir("streamrip") DEFAULT_CONFIG_PATH = os.path.join(APP_DIR, "config.toml") CURRENT_CONFIG_VERSION = "2.0" @@ -206,6 +206,8 @@ class CliConfig: text_output: bool # Show resolve, download progress bars progress_bars: bool + # The maximum number of search results to show in the interactive menu + max_search_results: int @dataclass(slots=True) diff --git a/streamrip/config.toml b/streamrip/config.toml index aa0990e..841dade 100644 --- a/streamrip/config.toml +++ b/streamrip/config.toml @@ -81,11 +81,13 @@ download_videos = false # The path to download the videos to video_downloads_folder = "" -# This stores a list of item IDs so that repeats are not downloaded. [database] +# Create a database that contains all the track IDs downloaded so far +# Any time a track logged in the database is requested, it is skipped +# This can be disabled temporarily with the --no-db flag downloads_enabled = true +# Path to the downloads database downloads_path = "" - # If a download fails, the item ID is stored here. Then, `rip repair` can be # called to retry the downloads failed_downloads_enabled = true @@ -171,7 +173,7 @@ truncate_to = 120 source = "qobuz" # If no results were found with the primary source, the item is searched for # on this one. -fallback_source = "deezer" +fallback_source = "" [cli] # Print "Downloading {Album name}" etc. to screen diff --git a/streamrip/media/__init__.py b/streamrip/media/__init__.py index ab458dd..4c478f7 100644 --- a/streamrip/media/__init__.py +++ b/streamrip/media/__init__.py @@ -3,7 +3,12 @@ from .artist import Artist, PendingArtist from .artwork import remove_artwork_tempdirs from .label import Label, PendingLabel from .media import Media, Pending -from .playlist import PendingPlaylist, PendingPlaylistTrack, Playlist +from .playlist import ( + PendingLastfmPlaylist, + PendingPlaylist, + PendingPlaylistTrack, + Playlist, +) from .track import PendingSingle, PendingTrack, Track __all__ = [ @@ -17,6 +22,7 @@ __all__ = [ "PendingLabel", "Playlist", "PendingPlaylist", + "PendingLastfmPlaylist", "Track", "PendingTrack", "PendingPlaylistTrack", diff --git a/streamrip/media/playlist.py b/streamrip/media/playlist.py index 83fe3fb..f901b5e 100644 --- a/streamrip/media/playlist.py +++ b/streamrip/media/playlist.py @@ -1,14 +1,28 @@ import asyncio +import html import logging import os +import random +import re +from contextlib import ExitStack from dataclasses import dataclass +import aiohttp +from rich.text import Text + from .. import progress from ..client import Client from ..config import Config +from ..console import console from ..db import Database from ..filepath_utils import clean_filename -from ..metadata import AlbumMetadata, Covers, PlaylistMetadata, TrackMetadata +from ..metadata import ( + AlbumMetadata, + Covers, + PlaylistMetadata, + SearchResults, + TrackMetadata, +) from .artwork import download_artwork from .media import Media, Pending from .track import Track @@ -75,22 +89,32 @@ class Playlist(Media): tracks: list[PendingPlaylistTrack] async def preprocess(self): - pass - - async def download(self): progress.add_title(self.name) - async def _resolve_and_download(pending: PendingPlaylistTrack): - track = await pending.resolve() + async def postprocess(self): + progress.remove_title(self.name) + + async def download(self): + track_resolve_chunk_size = 20 + + async def _resolve_download(item: PendingPlaylistTrack): + track = await item.resolve() if track is None: return await track.rip() - await asyncio.gather(*[_resolve_and_download(p) for p in self.tracks]) - progress.remove_title(self.name) + batches = self.batch( + [_resolve_download(track) for track in self.tracks], + track_resolve_chunk_size, + ) + for batch in batches: + await asyncio.gather(*batch) - async def postprocess(self): - pass + @staticmethod + def batch(iterable, n=1): + l = len(iterable) + for ndx in range(0, l, n): + yield iterable[ndx : min(ndx + n, l)] @dataclass(slots=True) @@ -113,3 +137,199 @@ class PendingPlaylist(Pending): for position, id in enumerate(meta.ids()) ] return Playlist(name, self.config, self.client, tracks) + + +@dataclass(slots=True) +class PendingLastfmPlaylist(Pending): + lastfm_url: str + client: Client + fallback_client: Client | None + config: Config + db: Database + + @dataclass(slots=True) + class Status: + found: int + failed: int + total: int + + def text(self) -> Text: + return Text.assemble( + "Searching for last.fm tracks (", + (f"{self.found} found", "bold green"), + ", ", + (f"{self.failed} failed", "bold red"), + ", ", + (f"{self.total} total", "bold"), + ")", + ) + + async def resolve(self) -> Playlist | None: + try: + playlist_title, titles_artists = await self._parse_lastfm_playlist( + self.lastfm_url + ) + except Exception as e: + logger.error("Error occured while parsing last.fm page: %s", e) + return None + + requests = [] + + s = self.Status(0, 0, len(titles_artists)) + if self.config.session.cli.progress_bars: + with console.status(s.text(), spinner="moon") as status: + callback = lambda: status.update(s.text()) + for title, artist in titles_artists: + requests.append(self._make_query(f"{title} {artist}", s, callback)) + results: list[tuple[str | None, bool]] = await asyncio.gather(*requests) + else: + callback = lambda: None + for title, artist in titles_artists: + requests.append(self._make_query(f"{title} {artist}", s, callback)) + results: list[tuple[str | None, bool]] = await asyncio.gather(*requests) + + parent = self.config.session.downloads.folder + folder = os.path.join(parent, clean_filename(playlist_title)) + + pending_tracks = [] + for pos, (id, from_fallback) in enumerate(results, start=1): + if id is None: + logger.warning(f"No results found for {titles_artists[pos-1]}") + continue + + if from_fallback: + assert self.fallback_client is not None + client = self.fallback_client + else: + client = self.client + + pending_tracks.append( + PendingPlaylistTrack( + id, + client, + self.config, + folder, + playlist_title, + pos, + self.db, + ) + ) + + return Playlist(playlist_title, self.config, self.client, pending_tracks) + + async def _make_query( + self, query: str, s: Status, callback + ) -> tuple[str | None, bool]: + """Try searching for `query` with main source. If that fails, try with next source. + + If both fail, return None. + """ + with ExitStack() as stack: + # ensure `callback` is always called + stack.callback(callback) + pages = await self.client.search("track", query, limit=1) + if len(pages) > 0: + logger.debug(f"Found result for {query} on {self.client.source}") + s.found += 1 + return ( + SearchResults.from_pages(self.client.source, "track", pages) + .results[0] + .id + ), False + + if self.fallback_client is None: + logger.debug(f"No result found for {query} on {self.client.source}") + s.failed += 1 + return None, False + + pages = await self.fallback_client.search("track", query, limit=1) + if len(pages) > 0: + logger.debug(f"Found result for {query} on {self.client.source}") + s.found += 1 + return ( + SearchResults.from_pages( + self.fallback_client.source, "track", pages + ) + .results[0] + .id + ), True + + logger.debug(f"No result found for {query} on {self.client.source}") + s.failed += 1 + return None, True + + async def _parse_lastfm_playlist( + self, playlist_url: str + ) -> tuple[str, list[tuple[str, str]]]: + """From a last.fm url, return the playlist title, and a list of + track titles and artist names. + + Each page contains 50 results, so `num_tracks // 50 + 1` requests + are sent per playlist. + + :param url: + :type url: str + :rtype: tuple[str, list[tuple[str, str]]] + """ + logger.debug("Fetching lastfm playlist") + + title_tags = re.compile(r'([^<]+)' + ) + + def find_title_artist_pairs(page_text): + info: list[tuple[str, str]] = [] + titles = title_tags.findall(page_text) # [2:] + for i in range(0, len(titles) - 1, 2): + info.append((html.unescape(titles[i]), html.unescape(titles[i + 1]))) + return info + + async def fetch(session: aiohttp.ClientSession, url, **kwargs): + async with session.get(url, **kwargs) as resp: + return await resp.text("utf-8") + + # Create new session so we're not bound by rate limit + async with aiohttp.ClientSession() as session: + page = await fetch(session, playlist_url) + playlist_title_match = re_playlist_title_match.search(page) + if playlist_title_match is None: + raise Exception("Error finding title from response") + + playlist_title: str = html.unescape(playlist_title_match.group(1)) + + title_artist_pairs: list[tuple[str, str]] = find_title_artist_pairs(page) + + total_tracks_match = re_total_tracks.search(page) + if total_tracks_match is None: + raise Exception("Error parsing lastfm page: %s", page) + total_tracks = int(total_tracks_match.group(1)) + + remaining_tracks = total_tracks - 50 # already got 50 from 1st page + if remaining_tracks <= 0: + return playlist_title, title_artist_pairs + + last_page = ( + 1 + int(remaining_tracks // 50) + int(remaining_tracks % 50 != 0) + ) + requests = [] + for page in range(2, last_page + 1): + requests.append(fetch(session, playlist_url, params={"page": page})) + results = await asyncio.gather(*requests) + + for page in results: + title_artist_pairs.extend(find_title_artist_pairs(page)) + + return playlist_title, title_artist_pairs + + async def _make_query_mock( + self, _: str, s: Status, callback + ) -> tuple[str | None, bool]: + await asyncio.sleep(random.uniform(1, 20)) + if random.randint(0, 4) >= 1: + s.found += 1 + else: + s.failed += 1 + callback() + return None, False diff --git a/streamrip/media/semaphore.py b/streamrip/media/semaphore.py index 6c59789..2aa2e88 100644 --- a/streamrip/media/semaphore.py +++ b/streamrip/media/semaphore.py @@ -1,27 +1,16 @@ import asyncio +from contextlib import nullcontext from ..config import DownloadsConfig INF = 9999 -class UnlimitedSemaphore: - """Can be swapped out for a real semaphore when no semaphore is needed.""" - - async def __aenter__(self): - return self - - async def __aexit__(self, *_): - pass - - -_unlimited = UnlimitedSemaphore() +_unlimited = nullcontext() _global_semaphore: None | tuple[int, asyncio.Semaphore] = None -def global_download_semaphore( - c: DownloadsConfig, -) -> UnlimitedSemaphore | asyncio.Semaphore: +def global_download_semaphore(c: DownloadsConfig) -> asyncio.Semaphore | nullcontext: """A global semaphore that limit the number of total tracks being downloaded at once. diff --git a/streamrip/media/track.py b/streamrip/media/track.py index b69766d..a471663 100644 --- a/streamrip/media/track.py +++ b/streamrip/media/track.py @@ -7,6 +7,7 @@ from .. import converter from ..client import Client, Downloadable from ..config import Config from ..db import Database +from ..exceptions import NonStreamable from ..filepath_utils import clean_filename from ..metadata import AlbumMetadata, Covers, TrackMetadata, tag_file from ..progress import add_title, get_progress_callback, remove_title @@ -129,7 +130,11 @@ class PendingSingle(Pending): db: Database async def resolve(self) -> Track | None: - resp = await self.client.get_metadata(self.id, "track") + try: + resp = await self.client.get_metadata(self.id, "track") + except NonStreamable as e: + logger.error(f"Error fetching track {self.id}: {e}") + return None # Patch for soundcloud # self.id = resp["id"] album = AlbumMetadata.from_track_resp(resp, self.client.source) diff --git a/streamrip/metadata/search_results.py b/streamrip/metadata/search_results.py index d2b534a..34dc2b6 100644 --- a/streamrip/metadata/search_results.py +++ b/streamrip/metadata/search_results.py @@ -3,7 +3,6 @@ import re import textwrap from abc import ABC, abstractmethod from dataclasses import dataclass -from pprint import pprint class Summary(ABC): diff --git a/streamrip/rip/cli.py b/streamrip/rip/cli.py index 3db5305..d09b886 100644 --- a/streamrip/rip/cli.py +++ b/streamrip/rip/cli.py @@ -11,6 +11,7 @@ from rich.logging import RichHandler from rich.prompt import Confirm from rich.traceback import install +from .. import db from ..config import DEFAULT_CONFIG_PATH, Config, set_user_defaults from ..console import console from .main import Main @@ -85,8 +86,18 @@ def rip(ctx, config_path, folder, no_db, quality, convert, no_progress, verbose) # pass to subcommands ctx.ensure_object(dict) + ctx.obj["config_path"] = config_path + + try: + c = Config(config_path) + except Exception as e: + console.print( + f"Error loading config from [bold cyan]{config_path}[/bold cyan]: {e}\n" + "Try running [bold]rip config reset[/bold]" + ) + ctx.obj["config"] = None + return - c = Config(config_path) # set session config values to command line args c.session.database.downloads_enabled = not no_db if folder is not None: @@ -144,7 +155,6 @@ async def file(ctx, path): @rip.group() def config(): """Manage configuration files.""" - pass @config.command("open") @@ -153,7 +163,8 @@ def config(): def config_open(ctx, vim): """Open the config file in a text editor.""" config_path = ctx.obj["config"].path - console.log(f"Opening file at [bold cyan]{config_path}") + + console.print(f"Opening file at [bold cyan]{config_path}") if vim: if shutil.which("nvim") is not None: subprocess.run(["nvim", config_path]) @@ -168,7 +179,7 @@ def config_open(ctx, vim): @click.pass_context def config_reset(ctx, yes): """Reset the config file.""" - config_path = ctx.obj["config"].path + config_path = ctx.obj["config_path"] if not yes: if not Confirm.ask( f"Are you sure you want to reset the config file at {config_path}?" @@ -180,6 +191,61 @@ def config_reset(ctx, yes): console.print(f"Reset the config file at [bold cyan]{config_path}!") +@config.command("path") +@click.pass_context +def config_path(ctx): + """Display the path of the config file.""" + config_path = ctx.obj["config_path"] + console.print(f"Config path: [bold cyan]'{config_path}'") + + +@rip.group() +def database(): + """View and modify the downloads and failed downloads databases.""" + + +@database.command("browse") +@click.argument("table") +@click.pass_context +def database_browse(ctx, table): + """Browse the contents of a table. + + Available tables: + + * Downloads + + * Failed + """ + from rich.table import Table + + cfg: Config = ctx.obj["config"] + + if table.lower() == "downloads": + downloads = db.Downloads(cfg.session.database.downloads_path) + t = Table(title="Downloads database") + t.add_column("Row") + t.add_column("ID") + for i, row in enumerate(downloads.all()): + t.add_row(f"{i:02}", *row) + console.print(t) + + elif table.lower() == "failed": + failed = db.Failed(cfg.session.database.failed_downloads_path) + t = Table(title="Failed downloads database") + t.add_column("Source") + t.add_column("Media Type") + t.add_column("ID") + for i, row in enumerate(failed.all()): + t.add_row(f"{i:02}", *row) + console.print(t) + + else: + console.print( + f"[red]Invalid database[/red] [bold]{table}[/bold]. [red]Choose[/red] [bold]downloads " + "[red]or[/red] failed[/bold]." + ) + + @rip.command() @click.option( "-f", @@ -211,10 +277,42 @@ async def search(ctx, first, source, media_type, query): @rip.command() +@click.option("-s", "--source", help="The source to search tracks on.") +@click.option( + "-fs", + "--fallback-source", + help="The source to search tracks on if no results were found with the main source.", +) @click.argument("url", required=True) -def lastfm(url): +@click.pass_context +@coro +async def lastfm(ctx, source, fallback_source, url): """Download tracks from a last.fm playlist using a supported source.""" - raise NotImplementedError + + config = ctx.obj["config"] + if source is not None: + config.session.lastfm.source = source + if fallback_source is not None: + config.session.lastfm.fallback_source = fallback_source + with config as cfg: + async with Main(cfg) as main: + await main.resolve_lastfm(url) + await main.rip() + + +@rip.command() +@click.argument("source") +@click.argument("media-type") +@click.argument("id") +@click.pass_context +@coro +async def id(ctx, source, media_type, id): + """Download an item by ID.""" + with ctx.obj["config"] as cfg: + async with Main(cfg) as main: + await main.add_by_id(source, media_type, id) + await main.resolve() + await main.rip() if __name__ == "__main__": diff --git a/streamrip/rip/main.py b/streamrip/rip/main.py index c3bc927..004af37 100644 --- a/streamrip/rip/main.py +++ b/streamrip/rip/main.py @@ -6,7 +6,7 @@ from .. import db from ..client import Client, QobuzClient, SoundcloudClient from ..config import Config from ..console import console -from ..media import Media, Pending, remove_artwork_tempdirs +from ..media import Media, Pending, PendingLastfmPlaylist, remove_artwork_tempdirs from ..metadata import SearchResults from ..progress import clear_progress from .parse_url import parse_url @@ -71,26 +71,30 @@ class Main: async def add_all(self, urls: list[str]): """Add multiple urls concurrently as pending items.""" parsed = [parse_url(url) for url in urls] - url_w_client = [] + url_client_pairs = [] for i, p in enumerate(parsed): if p is None: console.print( f"[red]Found invalid url [cyan]{urls[i]}[/cyan], skipping." ) continue - url_w_client.append((p, await self.get_logged_in_client(p.source))) + url_client_pairs.append((p, await self.get_logged_in_client(p.source))) pendings = await asyncio.gather( *[ url.into_pending(client, self.config, self.database) - for url, client in url_w_client + for url, client in url_client_pairs ] ) self.pending.extend(pendings) async def get_logged_in_client(self, source: str): """Return a functioning client instance for `source`.""" - client = self.clients[source] + client = self.clients.get(source) + if client is None: + raise Exception( + f"No client named {source} available. Only have {self.clients.keys()}" + ) if not client.logged_in: prompter = get_prompter(client, self.config) if not prompter.has_creds(): @@ -110,7 +114,9 @@ class Main: """Resolve all currently pending items.""" with console.status("Resolving URLs...", spinner="dots"): coros = [p.resolve() for p in self.pending] - new_media: list[Media] = await asyncio.gather(*coros) + new_media: list[Media] = [ + m for m in await asyncio.gather(*coros) if m is not None + ] self.media.extend(new_media) self.pending.clear() @@ -129,7 +135,7 @@ class Main: return search_results = SearchResults.from_pages(source, media_type, pages) - if os.name == "nt" or True: + if os.name == "nt": from pick import pick choices = pick( @@ -186,6 +192,24 @@ class Main: first = search_results.results[0] await self.add(f"http://{source}.com/{first.media_type()}/{first.id}") + async def resolve_lastfm(self, playlist_url: str): + """Resolve a last.fm playlist.""" + c = self.config.session.lastfm + client = await self.get_logged_in_client(c.source) + + if len(c.fallback_source) > 0: + fallback_client = await self.get_logged_in_client(c.fallback_source) + else: + fallback_client = None + + pending_playlist = PendingLastfmPlaylist( + playlist_url, client, fallback_client, self.config, self.database + ) + playlist = await pending_playlist.resolve() + + if playlist is not None: + self.media.append(playlist) + async def __aenter__(self): return self @@ -201,3 +225,6 @@ class Main: # may be able to share downloaded artwork in the same `rip` session # We don't know that a cover will not be used again until end of execution remove_artwork_tempdirs() + + async def add_by_id(self, source: str, media_type: str, id: str): + await self.add(f"http://{source}.com/{media_type}/{id}") diff --git a/streamrip/rip/prompter.py b/streamrip/rip/prompter.py index a4b992b..b5b9f1f 100644 --- a/streamrip/rip/prompter.py +++ b/streamrip/rip/prompter.py @@ -1,14 +1,18 @@ import hashlib +import logging import time from abc import ABC, abstractmethod -from getpass import getpass -from click import launch, secho, style +from click import launch +from rich.prompt import Prompt from ..client import Client, DeezerClient, QobuzClient, SoundcloudClient, TidalClient from ..config import Config +from ..console import console from ..exceptions import AuthenticationError, MissingCredentials +logger = logging.getLogger("streamrip") + class CredentialPrompter(ABC): client: Client @@ -53,19 +57,18 @@ class QobuzPrompter(CredentialPrompter): await self.client.login() break except AuthenticationError: - secho("Invalid credentials, try again.", fg="yellow") + console.print("[yellow]Invalid credentials, try again.") self._prompt_creds_and_set_session_config() except MissingCredentials: self._prompt_creds_and_set_session_config() def _prompt_creds_and_set_session_config(self): - secho("Enter Qobuz email: ", fg="green", nl=False) - email = input() - secho("Enter Qobuz password (will not show on screen): ", fg="green", nl=False) - pwd = hashlib.md5(getpass(prompt="").encode("utf-8")).hexdigest() - secho( - f'Credentials saved to config file at "{self.config.path}"', - fg="green", + email = Prompt.ask("Enter your Qobuz email") + pwd_input = Prompt.ask("Enter your Qobuz password (invisible)", password=True) + + pwd = hashlib.md5(pwd_input.encode("utf-8")).hexdigest() + console.print( + f"[green]Credentials saved to config file at [bold cyan]{self.config.path}" ) c = self.config.session.qobuz c.use_auth_token = False @@ -96,9 +99,8 @@ class TidalPrompter(CredentialPrompter): device_code = await self.client._get_device_code() login_link = f"https://{device_code}" - secho( - f"Go to {login_link} to log into Tidal within 5 minutes.", - fg="blue", + console.print( + f"Go to [blue underline]{login_link}[/blue underline] to log into Tidal within 5 minutes.", ) launch(login_link) @@ -158,33 +160,25 @@ class DeezerPrompter(CredentialPrompter): await self.client.login() break except AuthenticationError: - secho("Invalid arl, try again.", fg="yellow") + console.print("[yellow]Invalid arl, try again.") self._prompt_creds_and_set_session_config() self.save() def _prompt_creds_and_set_session_config(self): - secho( + console.print( "If you're not sure how to find the ARL cookie, see the instructions at ", - nl=False, - dim=True, + "[blue underline]https://github.com/nathom/streamrip/wiki/Finding-your-Deezer-ARL-Cookie", ) - secho( - "https://github.com/nathom/streamrip/wiki/Finding-your-Deezer-ARL-Cookie", - underline=True, - fg="blue", - ) - c = self.config.session.deezer - c.arl = input(style("ARL: ", fg="green")) + c.arl = Prompt.ask("Enter your [bold]ARL") def save(self): c = self.config.session.deezer cf = self.config.file.deezer cf.arl = c.arl self.config.file.set_modified() - secho( - f'Credentials saved to config file at "{self.config.path}"', - fg="green", + console.print( + f"[green]Credentials saved to config file at [bold cyan]{self.config.path}", ) def type_check_client(self, client) -> DeezerClient: