Start comprehensive typing

This commit is contained in:
nathom 2021-04-28 00:24:17 -07:00
parent dad58d8d22
commit e6a5d2cd39
8 changed files with 157 additions and 115 deletions

20
.mypy.ini Normal file
View file

@ -0,0 +1,20 @@
[mypy-mutagen.*]
ignore_missing_imports = True
[mypy-tqdm.*]
ignore_missing_imports = True
[mypy-pathvalidate.*]
ignore_missing_imports = True
[mypy-packaging.*]
ignore_missing_imports = True
[mypy-ruamel.yaml.*]
ignore_missing_imports = True
[mypy-pick.*]
ignore_missing_imports = True
[mypy-simple_term_menu.*]
ignore_missing_imports = True

View file

@ -86,6 +86,10 @@ class Track:
self.downloaded = False
self.tagged = False
self.converted = False
self.final_path: str
self.container: str
# TODO: find better solution
for attr in ("quality", "folder", "meta"):
setattr(self, attr, None)
@ -236,12 +240,10 @@ class Track:
if not kwargs.get("stay_temp", False):
self.move(self.final_path)
try:
database = kwargs.get("database")
database = kwargs.get("database")
if database is not None:
database.add(self.id)
logger.debug(f"{self.id} added to database")
except AttributeError: # assume database=None was passed
pass
logger.debug("Downloaded: %s -> %s", self.path, self.final_path)
@ -273,7 +275,7 @@ class Track:
shutil.move(self.path, path)
self.path = path
def _soundcloud_download(self, dl_info: dict) -> str:
def _soundcloud_download(self, dl_info: dict):
"""Downloads a soundcloud track. This requires a seperate function
because there are three methods that can be used to download a track:
* original file downloads
@ -708,6 +710,9 @@ class Booklet:
:param resp:
:type resp: dict
"""
self.url: str
self.description: str
self.__dict__.update(resp)
def download(self, parent_folder: str, **kwargs):
@ -861,9 +866,7 @@ class Tracklist(list):
return cls(client=client, **info)
@staticmethod
def get_cover_obj(
cover_path: str, container: str, source: str
) -> Union[Picture, APIC]:
def get_cover_obj(cover_path: str, container: str, source: str):
"""Given the path to an image and a quality id, return an initialized
cover object that can be used for every track in the album.
@ -907,7 +910,7 @@ class Tracklist(list):
with open(cover_path, "rb") as img:
return cover(img.read(), imageformat=MP4Cover.FORMAT_JPEG)
def download_message(self) -> str:
def download_message(self):
"""The message to display after calling `Tracklist.download`.
:rtype: str
@ -938,14 +941,14 @@ class Tracklist(list):
return album
def __getitem__(self, key: Union[str, int]):
def __getitem__(self, key):
if isinstance(key, str):
return getattr(self, key)
if isinstance(key, int):
return super().__getitem__(key)
def __setitem__(self, key: Union[str, int], val: Any):
def __setitem__(self, key, val):
if isinstance(key, str):
setattr(self, key, val)
@ -990,7 +993,7 @@ class YoutubeVideo:
)
if download_youtube_videos:
click.secho("Downloading video stream", fg='blue')
click.secho("Downloading video stream", fg="blue")
pv = subprocess.Popen(
[
"youtube-dl",

View file

@ -13,6 +13,7 @@ from .constants import (
AGENT,
AVAILABLE_QUALITY_IDS,
DEEZER_BASE,
DEEZER_DL,
DEEZER_MAX_Q,
QOBUZ_BASE,
QOBUZ_FEATURED_KEYS,
@ -43,6 +44,10 @@ class Client(ABC):
it is merely a template.
"""
source: str
max_quality: int
logged_in: bool
@abstractmethod
def login(self, **kwargs):
"""Authenticate the client.
@ -71,25 +76,13 @@ class Client(ABC):
pass
@abstractmethod
def get_file_url(self, track_id, quality=3) -> Union[dict]:
def get_file_url(self, track_id, quality=3) -> Union[dict, str]:
"""Get the direct download url dict for a file.
:param track_id: id of the track
"""
pass
@property
@abstractmethod
def source(self):
"""Source from which the Client retrieves data."""
pass
@property
@abstractmethod
def max_quality(self):
"""The maximum quality that the Client supports."""
pass
class QobuzClient(Client):
source = "qobuz"
@ -99,7 +92,7 @@ class QobuzClient(Client):
def __init__(self):
self.logged_in = False
def login(self, email: str, pwd: str, **kwargs):
def login(self, **kwargs):
"""Authenticate the QobuzClient. Must have a paid membership.
If `app_id` and `secrets` are not provided, this will run the
@ -113,6 +106,8 @@ class QobuzClient(Client):
:param kwargs: app_id: str, secrets: list, return_secrets: bool
"""
click.secho(f"Logging into {self.source}", fg="green")
email: str = kwargs["email"]
pwd: str = kwargs["pwd"]
if self.logged_in:
logger.debug("Already logged in")
return
@ -184,7 +179,7 @@ class QobuzClient(Client):
# ---------- Private Methods ---------------
def _gen_pages(self, epoint: str, params: dict) -> dict:
def _gen_pages(self, epoint: str, params: dict) -> Generator:
"""When there are multiple pages of results, this lazily
yields them.
@ -352,7 +347,7 @@ class QobuzClient(Client):
else:
raise InvalidAppSecretError("Cannot find app secret")
quality = get_quality(quality, self.source)
quality = int(get_quality(quality, self.source))
r_sig = f"trackgetFileUrlformat_id{quality}intentstreamtrack_id{track_id}{unix_ts}{secret}"
logger.debug("Raw request signature: %s", r_sig)
r_sig_hashed = hashlib.md5(r_sig.encode("utf-8")).hexdigest()
@ -857,7 +852,7 @@ class SoundCloudClient(Client):
return resp
def get_file_url(self, track: dict, quality) -> dict:
def get_file_url(self, track, quality):
"""Get the streamable file url from soundcloud.
It will most likely be an hls stream, which will have to be manually
@ -868,6 +863,9 @@ class SoundCloudClient(Client):
:param quality:
:rtype: dict
"""
# TODO: find better solution for typing
assert isinstance(track, dict)
if not track["streamable"] or track["policy"] == "BLOCK":
raise Exception

View file

@ -5,6 +5,7 @@ import os
import re
from functools import cache
from pprint import pformat
from typing import Any, Dict
from ruamel.yaml import YAML
@ -45,7 +46,7 @@ class Config:
values.
"""
defaults = {
defaults: Dict[str, Any] = {
"qobuz": {
"quality": 3,
"download_booklets": True,
@ -107,8 +108,8 @@ class Config:
def __init__(self, path: str = None):
# to access settings loaded from yaml file
self.file = copy.deepcopy(self.defaults)
self.session = copy.deepcopy(self.defaults)
self.file: Dict[str, Any] = copy.deepcopy(self.defaults)
self.session: Dict[str, Any] = copy.deepcopy(self.defaults)
if path is None:
self._path = CONFIG_PATH

View file

@ -15,11 +15,11 @@ SAMPLING_RATES = (44100, 48000, 88200, 96000, 176400, 192000)
class Converter:
"""Base class for audio codecs."""
codec_name = None
codec_lib = None
container = None
lossless = False
default_ffmpeg_arg = ""
codec_name: str
codec_lib: str
container: str
lossless: bool = False
default_ffmpeg_arg: str = ""
def __init__(
self,

View file

@ -6,7 +6,7 @@ import sys
from getpass import getpass
from hashlib import md5
from string import Formatter
from typing import Generator, Optional, Tuple, Union
from typing import Dict, Generator, List, Optional, Tuple, Type, Union
import click
import requests
@ -19,11 +19,11 @@ from .constants import (
CONFIG_PATH,
DB_PATH,
LASTFM_URL_REGEX,
YOUTUBE_URL_REGEX,
MEDIA_TYPES,
QOBUZ_INTERPRETER_URL_REGEX,
SOUNDCLOUD_URL_REGEX,
URL_REGEX,
YOUTUBE_URL_REGEX,
)
from .db import MusicDB
from .exceptions import (
@ -38,7 +38,10 @@ from .utils import extract_interpreter_url
logger = logging.getLogger(__name__)
MEDIA_CLASS = {
Media = Union[
Type[Album], Type[Playlist], Type[Artist], Type[Track], Type[Label], Type[Video]
]
MEDIA_CLASS: Dict[str, Media] = {
"album": Album,
"playlist": Playlist,
"artist": Artist,
@ -46,7 +49,6 @@ MEDIA_CLASS = {
"label": Label,
"video": Video,
}
Media = Union[Album, Playlist, Artist, Track]
class MusicDL(list):
@ -61,9 +63,11 @@ class MusicDL(list):
self.interpreter_url_parse = re.compile(QOBUZ_INTERPRETER_URL_REGEX)
self.youtube_url_parse = re.compile(YOUTUBE_URL_REGEX)
self.config = config
if self.config is None:
self.config: Config
if config is None:
self.config = Config(CONFIG_PATH)
else:
self.config = config
self.clients = {
"qobuz": QobuzClient(),
@ -72,13 +76,14 @@ class MusicDL(list):
"soundcloud": SoundCloudClient(),
}
if config.session["database"]["enabled"]:
if config.session["database"]["path"] is not None:
self.db = MusicDB(config.session["database"]["path"])
self.db: Union[MusicDB, list]
if self.config.session["database"]["enabled"]:
if self.config.session["database"]["path"] is not None:
self.db = MusicDB(self.config.session["database"]["path"])
else:
self.db = MusicDB(DB_PATH)
config.file["database"]["path"] = DB_PATH
config.save()
self.config.file["database"]["path"] = DB_PATH
self.config.save()
else:
self.db = []
@ -175,7 +180,7 @@ class MusicDL(list):
)
click.secho("rip config --reset ", fg="yellow", nl=False)
click.secho("to reset it. You will need to log in again.", fg="red")
click.secho(err, fg='red')
click.secho(err, fg="red")
exit()
logger.debug("Arguments from config: %s", arguments)
@ -247,7 +252,7 @@ class MusicDL(list):
self.config.file["tidal"].update(client.get_tokens())
self.config.save()
def parse_urls(self, url: str) -> Tuple[str, str]:
def parse_urls(self, url: str) -> List[Tuple[str, str, str]]:
"""Returns the type of the url and the id.
Compatible with urls of the form:
@ -262,7 +267,7 @@ class MusicDL(list):
:raises exceptions.ParsingError
"""
parsed = []
parsed: List[Tuple[str, str, str]] = []
interpreter_urls = self.interpreter_url_parse.findall(url)
if interpreter_urls:
@ -291,14 +296,15 @@ class MusicDL(list):
return parsed
def handle_lastfm_urls(self, urls):
# For testing:
# https://www.last.fm/user/nathan3895/playlists/12058911
user_regex = re.compile(r"https://www\.last\.fm/user/([^/]+)/playlists/\d+")
lastfm_urls = self.lastfm_url_parse.findall(urls)
lastfm_source = self.config.session["lastfm"]["source"]
tracks_not_found = 0
def search_query(query: str, playlist: Playlist):
global tracks_not_found
def search_query(query: str, playlist: Playlist) -> bool:
"""Search for a query and add the first result to the given
Playlist object."""
try:
track = next(self.search(lastfm_source, query, media_type="track"))
if self.config.session["metadata"]["set_playlist_to_album"]:
@ -307,29 +313,33 @@ class MusicDL(list):
track.meta.version = track.meta.work = None
playlist.append(track)
return True
except NoResultsFound:
tracks_not_found += 1
return
return False
for purl in lastfm_urls:
click.secho(f"Fetching playlist at {purl}", fg="blue")
title, queries = self.get_lastfm_playlist(purl)
pl = Playlist(client=self.get_client(lastfm_source), name=title)
pl.creator = user_regex.search(purl).group(1)
creator_match = user_regex.search(purl)
if creator_match is not None:
pl.creator = creator_match.group(1)
tracks_not_found: int = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=15) as executor:
futures = [
executor.submit(search_query, f"{title} {artist}", pl)
for title, artist in queries
]
# only for the progress bar
for f in tqdm(
for search_attempt in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Searching",
):
pass
if not search_attempt.result():
tracks_not_found += 1
pl.loaded = True
click.secho(f"{tracks_not_found} tracks not found.", fg="yellow")
@ -362,7 +372,7 @@ class MusicDL(list):
else page["albums"]["items"]
)
for item in tracklist:
yield MEDIA_CLASS[
yield MEDIA_CLASS[ # type: ignore
media_type if media_type != "featured" else "album"
].from_api(item, client)
i += 1
@ -376,7 +386,7 @@ class MusicDL(list):
raise NoResultsFound(query)
for item in items:
yield MEDIA_CLASS[media_type].from_api(item, client)
yield MEDIA_CLASS[media_type].from_api(item, client) # type: ignore
i += 1
if i > limit:
return
@ -408,7 +418,7 @@ class MusicDL(list):
ret = fmt.format(**{k: media.get(k, default="Unknown") for k in fields})
return ret
def interactive_search(
def interactive_search( # noqa
self, query: str, source: str = "qobuz", media_type: str = "album"
):
results = tuple(self.search(source, query, media_type, limit=50))
@ -506,13 +516,21 @@ class MusicDL(list):
r = requests.get(url)
get_titles(r.text)
remaining_tracks = (
int(re.search(r'data-playlisting-entry-count="(\d+)"', r.text).group(1))
- 50
remaining_tracks_match = re.search(
r'data-playlisting-entry-count="(\d+)"', r.text
)
playlist_title = re.search(
if remaining_tracks_match is not None:
remaining_tracks = int(remaining_tracks_match.group(1)) - 50
else:
raise Exception("Error parsing lastfm page")
playlist_title_match = re.search(
r'<h1 class="playlisting-playlist-header-title">([^<]+)</h1>', r.text
).group(1)
)
if playlist_title_match is not None:
playlist_title = playlist_title_match.group(1)
else:
raise Exception("Error finding title from response")
page = 1
while remaining_tracks > 0:

View file

@ -2,7 +2,7 @@
import logging
import re
from collections import OrderedDict
from typing import Generator, Hashable, Optional, Tuple, Union
from typing import Generator, Hashable, Iterable, Optional, Union
from .constants import (
COPYRIGHT,
@ -59,34 +59,37 @@ class TrackMetadata:
:type album: Optional[dict]
"""
# embedded information
self.title = None
self.album = None
self.albumartist = None
self.composer = None
self.comment = None
self.description = None
self.purchase_date = None
self.grouping = None
self.lyrics = None
self.encoder = None
self.compilation = None
self.cover = None
self.tracktotal = None
self.tracknumber = None
self.discnumber = None
self.disctotal = None
self.title: str
self.album: str
self.albumartist: str
self.composer: str
self.comment: Optional[str]
self.description: Optional[str]
self.purchase_date: Optional[str]
self.grouping: Optional[str]
self.lyrics: Optional[str]
self.encoder: Optional[str]
self.compilation: Optional[str]
self.cover: str
self.tracktotal: int
self.tracknumber: int
self.discnumber: int
self.disctotal: int
# not included in tags
self.explicit = False
self.quality = None
self.sampling_rate = None
self.bit_depth = None
self.explicit: Optional[bool] = False
self.quality: Optional[int] = None
self.sampling_rate: Optional[int] = None
self.bit_depth: Optional[int] = None
self.booklets = None
self.cover_urls = Optional[OrderedDict]
self.work: Optional[str]
self.id: Optional[str]
# Internals
self._artist = None
self._copyright = None
self._genres = None
self._artist: Optional[str] = None
self._copyright: Optional[str] = None
self._genres: Optional[Iterable] = None
self.__source = source
@ -121,7 +124,7 @@ class TrackMetadata:
"""
if self.__source == "qobuz":
# Tags
self.album = resp.get("title")
self.album = resp.get("title", "Unknown Album")
self.tracktotal = resp.get("tracks_count", 1)
self.genre = resp.get("genres_list") or resp.get("genre")
self.date = resp.get("release_date_original") or resp.get("release_date")
@ -144,7 +147,7 @@ class TrackMetadata:
# Non-embedded information
self.version = resp.get("version")
self.cover_urls = OrderedDict(resp.get("image"))
self.cover_urls = OrderedDict(resp["image"])
self.cover_urls["original"] = self.cover_urls["large"].replace("600", "org")
self.streamable = resp.get("streamable", False)
self.bit_depth = resp.get("maximum_bit_depth")
@ -156,14 +159,14 @@ class TrackMetadata:
self.sampling_rate *= 1000
elif self.__source == "tidal":
self.album = resp.get("title")
self.album = resp.get("title", "Unknown Album")
self.tracktotal = resp.get("numberOfTracks", 1)
# genre not returned by API
self.date = resp.get("releaseDate")
self.copyright = resp.get("copyright")
self.albumartist = safe_get(resp, "artist", "name")
self.disctotal = resp.get("numberOfVolumes")
self.disctotal = resp.get("numberOfVolumes", 1)
self.isrc = resp.get("isrc")
# label not returned by API
@ -185,8 +188,8 @@ class TrackMetadata:
self.sampling_rate = 44100
elif self.__source == "deezer":
self.album = resp.get("title")
self.tracktotal = resp.get("track_total") or resp.get("nb_tracks")
self.album = resp.get("title", "Unknown Album")
self.tracktotal = resp.get("track_total", 0) or resp.get("nb_tracks", 0)
self.disctotal = (
max(track.get("disk_number") for track in resp.get("tracks", [{}])) or 1
)
@ -224,7 +227,7 @@ class TrackMetadata:
:param track:
"""
if self.__source == "qobuz":
self.title = track.get("title").strip()
self.title = track["title"].strip()
self._mod_title(track.get("version"), track.get("work"))
self.composer = track.get("composer", {}).get("name")
@ -235,24 +238,23 @@ class TrackMetadata:
self.artist = self.get("albumartist")
elif self.__source == "tidal":
self.title = track.get("title").strip()
self.title = track["title"].strip()
self._mod_title(track.get("version"), None)
self.tracknumber = track.get("trackNumber", 1)
self.discnumber = track.get("volumeNumber")
self.discnumber = track.get("volumeNumber", 1)
self.artist = track.get("artist", {}).get("name")
elif self.__source == "deezer":
self.title = track.get("title").strip()
self.title = track["title"].strip()
self._mod_title(track.get("version"), None)
self.tracknumber = track.get("track_position", 1)
self.discnumber = track.get("disk_number")
self.discnumber = track.get("disk_number", 1)
self.artist = track.get("artist", {}).get("name")
elif self.__source == "soundcloud":
self.title = track["title"].strip()
self.genre = track["genre"]
self.artist = track["user"]["username"]
self.albumartist = self.artist
self.artist = self.albumartist = track["user"]["username"]
self.year = track["created_at"][:4]
self.label = track["label_name"]
self.description = track["description"]
@ -287,7 +289,7 @@ class TrackMetadata:
return album
@album.setter
def album(self, val) -> str:
def album(self, val):
self._album = val
@property
@ -331,7 +333,7 @@ class TrackMetadata:
if isinstance(self._genres, list):
if self.__source == "qobuz":
genres = re.findall(r"([^\u2192\/]+)", "/".join(self._genres))
genres: Iterable = re.findall(r"([^\u2192\/]+)", "/".join(self._genres))
genres = set(genres)
return ", ".join(genres)
@ -342,7 +344,7 @@ class TrackMetadata:
raise TypeError(f"Genre must be list or str, not {type(self._genres)}")
@genre.setter
def genre(self, val: Union[str, list]):
def genre(self, val: Union[Iterable, dict]):
"""Sets the internal `genre` field to the given list.
It is not formatted until it is requested with `meta.genre`.
@ -352,7 +354,7 @@ class TrackMetadata:
self._genres = val
@property
def copyright(self) -> Union[str, None]:
def copyright(self) -> Optional[str]:
"""Formats the copyright string to use nice-looking unicode
characters.
@ -361,11 +363,11 @@ class TrackMetadata:
if hasattr(self, "_copyright"):
if self._copyright is None:
return None
copyright = re.sub(r"(?i)\(P\)", PHON_COPYRIGHT, self._copyright)
copyright: str = re.sub(r"(?i)\(P\)", PHON_COPYRIGHT, self._copyright)
copyright = re.sub(r"(?i)\(C\)", COPYRIGHT, copyright)
return copyright
logger.debug("Accessed copyright tag before setting, return None")
logger.debug("Accessed copyright tag before setting, returning None")
return None
@copyright.setter
@ -440,7 +442,7 @@ class TrackMetadata:
raise InvalidContainerError(f"Invalid container {container}")
def __gen_flac_tags(self) -> Tuple[str, str]:
def __gen_flac_tags(self) -> Generator:
"""Generate key, value pairs to tag FLAC files.
:rtype: Tuple[str, str]
@ -454,7 +456,7 @@ class TrackMetadata:
logger.debug("Adding tag %s: %s", v, tag)
yield (v, str(tag))
def __gen_mp3_tags(self) -> Tuple[str, str]:
def __gen_mp3_tags(self) -> Generator:
"""Generate key, value pairs to tag MP3 files.
:rtype: Tuple[str, str]
@ -470,7 +472,7 @@ class TrackMetadata:
if text is not None and v is not None:
yield (v.__name__, v(encoding=3, text=text))
def __gen_mp4_tags(self) -> Tuple[str, Union[str, int, tuple]]:
def __gen_mp4_tags(self) -> Generator:
"""Generate key, value pairs to tag ALAC or AAC files in
an MP4 container.
@ -510,7 +512,7 @@ class TrackMetadata:
"""
return getattr(self, key)
def get(self, key, default=None) -> str:
def get(self, key, default=None):
"""Returns the requested attribute of the object, with
a default value.

View file

@ -171,7 +171,7 @@ class Album(Tracklist):
return True
@staticmethod
def _parse_get_resp(resp: dict, client: Client) -> dict:
def _parse_get_resp(resp: dict, client: Client) -> TrackMetadata:
"""Parse information from a client.get(query, 'album') call.
:param resp: