diff --git a/.gitignore b/.gitignore index 43f2052..17d5d5b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ StreamripDownloads *.pyc *test.py /.mypy_cache +.DS_Store diff --git a/.mypy.ini b/.mypy.ini index ac8d379..086d338 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -42,3 +42,12 @@ ignore_missing_imports = True [mypy-appdirs.*] ignore_missing_imports = True + +[mypy-m3u8.*] +ignore_missing_imports = True + +[mypy-aiohttp.*] +ignore_missing_imports = True + +[mypy-aiofiles.*] +ignore_missing_imports = True diff --git a/streamrip/__init__.py b/streamrip/__init__.py index a7e097b..086b816 100644 --- a/streamrip/__init__.py +++ b/streamrip/__init__.py @@ -2,4 +2,4 @@ __version__ = "1.4" -from . import clients, constants, converter, media +from . import clients, constants, converter, downloadtools, media diff --git a/streamrip/downloadtools.py b/streamrip/downloadtools.py new file mode 100644 index 0000000..310ca21 --- /dev/null +++ b/streamrip/downloadtools.py @@ -0,0 +1,241 @@ +import asyncio +import functools +import hashlib +import logging +import os +import re +from tempfile import gettempdir +from typing import Callable, Dict, Generator, Iterator, List, Optional + +import aiofiles +import aiohttp +from Cryptodome.Cipher import Blowfish + +from .exceptions import NonStreamable +from .utils import gen_threadsafe_session + +logger = logging.getLogger("streamrip") + + +class DownloadStream: + """An iterator over chunks of a stream. + + Usage: + + >>> stream = DownloadStream('https://google.com', None) + >>> with open('google.html', 'wb') as file: + >>> for chunk in stream: + >>> file.write(chunk) + + """ + + is_encrypted = re.compile("/m(?:obile|edia)/") + + def __init__( + self, + url: str, + source: str = None, + params: dict = None, + headers: dict = None, + item_id: str = None, + ): + """Create an iterable DownloadStream of a URL. + + :param url: The url to download + :type url: str + :param source: Only applicable for Deezer + :type source: str + :param params: Parameters to pass in the request + :type params: dict + :param headers: Headers to pass in the request + :type headers: dict + :param item_id: (Only for Deezer) the ID of the track + :type item_id: str + """ + self.source = source + self.session = gen_threadsafe_session(headers=headers) + + self.id = item_id + if isinstance(self.id, int): + self.id = str(self.id) + + if params is None: + params = {} + + self.request = self.session.get( + url, allow_redirects=True, stream=True, params=params + ) + self.file_size = int(self.request.headers.get("Content-Length", 0)) + + if self.file_size < 20000 and not self.url.endswith(".jpg"): + import json + + try: + info = self.request.json() + try: + # Usually happens with deezloader downloads + raise NonStreamable( + f"{info['error']} -- {info['message']}" + ) + except KeyError: + raise NonStreamable(info) + + except json.JSONDecodeError: + raise NonStreamable("File not found.") + + def __iter__(self) -> Iterator: + """Iterate through chunks of the stream. + + :rtype: Iterator + """ + if ( + self.source == "deezer" + and self.is_encrypted.search(self.url) is not None + ): + assert isinstance(self.id, str), self.id + + blowfish_key = self._generate_blowfish_key(self.id) + # decryptor = self._create_deezer_decryptor(blowfish_key) + CHUNK_SIZE = 2048 * 3 + return ( + # (decryptor.decrypt(chunk[:2048]) + chunk[2048:]) + ( + self._decrypt_chunk(blowfish_key, chunk[:2048]) + + chunk[2048:] + ) + if len(chunk) >= 2048 + else chunk + for chunk in self.request.iter_content(CHUNK_SIZE) + ) + + return self.request.iter_content(chunk_size=1024) + + @property + def url(self): + """Return the requested url.""" + return self.request.url + + def __len__(self) -> int: + """Return the value of the "Content-Length" header. + + :rtype: int + """ + return self.file_size + + def _create_deezer_decryptor(self, key) -> Blowfish: + return Blowfish.new( + key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07" + ) + + @staticmethod + def _generate_blowfish_key(track_id: str): + """Generate the blowfish key for Deezer downloads. + + :param track_id: + :type track_id: str + """ + SECRET = "g4el58wc0zvf9na1" + md5_hash = hashlib.md5(track_id.encode()).hexdigest() + # good luck :) + return "".join( + chr(functools.reduce(lambda x, y: x ^ y, map(ord, t))) + for t in zip(md5_hash[:16], md5_hash[16:], SECRET) + ).encode() + + @staticmethod + def _decrypt_chunk(key, data): + """Decrypt a chunk of a Deezer stream. + + :param key: + :param data: + """ + return Blowfish.new( + key, + Blowfish.MODE_CBC, + b"\x00\x01\x02\x03\x04\x05\x06\x07", + ).decrypt(data) + + +class DownloadPool: + """Asynchronously download a set of urls.""" + + def __init__( + self, + urls: Generator, + tempdir: str = None, + chunk_callback: Optional[Callable] = None, + ): + self.finished: bool = False + # Enumerate urls to know the order + self.urls = dict(enumerate(urls)) + self._downloaded_urls: List[str] = [] + # {url: path} + self._paths: Dict[str, str] = {} + self.task: Optional[asyncio.Task] = None + self.callack = chunk_callback + + if tempdir is None: + tempdir = gettempdir() + self.tempdir = tempdir + + async def getfn(self, url): + path = os.path.join( + self.tempdir, f"__streamrip_partial_{abs(hash(url))}" + ) + self._paths[url] = path + return path + + async def _download_urls(self): + async with aiohttp.ClientSession() as session: + tasks = [ + asyncio.ensure_future(self._download_url(session, url)) + for url in self.urls.values() + ] + await asyncio.gather(*tasks) + + async def _download_url(self, session, url): + filename = await self.getfn(url) + logger.debug("Downloading %s", url) + async with session.get(url) as response, aiofiles.open( + filename, "wb" + ) as f: + # without aiofiles 3.6632679780000004s + # with aiofiles 2.504482839s + await f.write(await response.content.read()) + + if self.callback: + self.callback() + + logger.debug("Finished %s", url) + + def download(self): + asyncio.run(self._download_urls()) + + @property + def files(self): + if len(self._paths) != len(self.urls): + # Not all of them have downloaded + raise Exception( + "Must run DownloadPool.download() before accessing files" + ) + + return [ + os.path.join(self.tempdir, self._paths[self.urls[i]]) + for i in range(len(self.urls)) + ] + + def __len__(self): + return len(self.urls) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.debug("Removing tempfiles %s", self._paths) + for file in self._paths.values(): + try: + os.remove(file) + except FileNotFoundError: + pass + + return False diff --git a/streamrip/media.py b/streamrip/media.py index e49d439..e010dac 100644 --- a/streamrip/media.py +++ b/streamrip/media.py @@ -39,6 +39,7 @@ from .constants import ( FOLDER_FORMAT, TRACK_FORMAT, ) +from .downloadtools import DownloadPool, DownloadStream from .exceptions import ( InvalidQuality, InvalidSourceError, @@ -49,7 +50,6 @@ from .exceptions import ( ) from .metadata import TrackMetadata from .utils import ( - DownloadStream, clean_filename, clean_format, decrypt_mqa_file, @@ -450,21 +450,44 @@ class Track(Media): """ logger.debug("dl_info: %s", dl_info) if dl_info["type"] == "mp3": + import m3u8 + import requests + + parsed_m3u = m3u8.loads(requests.get(dl_info["url"]).text) self.path += ".mp3" - # convert hls stream to mp3 - subprocess.call( - [ - "ffmpeg", - "-i", - dl_info["url"], - "-c", - "copy", - "-y", - self.path, - "-loglevel", - "fatal", - ] - ) + + with DownloadPool( + segment.uri for segment in parsed_m3u.segments + ) as pool: + pool.download() + subprocess.call( + [ + "ffmpeg", + "-i", + f"concat:{'|'.join(pool.files)}", + "-acodec", + "copy", + "-loglevel", + "panic", + self.path, + ] + ) + + # self.path += ".mp3" + # # convert hls stream to mp3 + # subprocess.call( + # [ + # "ffmpeg", + # "-i", + # dl_info["url"], + # "-c", + # "copy", + # "-y", + # self.path, + # "-loglevel", + # "fatal", + # ] + # ) elif dl_info["type"] == "original": _quick_download( dl_info["url"], self.path, desc=self._progress_desc @@ -857,6 +880,9 @@ class Video(Media): :param kwargs: """ + import m3u8 + import requests + secho( f"Downloading {self.title} (Video). This may take a while.", fg="blue", @@ -864,19 +890,41 @@ class Video(Media): self.parent_folder = kwargs.get("parent_folder", "StreamripDownloads") url = self.client.get_file_url(self.id, video=True) - # it's more convenient to have ffmpeg download the hls - command = [ - "ffmpeg", - "-i", - url, - "-c", - "copy", - "-loglevel", - "panic", - self.path, - ] - p = subprocess.Popen(command) - p.wait() # remove this? + + parsed_m3u = m3u8.loads(requests.get(url).text) + # Asynchronously download the streams + with DownloadPool( + segment.uri for segment in parsed_m3u.segments + ) as pool: + pool.download() + + # Put the filenames in a tempfile that ffmpeg + # can read from + file_list_path = os.path.join( + gettempdir(), "__streamrip_video_files" + ) + with open(file_list_path, "w") as file_list: + text = "\n".join(f"file '{path}'" for path in pool.files) + file_list.write(text) + + # Use ffmpeg to concat the files + p = subprocess.Popen( + [ + "ffmpeg", + "-f", + "concat", + "-safe", + "0", + "-i", + file_list_path, + "-c", + "copy", + self.path, + ] + ) + p.wait() + + os.remove(file_list_path) def tag(self, *args, **kwargs): """Return False. @@ -1396,12 +1444,11 @@ class Tracklist(list): class Album(Tracklist, Media): """Represents a downloadable album. - Usage: - >>> resp = client.get('fleetwood mac rumours', 'album') >>> album = Album.from_api(resp['items'][0], client) >>> album.load_meta() >>> album.download() + """ downloaded_ids: set = set() diff --git a/streamrip/utils.py b/streamrip/utils.py index 2ccc2a7..4d3bab7 100644 --- a/streamrip/utils.py +++ b/streamrip/utils.py @@ -3,166 +3,23 @@ from __future__ import annotations import base64 -import functools -import hashlib import logging -import re from string import Formatter from typing import Dict, Hashable, Iterator, Optional, Tuple, Union import requests from click import secho, style -from Cryptodome.Cipher import Blowfish from pathvalidate import sanitize_filename from requests.packages import urllib3 from tqdm import tqdm from .constants import COVER_SIZES, TIDAL_COVER_URL -from .exceptions import InvalidQuality, InvalidSourceError, NonStreamable +from .exceptions import InvalidQuality, InvalidSourceError urllib3.disable_warnings() logger = logging.getLogger("streamrip") -class DownloadStream: - """An iterator over chunks of a stream. - - Usage: - - >>> stream = DownloadStream('https://google.com', None) - >>> with open('google.html', 'wb') as file: - >>> for chunk in stream: - >>> file.write(chunk) - - """ - - is_encrypted = re.compile("/m(?:obile|edia)/") - - def __init__( - self, - url: str, - source: str = None, - params: dict = None, - headers: dict = None, - item_id: str = None, - ): - """Create an iterable DownloadStream of a URL. - - :param url: The url to download - :type url: str - :param source: Only applicable for Deezer - :type source: str - :param params: Parameters to pass in the request - :type params: dict - :param headers: Headers to pass in the request - :type headers: dict - :param item_id: (Only for Deezer) the ID of the track - :type item_id: str - """ - self.source = source - self.session = gen_threadsafe_session(headers=headers) - - self.id = item_id - if isinstance(self.id, int): - self.id = str(self.id) - - if params is None: - params = {} - - self.request = self.session.get( - url, allow_redirects=True, stream=True, params=params - ) - self.file_size = int(self.request.headers.get("Content-Length", 0)) - - if self.file_size < 20000 and not self.url.endswith(".jpg"): - import json - - try: - info = self.request.json() - try: - # Usually happens with deezloader downloads - raise NonStreamable( - f"{info['error']} -- {info['message']}" - ) - except KeyError: - raise NonStreamable(info) - - except json.JSONDecodeError: - raise NonStreamable("File not found.") - - def __iter__(self) -> Iterator: - """Iterate through chunks of the stream. - - :rtype: Iterator - """ - if ( - self.source == "deezer" - and self.is_encrypted.search(self.url) is not None - ): - assert isinstance(self.id, str), self.id - - blowfish_key = self._generate_blowfish_key(self.id) - # decryptor = self._create_deezer_decryptor(blowfish_key) - CHUNK_SIZE = 2048 * 3 - return ( - # (decryptor.decrypt(chunk[:2048]) + chunk[2048:]) - ( - self._decrypt_chunk(blowfish_key, chunk[:2048]) - + chunk[2048:] - ) - if len(chunk) >= 2048 - else chunk - for chunk in self.request.iter_content(CHUNK_SIZE) - ) - - return self.request.iter_content(chunk_size=1024) - - @property - def url(self): - """Return the requested url.""" - return self.request.url - - def __len__(self) -> int: - """Return the value of the "Content-Length" header. - - :rtype: int - """ - return self.file_size - - def _create_deezer_decryptor(self, key) -> Blowfish: - return Blowfish.new( - key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07" - ) - - @staticmethod - def _generate_blowfish_key(track_id: str): - """Generate the blowfish key for Deezer downloads. - - :param track_id: - :type track_id: str - """ - SECRET = "g4el58wc0zvf9na1" - md5_hash = hashlib.md5(track_id.encode()).hexdigest() - # good luck :) - return "".join( - chr(functools.reduce(lambda x, y: x ^ y, map(ord, t))) - for t in zip(md5_hash[:16], md5_hash[16:], SECRET) - ).encode() - - @staticmethod - def _decrypt_chunk(key, data): - """Decrypt a chunk of a Deezer stream. - - :param key: - :param data: - """ - return Blowfish.new( - key, - Blowfish.MODE_CBC, - b"\x00\x01\x02\x03\x04\x05\x06\x07", - ).decrypt(data) - - def safe_get(d: dict, *keys: Hashable, default=None): """Traverse dict layers safely. @@ -567,9 +424,7 @@ def set_progress_bar_theme(theme: str): TQDM_BAR_FORMAT = TQDM_THEMES[theme] -def tqdm_stream( - iterator: DownloadStream, desc: Optional[str] = None -) -> Iterator[bytes]: +def tqdm_stream(iterator, desc: Optional[str] = None) -> Iterator[bytes]: """Return a tqdm bar with presets appropriate for downloading large files. :param iterator: @@ -578,15 +433,19 @@ def tqdm_stream( :type desc: Optional[str] :rtype: Iterator """ - with tqdm( - total=len(iterator), + with get_tqdm_bar(len(iterator), desc=desc) as bar: + for chunk in iterator: + bar.update(len(chunk)) + yield chunk + + +def get_tqdm_bar(total, desc: Optional[str] = None): + return tqdm( + total=total, unit="B", unit_scale=True, unit_divisor=1024, desc=desc, dynamic_ncols=True, bar_format=TQDM_BAR_FORMAT, - ) as bar: - for chunk in iterator: - bar.update(len(chunk)) - yield chunk + ) diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000..7475c1b --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,23 @@ +import os +import time +from pprint import pprint + +from streamrip.downloadtools import DownloadPool + + +def test_downloadpool(tmpdir): + start = time.perf_counter() + with DownloadPool( + ( + f"https://pokeapi.co/api/v2/pokemon/{number}" + for number in range(1, 151) + ), + tempdir=tmpdir, + ) as pool: + pool.download() + assert len(os.listdir(tmpdir)) == 151 + + # the tempfiles should be removed at this point + assert len(os.listdir(tmpdir)) == 0 + + print(f"Finished in {time.perf_counter() - start}s")