Add dosctrings

This commit is contained in:
nathom 2021-07-29 11:20:49 -07:00
parent 8d0dc7fb7f
commit e73bff8d6b
12 changed files with 295 additions and 106 deletions

View file

@ -22,8 +22,11 @@ ignore_missing_imports = True
[mypy-tomlkit.*] [mypy-tomlkit.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-Crypto.*] [mypy-Cryptodome.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-click.*] [mypy-click.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-PIL.*]
ignore_missing_imports = True

View file

@ -0,0 +1 @@
"""Rip: an easy to use command line utility for downloading audio streams."""

View file

@ -1,3 +1,4 @@
"""Run the rip program."""
from .cli import main from .cli import main
main() main()

View file

@ -1,3 +1,5 @@
"""Various constant values that are used by RipCore."""
import os import os
import re import re
from pathlib import Path from pathlib import Path

View file

@ -240,6 +240,10 @@ class RipCore(list):
} }
def repair(self, max_items=None): def repair(self, max_items=None):
"""Iterate through the failed_downloads database and retry them.
:param max_items: The maximum number of items to download.
"""
if max_items is None: if max_items is None:
max_items = float("inf") max_items = float("inf")
@ -331,6 +335,11 @@ class RipCore(list):
item.convert(**arguments["conversion"]) item.convert(**arguments["conversion"])
def scrape(self, featured_list: str): def scrape(self, featured_list: str):
"""Download all of the items in a Qobuz featured list.
:param featured_list: The name of the list. See `rip discover --help`.
:type featured_list: str
"""
self.extend(self.search("qobuz", featured_list, "featured", limit=500)) self.extend(self.search("qobuz", featured_list, "featured", limit=500))
def get_client(self, source: str) -> Client: def get_client(self, source: str) -> Client:

View file

