mirror of
https://github.com/nathom/streamrip.git
synced 2025-05-20 10:15:23 -04:00
535 lines
14 KiB
Python
535 lines
14 KiB
Python
"""Miscellaneous utility functions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import logging
|
|
import os
|
|
from string import Formatter
|
|
from typing import (
|
|
Dict,
|
|
Hashable,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
Generator,
|
|
)
|
|
from collections import OrderedDict
|
|
import functools
|
|
from Cryptodome.Cipher import Blowfish
|
|
import hashlib
|
|
import re
|
|
from json import JSONDecodeError
|
|
|
|
import click
|
|
import requests
|
|
from pathvalidate import sanitize_filename
|
|
from requests.packages import urllib3
|
|
from tqdm import tqdm
|
|
|
|
from .constants import TIDAL_COVER_URL, COVER_SIZES
|
|
from .exceptions import InvalidQuality, InvalidSourceError, NonStreamable
|
|
|
|
urllib3.disable_warnings()
|
|
logger = logging.getLogger("streamrip")
|
|
|
|
|
|
def safe_get(d: dict, *keys: Hashable, default=None):
|
|
"""Traverse dict layers safely.
|
|
|
|
Usage:
|
|
>>> d = {'foo': {'bar': 'baz'}}
|
|
>>> safe_get(d, 'baz')
|
|
None
|
|
>>> safe_get(d, 'foo', 'bar')
|
|
'baz'
|
|
|
|
:param d:
|
|
:type d: dict
|
|
:param keys:
|
|
:type keys: Hashable
|
|
:param default: the default value to use if a key isn't found
|
|
"""
|
|
curr = d
|
|
res = default
|
|
for key in keys:
|
|
res = curr.get(key, default)
|
|
if res == default or not hasattr(res, "__getitem__"):
|
|
return res
|
|
else:
|
|
curr = res
|
|
return res
|
|
|
|
"""
|
|
FLAC = 9
|
|
MP3_320 = 3
|
|
MP3_128 = 1
|
|
MP4_RA3 = 15
|
|
MP4_RA2 = 14
|
|
MP4_RA1 = 13
|
|
DEFAULT = 8
|
|
LOCAL = 0
|
|
|
|
|
|
"""
|
|
|
|
|
|
__QUALITY_MAP: Dict[str, Dict[int, Union[int, str, Tuple[int, str]]]] = {
|
|
"qobuz": {
|
|
1: 5,
|
|
2: 6,
|
|
3: 7,
|
|
4: 27,
|
|
},
|
|
"deezer": {
|
|
0: (9, "MP3_128"),
|
|
1: (3, "MP3_320"),
|
|
2: (1, "FLAC"),
|
|
},
|
|
"tidal": {
|
|
0: "LOW", # AAC
|
|
1: "HIGH", # AAC
|
|
2: "LOSSLESS", # CD Quality
|
|
3: "HI_RES", # MQA
|
|
},
|
|
}
|
|
|
|
|
|
def get_quality(quality_id: int, source: str) -> Union[str, int, Tuple[int, str]]:
|
|
"""Get the source-specific quality id.
|
|
|
|
:param quality_id: the universal quality id (0, 1, 2, 4)
|
|
:type quality_id: int
|
|
:param source: qobuz, tidal, or deezer
|
|
:type source: str
|
|
:rtype: Union[str, int]
|
|
"""
|
|
|
|
return __QUALITY_MAP[source][quality_id]
|
|
|
|
|
|
def get_quality_id(bit_depth: Optional[int], sampling_rate: Optional[int]):
|
|
"""Get the universal quality id from bit depth and sampling rate.
|
|
|
|
:param bit_depth:
|
|
:type bit_depth: Optional[int]
|
|
:param sampling_rate:
|
|
:type sampling_rate: Optional[int]
|
|
"""
|
|
# XXX: Should `0` quality be supported?
|
|
if bit_depth is None or sampling_rate is None: # is lossy
|
|
return 1
|
|
|
|
if bit_depth == 16:
|
|
return 2
|
|
|
|
if bit_depth == 24:
|
|
if sampling_rate <= 96:
|
|
return 3
|
|
|
|
return 4
|
|
|
|
|
|
def get_stats_from_quality(
|
|
quality_id: int,
|
|
) -> Tuple[Optional[int], Optional[int]]:
|
|
"""Get bit depth and sampling rate based on the quality id.
|
|
|
|
:param quality_id:
|
|
:type quality_id: int
|
|
:rtype: Tuple[Optional[int], Optional[int]]
|
|
"""
|
|
if quality_id <= 1:
|
|
return (None, None)
|
|
elif quality_id == 2:
|
|
return (16, 44100)
|
|
elif quality_id == 3:
|
|
return (24, 96000)
|
|
elif quality_id == 4:
|
|
return (24, 192000)
|
|
else:
|
|
raise InvalidQuality(quality_id)
|
|
|
|
|
|
def tqdm_download(url: str, filepath: str, params: dict = None, desc: str = None):
|
|
"""Download a file with a progress bar.
|
|
|
|
:param url: url to direct download
|
|
:param filepath: file to write
|
|
:type url: str
|
|
:type filepath: str
|
|
"""
|
|
logger.debug(f"Downloading {url} to {filepath} with params {params}")
|
|
if params is None:
|
|
params = {}
|
|
|
|
session = gen_threadsafe_session()
|
|
r = session.get(url, allow_redirects=True, stream=True, params=params)
|
|
total = int(r.headers.get("content-length", 0))
|
|
logger.debug("File size = %s", total)
|
|
if total < 1000 and not url.endswith("jpg") and not url.endswith("png"):
|
|
logger.debug("Response text: %s", r.text)
|
|
try:
|
|
raise NonStreamable(r.json()["error"])
|
|
except JSONDecodeError:
|
|
raise NonStreamable("Resource not found.")
|
|
|
|
try:
|
|
with open(filepath, "wb") as file, tqdm(
|
|
total=total,
|
|
unit="iB",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
desc=desc,
|
|
dynamic_ncols=True,
|
|
# leave=False,
|
|
) as bar:
|
|
for data in r.iter_content(chunk_size=1024):
|
|
size = file.write(data)
|
|
bar.update(size)
|
|
except Exception:
|
|
try:
|
|
os.remove(filepath)
|
|
except OSError:
|
|
pass
|
|
raise
|
|
|
|
|
|
class DownloadStream:
|
|
"""An iterator over chunks of a stream.
|
|
|
|
Usage:
|
|
|
|
>>> stream = DownloadStream('https://google.com', None)
|
|
>>> with open('google.html', 'wb') as file:
|
|
>>> for chunk in stream:
|
|
>>> file.write(chunk)
|
|
|
|
"""
|
|
|
|
is_encrypted = re.compile("/m(?:obile|edia)/")
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
source: str = None,
|
|
params: dict = None,
|
|
headers: dict = None,
|
|
item_id: str = None,
|
|
):
|
|
self.source = source
|
|
self.session = gen_threadsafe_session(headers=headers)
|
|
|
|
self.id = item_id
|
|
if isinstance(self.id, int):
|
|
self.id = str(self.id)
|
|
|
|
if params is None:
|
|
params = {}
|
|
|
|
self.request = self.session.get(
|
|
url, allow_redirects=True, stream=True, params=params
|
|
)
|
|
self.file_size = int(self.request.headers.get("Content-Length", 0))
|
|
|
|
if self.file_size == 0:
|
|
raise NonStreamable
|
|
|
|
def __iter__(self) -> Generator:
|
|
if self.source == "deezer" and self.is_encrypted.search(self.url) is not None:
|
|
assert isinstance(self.id, str), self.id
|
|
|
|
blowfish_key = self._generate_blowfish_key(self.id)
|
|
return (
|
|
(self._decrypt_chunk(blowfish_key, chunk[:2048]) + chunk[2048:])
|
|
if len(chunk) >= 2048
|
|
else chunk
|
|
for chunk in self.request.iter_content(2048 * 3)
|
|
)
|
|
|
|
return self.request.iter_content(chunk_size=1024)
|
|
|
|
@property
|
|
def url(self):
|
|
return self.request.url
|
|
|
|
def __len__(self):
|
|
return self.file_size
|
|
|
|
@staticmethod
|
|
def _generate_blowfish_key(track_id: str):
|
|
SECRET = "g4el58wc0zvf9na1"
|
|
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
|
|
# good luck :)
|
|
return "".join(
|
|
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
|
|
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
|
|
).encode()
|
|
|
|
@staticmethod
|
|
def _decrypt_chunk(key, data):
|
|
return Blowfish.new(
|
|
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
|
|
).decrypt(data)
|
|
|
|
|
|
def clean_format(formatter: str, format_info):
|
|
"""Format track or folder names sanitizing every formatter key.
|
|
|
|
:param formatter:
|
|
:type formatter: str
|
|
:param kwargs:
|
|
"""
|
|
fmt_keys = [i[1] for i in Formatter().parse(formatter) if i[1] is not None]
|
|
|
|
logger.debug("Formatter keys: %s", fmt_keys)
|
|
|
|
clean_dict = dict()
|
|
for key in fmt_keys:
|
|
if isinstance(format_info.get(key), (str, float)):
|
|
clean_dict[key] = sanitize_filename(str(format_info[key]))
|
|
elif isinstance(format_info.get(key), int): # track/discnumber
|
|
clean_dict[key] = f"{format_info[key]:02}"
|
|
else:
|
|
clean_dict[key] = "Unknown"
|
|
|
|
return formatter.format(**clean_dict)
|
|
|
|
|
|
def tidal_cover_url(uuid, size):
|
|
"""Generate a tidal cover url.
|
|
|
|
:param uuid:
|
|
:param size:
|
|
"""
|
|
possibles = (80, 160, 320, 640, 1280)
|
|
assert size in possibles, f"size must be in {possibles}"
|
|
|
|
return TIDAL_COVER_URL.format(uuid=uuid.replace("-", "/"), height=size, width=size)
|
|
|
|
|
|
def init_log(path: Optional[str] = None, level: str = "DEBUG"):
|
|
"""Create a log.
|
|
|
|
:param path:
|
|
:type path: Optional[str]
|
|
:param level:
|
|
:type level: str
|
|
:param rotate:
|
|
:type rotate: str
|
|
"""
|
|
# path = os.path.join(LOG_DIR, "streamrip.log")
|
|
level = logging.getLevelName(level)
|
|
logging.basicConfig(level=level)
|
|
|
|
|
|
def decrypt_mqa_file(in_path, out_path, encryption_key):
|
|
"""Decrypt an MQA file.
|
|
|
|
:param in_path:
|
|
:param out_path:
|
|
:param encryption_key:
|
|
"""
|
|
try:
|
|
from Crypto.Cipher import AES
|
|
from Crypto.Util import Counter
|
|
except (ImportError, ModuleNotFoundError):
|
|
click.secho(
|
|
"To download this item in MQA, you need to run ",
|
|
fg="yellow",
|
|
nl=False,
|
|
)
|
|
click.secho("pip3 install pycryptodome --upgrade", fg="blue", nl=False)
|
|
click.secho(".")
|
|
raise click.Abort
|
|
|
|
# Do not change this
|
|
master_key = "UIlTTEMmmLfGowo/UC60x2H45W6MdGgTRfo/umg4754="
|
|
|
|
# Decode the base64 strings to ascii strings
|
|
master_key = base64.b64decode(master_key)
|
|
security_token = base64.b64decode(encryption_key)
|
|
|
|
# Get the IV from the first 16 bytes of the securityToken
|
|
iv = security_token[:16]
|
|
encrypted_st = security_token[16:]
|
|
|
|
# Initialize decryptor
|
|
decryptor = AES.new(master_key, AES.MODE_CBC, iv)
|
|
|
|
# Decrypt the security token
|
|
decrypted_st = decryptor.decrypt(encrypted_st)
|
|
|
|
# Get the audio stream decryption key and nonce from the decrypted security token
|
|
key = decrypted_st[:16]
|
|
nonce = decrypted_st[16:24]
|
|
|
|
counter = Counter.new(64, prefix=nonce, initial_value=0)
|
|
decryptor = AES.new(key, AES.MODE_CTR, counter=counter)
|
|
|
|
with open(in_path, "rb") as enc_file:
|
|
dec_bytes = decryptor.decrypt(enc_file.read())
|
|
with open(out_path, "wb") as dec_file:
|
|
dec_file.write(dec_bytes)
|
|
|
|
|
|
def ext(quality: int, source: str):
|
|
"""Get the extension of an audio file.
|
|
|
|
:param quality:
|
|
:type quality: int
|
|
:param source:
|
|
:type source: str
|
|
"""
|
|
if quality <= 1:
|
|
if source == "tidal":
|
|
return ".m4a"
|
|
else:
|
|
return ".mp3"
|
|
else:
|
|
return ".flac"
|
|
|
|
|
|
def gen_threadsafe_session(
|
|
headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100
|
|
) -> requests.Session:
|
|
"""Create a new Requests session with a large poolsize.
|
|
|
|
:param headers:
|
|
:type headers: dict
|
|
:param pool_connections:
|
|
:type pool_connections: int
|
|
:param pool_maxsize:
|
|
:type pool_maxsize: int
|
|
:rtype: requests.Session
|
|
"""
|
|
if headers is None:
|
|
headers = {}
|
|
|
|
session = requests.Session()
|
|
adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100)
|
|
session.mount("https://", adapter)
|
|
session.headers.update(headers)
|
|
return session
|
|
|
|
|
|
def decho(message, fg=None):
|
|
"""Debug echo the message.
|
|
|
|
:param message:
|
|
:param fg: ANSI color with which to display the message on the
|
|
screen
|
|
"""
|
|
click.secho(message, fg=fg)
|
|
logger.debug(message)
|
|
|
|
|
|
def get_container(quality: int, source: str) -> str:
|
|
"""Get the file container given the quality.
|
|
|
|
`container` can also be the the codec; both work.
|
|
|
|
:param quality: quality id
|
|
:type quality: int
|
|
:param source:
|
|
:type source: str
|
|
:rtype: str
|
|
"""
|
|
if quality >= 2:
|
|
return "FLAC"
|
|
|
|
if source == "tidal":
|
|
return "AAC"
|
|
|
|
return "MP3"
|
|
|
|
|
|
def get_cover_urls(resp: dict, source: str) -> dict:
|
|
if source == "qobuz":
|
|
cover_urls = OrderedDict(resp["image"])
|
|
cover_urls["original"] = cover_urls["large"].replace("600", "org")
|
|
return cover_urls
|
|
|
|
if source == "tidal":
|
|
uuid = resp["cover"]
|
|
return OrderedDict(
|
|
{
|
|
sk: tidal_cover_url(uuid, size)
|
|
for sk, size in zip(COVER_SIZES, (160, 320, 640, 1280))
|
|
}
|
|
)
|
|
|
|
if source == "deezer":
|
|
cover_urls = OrderedDict(
|
|
{
|
|
sk: resp.get(rk) # size key, resp key
|
|
for sk, rk in zip(
|
|
COVER_SIZES,
|
|
("cover", "cover_medium", "cover_large", "cover_xl"),
|
|
)
|
|
}
|
|
)
|
|
if cover_urls["large"] is None and resp.get("cover_big") is not None:
|
|
cover_urls["large"] = resp["cover_big"]
|
|
|
|
return cover_urls
|
|
|
|
raise InvalidSourceError(source)
|
|
|
|
|
|
def downsize_image(filepath: str, width: int, height: int):
|
|
"""Downsize an image. If either the width or the height is greater
|
|
than the image's width or height, that dimension will not be changed.
|
|
|
|
:param filepath:
|
|
:type filepath: str
|
|
:param width:
|
|
:type width: int
|
|
:param height:
|
|
:type height: int
|
|
:raises: ValueError
|
|
"""
|
|
from PIL import Image
|
|
|
|
image = Image.open(filepath)
|
|
|
|
width = min(width, image.width)
|
|
height = min(height, image.height)
|
|
|
|
resized_image = image.resize((width, height))
|
|
resized_image.save(filepath)
|
|
|
|
|
|
TQDM_THEMES = {
|
|
"plain": None,
|
|
"dainty": (
|
|
"{desc} |{bar}| "
|
|
+ click.style("{remaining}", fg="magenta")
|
|
+ " left at "
|
|
+ click.style("{rate_fmt}{postfix} ", fg="cyan", bold=True)
|
|
),
|
|
}
|
|
|
|
TQDM_DEFAULT_THEME = "dainty"
|
|
|
|
TQDM_BAR_FORMAT = TQDM_THEMES["dainty"]
|
|
|
|
|
|
def set_progress_bar_theme(theme: str):
|
|
global TQDM_BAR_FORMAT
|
|
TQDM_BAR_FORMAT = TQDM_THEMES[theme]
|
|
|
|
|
|
def tqdm_stream(iterator: DownloadStream, desc: Optional[str] = None) -> Generator:
|
|
with tqdm(
|
|
total=len(iterator),
|
|
unit="B",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
desc=desc,
|
|
dynamic_ncols=True,
|
|
bar_format=TQDM_BAR_FORMAT,
|
|
) as bar:
|
|
for chunk in iterator:
|
|
bar.update(len(chunk))
|
|
yield chunk
|