mirror of
https://github.com/nathom/streamrip.git
synced 2025-05-24 20:14:42 -04:00
Create custom async downloader for HLS streams
This commit is contained in:
parent
5b2aaf5ad2
commit
9f5cd49aab
7 changed files with 364 additions and 184 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -19,3 +19,4 @@ StreamripDownloads
|
|||
*.pyc
|
||||
*test.py
|
||||
/.mypy_cache
|
||||
.DS_Store
|
||||
|
|
|
@ -42,3 +42,12 @@ ignore_missing_imports = True
|
|||
|
||||
[mypy-appdirs.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-m3u8.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-aiohttp.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-aiofiles.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
__version__ = "1.4"
|
||||
|
||||
from . import clients, constants, converter, media
|
||||
from . import clients, constants, converter, downloadtools, media
|
||||
|
|
241
streamrip/downloadtools.py
Normal file
241
streamrip/downloadtools.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from tempfile import gettempdir
|
||||
from typing import Callable, Dict, Generator, Iterator, List, Optional
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from Cryptodome.Cipher import Blowfish
|
||||
|
||||
from .exceptions import NonStreamable
|
||||
from .utils import gen_threadsafe_session
|
||||
|
||||
logger = logging.getLogger("streamrip")
|
||||
|
||||
|
||||
class DownloadStream:
|
||||
"""An iterator over chunks of a stream.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> stream = DownloadStream('https://google.com', None)
|
||||
>>> with open('google.html', 'wb') as file:
|
||||
>>> for chunk in stream:
|
||||
>>> file.write(chunk)
|
||||
|
||||
"""
|
||||
|
||||
is_encrypted = re.compile("/m(?:obile|edia)/")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
source: str = None,
|
||||
params: dict = None,
|
||||
headers: dict = None,
|
||||
item_id: str = None,
|
||||
):
|
||||
"""Create an iterable DownloadStream of a URL.
|
||||
|
||||
:param url: The url to download
|
||||
:type url: str
|
||||
:param source: Only applicable for Deezer
|
||||
:type source: str
|
||||
:param params: Parameters to pass in the request
|
||||
:type params: dict
|
||||
:param headers: Headers to pass in the request
|
||||
:type headers: dict
|
||||
:param item_id: (Only for Deezer) the ID of the track
|
||||
:type item_id: str
|
||||
"""
|
||||
self.source = source
|
||||
self.session = gen_threadsafe_session(headers=headers)
|
||||
|
||||
self.id = item_id
|
||||
if isinstance(self.id, int):
|
||||
self.id = str(self.id)
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
self.request = self.session.get(
|
||||
url, allow_redirects=True, stream=True, params=params
|
||||
)
|
||||
self.file_size = int(self.request.headers.get("Content-Length", 0))
|
||||
|
||||
if self.file_size < 20000 and not self.url.endswith(".jpg"):
|
||||
import json
|
||||
|
||||
try:
|
||||
info = self.request.json()
|
||||
try:
|
||||
# Usually happens with deezloader downloads
|
||||
raise NonStreamable(
|
||||
f"{info['error']} -- {info['message']}"
|
||||
)
|
||||
except KeyError:
|
||||
raise NonStreamable(info)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise NonStreamable("File not found.")
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""Iterate through chunks of the stream.
|
||||
|
||||
:rtype: Iterator
|
||||
"""
|
||||
if (
|
||||
self.source == "deezer"
|
||||
and self.is_encrypted.search(self.url) is not None
|
||||
):
|
||||
assert isinstance(self.id, str), self.id
|
||||
|
||||
blowfish_key = self._generate_blowfish_key(self.id)
|
||||
# decryptor = self._create_deezer_decryptor(blowfish_key)
|
||||
CHUNK_SIZE = 2048 * 3
|
||||
return (
|
||||
# (decryptor.decrypt(chunk[:2048]) + chunk[2048:])
|
||||
(
|
||||
self._decrypt_chunk(blowfish_key, chunk[:2048])
|
||||
+ chunk[2048:]
|
||||
)
|
||||
if len(chunk) >= 2048
|
||||
else chunk
|
||||
for chunk in self.request.iter_content(CHUNK_SIZE)
|
||||
)
|
||||
|
||||
return self.request.iter_content(chunk_size=1024)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
"""Return the requested url."""
|
||||
return self.request.url
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the value of the "Content-Length" header.
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
return self.file_size
|
||||
|
||||
def _create_deezer_decryptor(self, key) -> Blowfish:
|
||||
return Blowfish.new(
|
||||
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_blowfish_key(track_id: str):
|
||||
"""Generate the blowfish key for Deezer downloads.
|
||||
|
||||
:param track_id:
|
||||
:type track_id: str
|
||||
"""
|
||||
SECRET = "g4el58wc0zvf9na1"
|
||||
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
|
||||
# good luck :)
|
||||
return "".join(
|
||||
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
|
||||
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
|
||||
).encode()
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_chunk(key, data):
|
||||
"""Decrypt a chunk of a Deezer stream.
|
||||
|
||||
:param key:
|
||||
:param data:
|
||||
"""
|
||||
return Blowfish.new(
|
||||
key,
|
||||
Blowfish.MODE_CBC,
|
||||
b"\x00\x01\x02\x03\x04\x05\x06\x07",
|
||||
).decrypt(data)
|
||||
|
||||
|
||||
class DownloadPool:
|
||||
"""Asynchronously download a set of urls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls: Generator,
|
||||
tempdir: str = None,
|
||||
chunk_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.finished: bool = False
|
||||
# Enumerate urls to know the order
|
||||
self.urls = dict(enumerate(urls))
|
||||
self._downloaded_urls: List[str] = []
|
||||
# {url: path}
|
||||
self._paths: Dict[str, str] = {}
|
||||
self.task: Optional[asyncio.Task] = None
|
||||
self.callack = chunk_callback
|
||||
|
||||
if tempdir is None:
|
||||
tempdir = gettempdir()
|
||||
self.tempdir = tempdir
|
||||
|
||||
async def getfn(self, url):
|
||||
path = os.path.join(
|
||||
self.tempdir, f"__streamrip_partial_{abs(hash(url))}"
|
||||
)
|
||||
self._paths[url] = path
|
||||
return path
|
||||
|
||||
async def _download_urls(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [
|
||||
asyncio.ensure_future(self._download_url(session, url))
|
||||
for url in self.urls.values()
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _download_url(self, session, url):
|
||||
filename = await self.getfn(url)
|
||||
logger.debug("Downloading %s", url)
|
||||
async with session.get(url) as response, aiofiles.open(
|
||||
filename, "wb"
|
||||
) as f:
|
||||
# without aiofiles 3.6632679780000004s
|
||||
# with aiofiles 2.504482839s
|
||||
await f.write(await response.content.read())
|
||||
|
||||
if self.callback:
|
||||
self.callback()
|
||||
|
||||
logger.debug("Finished %s", url)
|
||||
|
||||
def download(self):
|
||||
asyncio.run(self._download_urls())
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
if len(self._paths) != len(self.urls):
|
||||
# Not all of them have downloaded
|
||||
raise Exception(
|
||||
"Must run DownloadPool.download() before accessing files"
|
||||
)
|
||||
|
||||
return [
|
||||
os.path.join(self.tempdir, self._paths[self.urls[i]])
|
||||
for i in range(len(self.urls))
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.urls)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logger.debug("Removing tempfiles %s", self._paths)
|
||||
for file in self._paths.values():
|
||||
try:
|
||||
os.remove(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return False
|
|
@ -39,6 +39,7 @@ from .constants import (
|
|||
FOLDER_FORMAT,
|
||||
TRACK_FORMAT,
|
||||
)
|
||||
from .downloadtools import DownloadPool, DownloadStream
|
||||
from .exceptions import (
|
||||
InvalidQuality,
|
||||
InvalidSourceError,
|
||||
|
@ -49,7 +50,6 @@ from .exceptions import (
|
|||
)
|
||||
from .metadata import TrackMetadata
|
||||
from .utils import (
|
||||
DownloadStream,
|
||||
clean_filename,
|
||||
clean_format,
|
||||
decrypt_mqa_file,
|
||||
|
@ -450,21 +450,44 @@ class Track(Media):
|
|||
"""
|
||||
logger.debug("dl_info: %s", dl_info)
|
||||
if dl_info["type"] == "mp3":
|
||||
import m3u8
|
||||
import requests
|
||||
|
||||
parsed_m3u = m3u8.loads(requests.get(dl_info["url"]).text)
|
||||
self.path += ".mp3"
|
||||
# convert hls stream to mp3
|
||||
|
||||
with DownloadPool(
|
||||
segment.uri for segment in parsed_m3u.segments
|
||||
) as pool:
|
||||
pool.download()
|
||||
subprocess.call(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
dl_info["url"],
|
||||
"-c",
|
||||
f"concat:{'|'.join(pool.files)}",
|
||||
"-acodec",
|
||||
"copy",
|
||||
"-y",
|
||||
self.path,
|
||||
"-loglevel",
|
||||
"fatal",
|
||||
"panic",
|
||||
self.path,
|
||||
]
|
||||
)
|
||||
|
||||
# self.path += ".mp3"
|
||||
# # convert hls stream to mp3
|
||||
# subprocess.call(
|
||||
# [
|
||||
# "ffmpeg",
|
||||
# "-i",
|
||||
# dl_info["url"],
|
||||
# "-c",
|
||||
# "copy",
|
||||
# "-y",
|
||||
# self.path,
|
||||
# "-loglevel",
|
||||
# "fatal",
|
||||
# ]
|
||||
# )
|
||||
elif dl_info["type"] == "original":
|
||||
_quick_download(
|
||||
dl_info["url"], self.path, desc=self._progress_desc
|
||||
|
@ -857,6 +880,9 @@ class Video(Media):
|
|||
|
||||
:param kwargs:
|
||||
"""
|
||||
import m3u8
|
||||
import requests
|
||||
|
||||
secho(
|
||||
f"Downloading {self.title} (Video). This may take a while.",
|
||||
fg="blue",
|
||||
|
@ -864,19 +890,41 @@ class Video(Media):
|
|||
|
||||
self.parent_folder = kwargs.get("parent_folder", "StreamripDownloads")
|
||||
url = self.client.get_file_url(self.id, video=True)
|
||||
# it's more convenient to have ffmpeg download the hls
|
||||
command = [
|
||||
|
||||
parsed_m3u = m3u8.loads(requests.get(url).text)
|
||||
# Asynchronously download the streams
|
||||
with DownloadPool(
|
||||
segment.uri for segment in parsed_m3u.segments
|
||||
) as pool:
|
||||
pool.download()
|
||||
|
||||
# Put the filenames in a tempfile that ffmpeg
|
||||
# can read from
|
||||
file_list_path = os.path.join(
|
||||
gettempdir(), "__streamrip_video_files"
|
||||
)
|
||||
with open(file_list_path, "w") as file_list:
|
||||
text = "\n".join(f"file '{path}'" for path in pool.files)
|
||||
file_list.write(text)
|
||||
|
||||
# Use ffmpeg to concat the files
|
||||
p = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"concat",
|
||||
"-safe",
|
||||
"0",
|
||||
"-i",
|
||||
url,
|
||||
file_list_path,
|
||||
"-c",
|
||||
"copy",
|
||||
"-loglevel",
|
||||
"panic",
|
||||
self.path,
|
||||
]
|
||||
p = subprocess.Popen(command)
|
||||
p.wait() # remove this?
|
||||
)
|
||||
p.wait()
|
||||
|
||||
os.remove(file_list_path)
|
||||
|
||||
def tag(self, *args, **kwargs):
|
||||
"""Return False.
|
||||
|
@ -1396,12 +1444,11 @@ class Tracklist(list):
|
|||
class Album(Tracklist, Media):
|
||||
"""Represents a downloadable album.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> resp = client.get('fleetwood mac rumours', 'album')
|
||||
>>> album = Album.from_api(resp['items'][0], client)
|
||||
>>> album.load_meta()
|
||||
>>> album.download()
|
||||
|
||||
"""
|
||||
|
||||
downloaded_ids: set = set()
|
||||
|
|
|
@ -3,166 +3,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from string import Formatter
|
||||
from typing import Dict, Hashable, Iterator, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
from click import secho, style
|
||||
from Cryptodome.Cipher import Blowfish
|
||||
from pathvalidate import sanitize_filename
|
||||
from requests.packages import urllib3
|
||||
from tqdm import tqdm
|
||||
|
||||
from .constants import COVER_SIZES, TIDAL_COVER_URL
|
||||
from .exceptions import InvalidQuality, InvalidSourceError, NonStreamable
|
||||
from .exceptions import InvalidQuality, InvalidSourceError
|
||||
|
||||
urllib3.disable_warnings()
|
||||
logger = logging.getLogger("streamrip")
|
||||
|
||||
|
||||
class DownloadStream:
|
||||
"""An iterator over chunks of a stream.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> stream = DownloadStream('https://google.com', None)
|
||||
>>> with open('google.html', 'wb') as file:
|
||||
>>> for chunk in stream:
|
||||
>>> file.write(chunk)
|
||||
|
||||
"""
|
||||
|
||||
is_encrypted = re.compile("/m(?:obile|edia)/")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
source: str = None,
|
||||
params: dict = None,
|
||||
headers: dict = None,
|
||||
item_id: str = None,
|
||||
):
|
||||
"""Create an iterable DownloadStream of a URL.
|
||||
|
||||
:param url: The url to download
|
||||
:type url: str
|
||||
:param source: Only applicable for Deezer
|
||||
:type source: str
|
||||
:param params: Parameters to pass in the request
|
||||
:type params: dict
|
||||
:param headers: Headers to pass in the request
|
||||
:type headers: dict
|
||||
:param item_id: (Only for Deezer) the ID of the track
|
||||
:type item_id: str
|
||||
"""
|
||||
self.source = source
|
||||
self.session = gen_threadsafe_session(headers=headers)
|
||||
|
||||
self.id = item_id
|
||||
if isinstance(self.id, int):
|
||||
self.id = str(self.id)
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
self.request = self.session.get(
|
||||
url, allow_redirects=True, stream=True, params=params
|
||||
)
|
||||
self.file_size = int(self.request.headers.get("Content-Length", 0))
|
||||
|
||||
if self.file_size < 20000 and not self.url.endswith(".jpg"):
|
||||
import json
|
||||
|
||||
try:
|
||||
info = self.request.json()
|
||||
try:
|
||||
# Usually happens with deezloader downloads
|
||||
raise NonStreamable(
|
||||
f"{info['error']} -- {info['message']}"
|
||||
)
|
||||
except KeyError:
|
||||
raise NonStreamable(info)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise NonStreamable("File not found.")
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""Iterate through chunks of the stream.
|
||||
|
||||
:rtype: Iterator
|
||||
"""
|
||||
if (
|
||||
self.source == "deezer"
|
||||
and self.is_encrypted.search(self.url) is not None
|
||||
):
|
||||
assert isinstance(self.id, str), self.id
|
||||
|
||||
blowfish_key = self._generate_blowfish_key(self.id)
|
||||
# decryptor = self._create_deezer_decryptor(blowfish_key)
|
||||
CHUNK_SIZE = 2048 * 3
|
||||
return (
|
||||
# (decryptor.decrypt(chunk[:2048]) + chunk[2048:])
|
||||
(
|
||||
self._decrypt_chunk(blowfish_key, chunk[:2048])
|
||||
+ chunk[2048:]
|
||||
)
|
||||
if len(chunk) >= 2048
|
||||
else chunk
|
||||
for chunk in self.request.iter_content(CHUNK_SIZE)
|
||||
)
|
||||
|
||||
return self.request.iter_content(chunk_size=1024)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
"""Return the requested url."""
|
||||
return self.request.url
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the value of the "Content-Length" header.
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
return self.file_size
|
||||
|
||||
def _create_deezer_decryptor(self, key) -> Blowfish:
|
||||
return Blowfish.new(
|
||||
key, Blowfish.MODE_CBC, b"\x00\x01\x02\x03\x04\x05\x06\x07"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_blowfish_key(track_id: str):
|
||||
"""Generate the blowfish key for Deezer downloads.
|
||||
|
||||
:param track_id:
|
||||
:type track_id: str
|
||||
"""
|
||||
SECRET = "g4el58wc0zvf9na1"
|
||||
md5_hash = hashlib.md5(track_id.encode()).hexdigest()
|
||||
# good luck :)
|
||||
return "".join(
|
||||
chr(functools.reduce(lambda x, y: x ^ y, map(ord, t)))
|
||||
for t in zip(md5_hash[:16], md5_hash[16:], SECRET)
|
||||
).encode()
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_chunk(key, data):
|
||||
"""Decrypt a chunk of a Deezer stream.
|
||||
|
||||
:param key:
|
||||
:param data:
|
||||
"""
|
||||
return Blowfish.new(
|
||||
key,
|
||||
Blowfish.MODE_CBC,
|
||||
b"\x00\x01\x02\x03\x04\x05\x06\x07",
|
||||
).decrypt(data)
|
||||
|
||||
|
||||
def safe_get(d: dict, *keys: Hashable, default=None):
|
||||
"""Traverse dict layers safely.
|
||||
|
||||
|
@ -567,9 +424,7 @@ def set_progress_bar_theme(theme: str):
|
|||
TQDM_BAR_FORMAT = TQDM_THEMES[theme]
|
||||
|
||||
|
||||
def tqdm_stream(
|
||||
iterator: DownloadStream, desc: Optional[str] = None
|
||||
) -> Iterator[bytes]:
|
||||
def tqdm_stream(iterator, desc: Optional[str] = None) -> Iterator[bytes]:
|
||||
"""Return a tqdm bar with presets appropriate for downloading large files.
|
||||
|
||||
:param iterator:
|
||||
|
@ -578,15 +433,19 @@ def tqdm_stream(
|
|||
:type desc: Optional[str]
|
||||
:rtype: Iterator
|
||||
"""
|
||||
with tqdm(
|
||||
total=len(iterator),
|
||||
with get_tqdm_bar(len(iterator), desc=desc) as bar:
|
||||
for chunk in iterator:
|
||||
bar.update(len(chunk))
|
||||
yield chunk
|
||||
|
||||
|
||||
def get_tqdm_bar(total, desc: Optional[str] = None):
|
||||
return tqdm(
|
||||
total=total,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
desc=desc,
|
||||
dynamic_ncols=True,
|
||||
bar_format=TQDM_BAR_FORMAT,
|
||||
) as bar:
|
||||
for chunk in iterator:
|
||||
bar.update(len(chunk))
|
||||
yield chunk
|
||||
)
|
||||
|
|
23
tests/test_download.py
Normal file
23
tests/test_download.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import os
|
||||
import time
|
||||
from pprint import pprint
|
||||
|
||||
from streamrip.downloadtools import DownloadPool
|
||||
|
||||
|
||||
def test_downloadpool(tmpdir):
|
||||
start = time.perf_counter()
|
||||
with DownloadPool(
|
||||
(
|
||||
f"https://pokeapi.co/api/v2/pokemon/{number}"
|
||||
for number in range(1, 151)
|
||||
),
|
||||
tempdir=tmpdir,
|
||||
) as pool:
|
||||
pool.download()
|
||||
assert len(os.listdir(tmpdir)) == 151
|
||||
|
||||
# the tempfiles should be removed at this point
|
||||
assert len(os.listdir(tmpdir)) == 0
|
||||
|
||||
print(f"Finished in {time.perf_counter() - start}s")
|
Loading…
Add table
Add a link
Reference in a new issue