@ -15,6 +15,11 @@ class Database:
name: str name: str
def __init__(self, path, dummy=False): def __init__(self, path, dummy=False):
"""Create a Database instance.
:param path: Path to the database file.
:param dummy: Make the database empty.
"""
assert self.structure != [] assert self.structure != []
assert self.name assert self.name
@ -72,7 +77,15 @@ class Database:
return bool(conn.execute(command, tuple(items.values())).fetchone()[0]) return bool(conn.execute(command, tuple(items.values())).fetchone()[0])
def __contains__(self, keys: dict) -> bool: def __contains__(self, keys: Union[str, dict]) -> bool:
"""Check whether a key-value pair exists in the database.
:param keys: Either a dict with the structure {key: value_to_search_for, ...},
or if there is only one key in the table, value_to_search_for can be
passed in by itself.
:type keys: Union[str, dict]
:rtype: bool
"""
if isinstance(keys, dict): if isinstance(keys, dict):
return self.contains(**keys) return self.contains(**keys)
@ -119,6 +132,12 @@ class Database:
logger.debug(e) logger.debug(e)
def remove(self, **items): def remove(self, **items):
"""Remove items from a table.
Warning: NOT TESTED!
:param items:
"""
# not in use currently # not in use currently
if self.is_dummy: if self.is_dummy:
return return
@ -131,6 +150,7 @@ class Database:
conn.execute(command, tuple(items.values())) conn.execute(command, tuple(items.values()))
def __iter__(self): def __iter__(self):
"""Iterate through the rows of the table."""
if self.is_dummy: if self.is_dummy:
return () return ()
@ -138,6 +158,7 @@ class Database:
return conn.execute(f"SELECT * FROM {self.name}") return conn.execute(f"SELECT * FROM {self.name}")
def reset(self): def reset(self):
"""Delete the database file."""
try: try:
os.remove(self.path) os.remove(self.path)
except FileNotFoundError: except FileNotFoundError:
@ -145,6 +166,8 @@ class Database:
class Downloads(Database): class Downloads(Database):
"""A table that stores the downloaded IDs."""
name = "downloads" name = "downloads"
structure = { structure = {
"id": ["text", "unique"], "id": ["text", "unique"],
@ -152,6 +175,8 @@ class Downloads(Database):
class FailedDownloads(Database): class FailedDownloads(Database):
"""A table that stores information about failed downloads."""
name = "failed_downloads" name = "failed_downloads"
structure = { structure = {
"source": ["text"], "source": ["text"],

View file

@ -1,2 +1,5 @@
"""Exceptions used by RipCore."""
class DeezloaderFallback(Exception): class DeezloaderFallback(Exception):
pass """Raise if Deezer account isn't logged in and rip is falling back to Deezloader."""

View file

@ -1,3 +1,5 @@
"""Utility functions for RipCore."""
import re import re
from typing import Tuple from typing import Tuple

View file

@ -472,6 +472,10 @@ class DeezerClient(Client):
return response return response
def login(self, **kwargs): def login(self, **kwargs):
"""Log into Deezer.
:param kwargs:
"""
try: try:
arl = kwargs["arl"] arl = kwargs["arl"]
except KeyError: except KeyError:
@ -491,7 +495,6 @@ class DeezerClient(Client):
:param type_: :param type_:
:type type_: str :type type_: str
""" """
GET_FUNCTIONS = { GET_FUNCTIONS = {
"track": self.client.api.get_track, "track": self.client.api.get_track,
"album": self.client.api.get_album, "album": self.client.api.get_album,
@ -630,11 +633,13 @@ class DeezerClient(Client):
class DeezloaderClient(Client): class DeezloaderClient(Client):
"""DeezloaderClient."""
source = "deezer" source = "deezer"
max_quality = 1 max_quality = 1
def __init__(self): def __init__(self):
"""Create a DeezloaderClient."""
self.session = gen_threadsafe_session() self.session = gen_threadsafe_session()
# no login required # no login required
@ -1037,7 +1042,7 @@ class TidalClient(Client):
params["countryCode"] = self.country_code params["countryCode"] = self.country_code
params["limit"] = 100 params["limit"] = 100
r = self.session.get(f"{TIDAL_BASE}/{path}", params=params) r = self.session.get(f"{TIDAL_BASE}/{path}", params=params)
r.raise_for_status() # r.raise_for_status()
return r.json() return r.json()
def _get_video_stream_url(self, video_id: str) -> str: def _get_video_stream_url(self, video_id: str) -> str:

View file

@ -1,41 +1,64 @@
"""Streamrip specific exceptions."""
from typing import List from typing import List
import click import click
class AuthenticationError(Exception): class AuthenticationError(Exception):
pass """AuthenticationError."""
class MissingCredentials(Exception): class MissingCredentials(Exception):
pass """MissingCredentials."""
class IneligibleError(Exception): class IneligibleError(Exception):
pass """IneligibleError.
Raised when the account is not eligible to stream a track.
"""
class InvalidAppIdError(Exception): class InvalidAppIdError(Exception):
pass """InvalidAppIdError."""
class InvalidAppSecretError(Exception): class InvalidAppSecretError(Exception):
pass """InvalidAppSecretError."""
class InvalidQuality(Exception): class InvalidQuality(Exception):
pass """InvalidQuality."""
class NonStreamable(Exception): class NonStreamable(Exception):
"""Item is not streamable.
A versatile error that can have many causes.
"""
def __init__(self, message=None): def __init__(self, message=None):
"""Create a NonStreamable exception.
:param message:
"""
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)
def print(self, item): def print(self, item):
"""Print a readable version of the exception.
:param item:
"""
click.echo(self.print_msg(item)) click.echo(self.print_msg(item))
def print_msg(self, item) -> str: def print_msg(self, item) -> str:
"""Return a generic readable message.
:param item:
:type item: Media
:rtype: str
"""
base_msg = [click.style(f"Unable to stream {item!s}.", fg="yellow")] base_msg = [click.style(f"Unable to stream {item!s}.", fg="yellow")]
if self.message: if self.message:
base_msg.extend( base_msg.extend(
@ -49,38 +72,45 @@ class NonStreamable(Exception):
class InvalidContainerError(Exception): class InvalidContainerError(Exception):
pass """InvalidContainerError."""
class InvalidSourceError(Exception): class InvalidSourceError(Exception):
pass """InvalidSourceError."""
class ParsingError(Exception): class ParsingError(Exception):
pass """ParsingError."""
class TooLargeCoverArt(Exception): class TooLargeCoverArt(Exception):
pass """TooLargeCoverArt."""
class BadEncoderOption(Exception): class BadEncoderOption(Exception):
pass """BadEncoderOption."""
class ConversionError(Exception): class ConversionError(Exception):
pass """ConversionError."""
class NoResultsFound(Exception): class NoResultsFound(Exception):
pass """NoResultsFound."""
class ItemExists(Exception): class ItemExists(Exception):
pass """ItemExists."""
class PartialFailure(Exception): class PartialFailure(Exception):
"""Raise if part of a tracklist fails to download."""
def __init__(self, failed_items: List): def __init__(self, failed_items: List):
"""Create a PartialFailure exception.
:param failed_items:
:type failed_items: List
"""
self.failed_items = failed_items self.failed_items = failed_items
super().__init__() super().__init__()

View file

@ -63,38 +63,60 @@ TYPE_REGEXES = {
class Media(abc.ABC): class Media(abc.ABC):
"""An interface for a downloadable item."""
@abc.abstractmethod @abc.abstractmethod
def download(self, **kwargs): def download(self, **kwargs):
"""Download the item.
:param kwargs:
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def load_meta(self, **kwargs): def load_meta(self, **kwargs):
"""Load all of the metadata for an item.
:param kwargs:
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def tag(self, **kwargs): def tag(self, **kwargs):
"""Tag this item with metadata, if applicable.
:param kwargs:
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def convert(self, **kwargs): def convert(self, **kwargs):
"""Convert this item between file formats.
:param kwargs:
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def __repr__(self): def __repr__(self):
"""Return a string representation of the item."""
pass pass
@abc.abstractmethod @abc.abstractmethod
def __str__(self): def __str__(self):
"""Get a readable representation of the item."""
pass pass
@property @property
@abc.abstractmethod @abc.abstractmethod
def type(self): def type(self):
"""Return the type of the item."""
pass pass
@property @property
@abc.abstractmethod @abc.abstractmethod
def downloaded_ids(self): def downloaded_ids(self):
"""If the item is a collection, this is a set of downloaded IDs."""
pass pass
@downloaded_ids.setter @downloaded_ids.setter
@ -268,8 +290,8 @@ class Track(Media):
try: try:
dl_info = self.client.get_file_url(url_id, self.quality) dl_info = self.client.get_file_url(url_id, self.quality)
except Exception as e: except Exception as e:
# click.secho(f"Unable to download track. {e}", fg="red") # raise NonStreamable(repr(e))
raise NonStreamable(repr(e)) raise NonStreamable(e)
if self.client.source == "qobuz": if self.client.source == "qobuz":
if not self.__validate_qobuz_dl_info(dl_info): if not self.__validate_qobuz_dl_info(dl_info):
@ -429,6 +451,10 @@ class Track(Media):
@property @property
def type(self) -> str: def type(self) -> str:
"""Return "track".
:rtype: str
"""
return "track" return "track"
@property @property
@ -754,6 +780,7 @@ class Track(Media):
return f"{self['artist']} - {self['title']}" return f"{self['artist']} - {self['title']}"
def __bool__(self): def __bool__(self):
"""Return True."""
return True return True
@ -835,6 +862,13 @@ class Video(Media):
) )
def convert(self, *args, **kwargs): def convert(self, *args, **kwargs):
"""Return None.
Dummy method.
:param args:
:param kwargs:
"""
pass pass
@property @property
@ -854,6 +888,10 @@ class Video(Media):
@property @property
def type(self) -> str: def type(self) -> str:
"""Return "video".
:rtype: str
"""
return "video" return "video"
def __str__(self) -> str: def __str__(self) -> str:
@ -871,6 +909,7 @@ class Video(Media):
return f"<Video - {self.title}>" return f"<Video - {self.title}>"
def __bool__(self): def __bool__(self):
"""Return True."""
return True return True
@ -966,6 +1005,7 @@ class YoutubeVideo(Media):
pass pass
def __bool__(self): def __bool__(self):
"""Return True."""
return True return True
@ -1001,9 +1041,14 @@ class Booklet:
_quick_download(self.url, filepath, "Booklet") _quick_download(self.url, filepath, "Booklet")
def type(self) -> str: def type(self) -> str:
"""Return "booklet".
:rtype: str
"""
return "booklet" return "booklet"
def __bool__(self): def __bool__(self):
"""Return True."""
return True return True
@ -1039,13 +1084,13 @@ class Tracklist(list):
# TODO: make this function return the items that have not been downloaded # TODO: make this function return the items that have not been downloaded
failed_downloads: List[Tuple[str, str, str]] = [] failed_downloads: List[Tuple[str, str, str]] = []
if kwargs.get("concurrent_downloads", True): if kwargs.get("concurrent_downloads", True):
click.echo() # To separate cover progress bars and the rest
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(
kwargs.get("max_connections", 3) kwargs.get("max_connections", 3)
) as executor: ) as executor:
future_map = { future_map = {
executor.submit(target, item, **kwargs): item for item in self executor.submit(target, item, **kwargs): item for item in self
} }
# futures = [executor.submit(target, item, **kwargs) for item in self]
try: try:
concurrent.futures.wait(future_map.keys()) concurrent.futures.wait(future_map.keys())
for future in future_map.keys(): for future in future_map.keys():
@ -1248,10 +1293,15 @@ class Tracklist(list):
@property @property
def type(self) -> str: def type(self) -> str:
return self.__class__.__name__.lower() """Return "booklet".
:rtype: str
"""
return "booklet"
@property @property
def downloaded_ids(self): def downloaded_ids(self):
"""Return the IDs of tracks that have been downloaded."""
raise NotImplementedError raise NotImplementedError
def __getitem__(self, key): def __getitem__(self, key):
@ -1278,6 +1328,7 @@ class Tracklist(list):
super().__setitem__(key, val) super().__setitem__(key, val)
def __bool__(self): def __bool__(self):
"""Return True."""
return True return True
@ -1809,6 +1860,7 @@ class Playlist(Tracklist, Media):
return f"<Playlist: {self.name}>" return f"<Playlist: {self.name}>"
def tag(self): def tag(self):
"""Raise NotImplementedError."""
raise NotImplementedError raise NotImplementedError
def __str__(self) -> str: def __str__(self) -> str:

View file

@ -11,7 +11,7 @@ import re
from collections import OrderedDict from collections import OrderedDict
from json import JSONDecodeError from json import JSONDecodeError
from string import Formatter from string import Formatter
from typing import Dict, Generator, Hashable, Optional, Tuple, Union from typing import Dict, Hashable, Iterator, Optional, Tuple, Union
import click import click
import requests import requests
@ -27,6 +27,116 @@ urllib3.disable_warnings()
logger = logging.getLogger("streamrip") logger = logging.getLogger("streamrip")
class DownloadStream:
"""An iterator over chunks of a stream.
Usage:
>>> stream = DownloadStream('https://google.com', None)
>>> with open('google.html', 'wb') as file:
>>> for chunk in stream:
>>> file.write(chunk)
"""
is_encrypted = re.compile("/m(?:obile|edia)/")
def __init__(
self,
url: str,
source: str = None,
params: dict = None,
headers: dict = None,
item_id: str = None,
):
"""Create an iterable DownloadStream of a URL.
:param url: The url to download
:type url: str
:param source: Only applicable for Deezer
:type source: str
:param params: Parameters to pass in the request
:type params: dict
:param headers: Headers to pass in the request
:type headers: dict
:param item_id: (Only for Deezer) the ID of the track
:type item_id: str
"""
self.source = source
self.session = gen_threadsafe_session(headers=headers)
self.id = item_id
if isinstance(self.id, int):
self.id = str(self.id)
if params is None:
params = {}
self.request = self.session.get(
url, allow_redirects=True, stream=True, params=params
)
self.file_size = int(self.request.headers.get("Content-Length", 0))
if self.file_size == 0:
raise NonStreamable
def __iter__(self) -> Iterator:
"""Iterate through chunks of the stream.
:rtype: Iterator
"""
if self.source == "deezer" and self.is_encrypted.search(self.url) is not None:
assert isinstance(self.id, str), self.id
blowfish_key = self._generate_blowfish_key(self.id)
return (
(self._decrypt_chunk(blowfish_key, chunk[:2048]) + chunk[2048:])
if len(chunk) >= 2048
else chunk
for chunk in self.request.iter_content(2048 * 3)
)
return self.request.iter_content(chunk_size=1024)
@property
def url(self):
"""Return the requested url."""
return self.request.url
def __len__(self) -> int:
"""Return the value of the "Content-Length" header.
:rtype: int
"""
return self.file_size
@staticmethod
def _generate_blowfish_key(track_id: str):
"""Generate the blowfish key for Deezer downloads.
:param track_id:
:type track_id: str
"""
SECRET = "g4el58wc0zvf9na1"
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
# good luck :)
return "".join(
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
).encode()
@staticmethod
def _decrypt_chunk(key, data):
"""Decrypt a chunk of a Deezer stream.
:param key:
:param data:
"""
return Blowfish.new(
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
).decrypt(data)
def safe_get(d: dict, *keys: Hashable, default=None): def safe_get(d: dict, *keys: Hashable, default=None):
"""Traverse dict layers safely. """Traverse dict layers safely.
@ -84,7 +194,6 @@ def get_quality(quality_id: int, source: str) -> Union[str, int, Tuple[int, str]
:type source: str :type source: str
:rtype: Union[str, int] :rtype: Union[str, int]
""" """
return __QUALITY_MAP[source][quality_id] return __QUALITY_MAP[source][quality_id]
@ -175,84 +284,6 @@ def tqdm_download(url: str, filepath: str, params: dict = None, desc: str = None
raise raise
class DownloadStream:
"""An iterator over chunks of a stream.
Usage:
>>> stream = DownloadStream('https://google.com', None)
>>> with open('google.html', 'wb') as file:
>>> for chunk in stream:
>>> file.write(chunk)
"""
is_encrypted = re.compile("/m(?:obile|edia)/")
def __init__(
self,
url: str,
source: str = None,
params: dict = None,
headers: dict = None,
item_id: str = None,
):
self.source = source
self.session = gen_threadsafe_session(headers=headers)
self.id = item_id
if isinstance(self.id, int):
self.id = str(self.id)
if params is None:
params = {}
self.request = self.session.get(
url, allow_redirects=True, stream=True, params=params
)
self.file_size = int(self.request.headers.get("Content-Length", 0))
if self.file_size == 0:
raise NonStreamable
def __iter__(self) -> Generator:
if self.source == "deezer" and self.is_encrypted.search(self.url) is not None:
assert isinstance(self.id, str), self.id
blowfish_key = self._generate_blowfish_key(self.id)
return (
(self._decrypt_chunk(blowfish_key, chunk[:2048]) + chunk[2048:])
if len(chunk) >= 2048
else chunk
for chunk in self.request.iter_content(2048 * 3)
)
return self.request.iter_content(chunk_size=1024)
@property
def url(self):
return self.request.url
def __len__(self):
return self.file_size
@staticmethod
def _generate_blowfish_key(track_id: str):
SECRET = "g4el58wc0zvf9na1"
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
# good luck :)
return "".join(
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
).encode()
@staticmethod
def _decrypt_chunk(key, data):
return Blowfish.new(
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
).decrypt(data)
def clean_format(formatter: str, format_info): def clean_format(formatter: str, format_info):
"""Format track or folder names sanitizing every formatter key. """Format track or folder names sanitizing every formatter key.
@ -425,6 +456,14 @@ def get_container(quality: int, source: str) -> str:
def get_cover_urls(resp: dict, source: str) -> dict: def get_cover_urls(resp: dict, source: str) -> dict:
"""Parse a response dict containing cover info according to the source.
:param resp:
:type resp: dict
:param source:
:type source: str
:rtype: dict
"""
if source == "qobuz": if source == "qobuz":
cover_urls = OrderedDict(resp["image"]) cover_urls = OrderedDict(resp["image"])
cover_urls["original"] = cover_urls["large"].replace("600", "org") cover_urls["original"] = cover_urls["large"].replace("600", "org")
@ -458,8 +497,10 @@ def get_cover_urls(resp: dict, source: str) -> dict:
def downsize_image(filepath: str, width: int, height: int): def downsize_image(filepath: str, width: int, height: int):
"""Downsize an image. If either the width or the height is greater """Downsize an image.
than the image's width or height, that dimension will not be changed.
If either the width or the height is greater than the image's width or
height, that dimension will not be changed.
:param filepath: :param filepath:
:type filepath: str :type filepath: str
@ -496,11 +537,26 @@ TQDM_BAR_FORMAT = TQDM_THEMES["dainty"]
def set_progress_bar_theme(theme: str): def set_progress_bar_theme(theme: str):
"""Set the theme of the tqdm progress bar.
:param theme:
:type theme: str
"""
global TQDM_BAR_FORMAT global TQDM_BAR_FORMAT
TQDM_BAR_FORMAT = TQDM_THEMES[theme] TQDM_BAR_FORMAT = TQDM_THEMES[theme]
def tqdm_stream(iterator: DownloadStream, desc: Optional[str] = None) -> Generator: def tqdm_stream(
iterator: DownloadStream, desc: Optional[str] = None
) -> Iterator[bytes]:
"""Return a tqdm bar with presets appropriate for downloading large files.
:param iterator:
:type iterator: DownloadStream
:param desc: Description to add for the progress bar
:type desc: Optional[str]
:rtype: Iterator
"""
with tqdm( with tqdm(
total=len(iterator), total=len(iterator),
unit="B", unit="B",