diff --git a/streamrip/clients.py b/streamrip/clients.py index 79ae50c..4f11bbe 100644 --- a/streamrip/clients.py +++ b/streamrip/clients.py @@ -26,7 +26,7 @@ from .exceptions import ( InvalidQuality, ) from .spoofbuz import Spoofer -from .utils import get_quality +from .utils import gen_threadsafe_session, get_quality urllib3.disable_warnings() requests.adapters.DEFAULT_RETRIES = 5 @@ -149,15 +149,8 @@ class QobuzClient(ClientInterface): self.app_id = str(kwargs["app_id"]) # Ensure it is a string self.secrets = kwargs["secrets"] - self.session = requests.Session() - # for multithreading - adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100) - self.session.mount("https://", adapter) - self.session.headers.update( - { - "User-Agent": AGENT, - "X-App-Id": self.app_id, - } + self.session = gen_threadsafe_session( + headers={"User-Agent": AGENT, "X-App-Id": self.app_id} ) self._api_login(email, pwd) @@ -376,10 +369,7 @@ class DeezerClient(ClientInterface): max_quality = 2 def __init__(self): - self.session = requests.Session() - # for multithreading - adapter = requests.adapters.HTTPAdapter(pool_connections=300, pool_maxsize=300) - self.session.mount("https://", adapter) + self.session = gen_threadsafe_session() # no login required self.logged_in = True @@ -452,10 +442,7 @@ class TidalClient(ClientInterface): self.refresh_token = None self.expiry = None - self.session = requests.Session() - # for multithreading - adapter = requests.adapters.HTTPAdapter(pool_connections=200, pool_maxsize=200) - self.session.mount("https://", adapter) + self.session = gen_threadsafe_session() def login( self, @@ -598,7 +585,9 @@ class TidalClient(ClientInterface): headers = { "authorization": f"Bearer {token}", } - r = requests.get("https://api.tidal.com/v1/sessions", headers=headers).json() + r = self.session.get( + "https://api.tidal.com/v1/sessions", headers=headers + ).json() if r.status != 200: raise Exception("Login failed") @@ -627,8 +616,10 @@ class TidalClient(ClientInterface): self._update_authorization() def _login_by_access_token(self, token, user_id=None): - headers = {"authorization": f"Bearer {token}"} - resp = requests.get("https://api.tidal.com/v1/sessions", headers=headers).json() + headers = {"authorization": f"Bearer {token}"} # temporary + resp = self.session.get( + "https://api.tidal.com/v1/sessions", headers=headers + ).json() if resp.get("status", 200) != 200: raise Exception(f"Login failed {resp}") @@ -666,14 +657,13 @@ class TidalClient(ClientInterface): if params is None: params = {} - headers = {"authorization": f"Bearer {self.access_token}"} params["countryCode"] = self.country_code params["limit"] = 100 - r = requests.get(f"{TIDAL_BASE}/{path}", headers=headers, params=params).json() + r = self.session.get(f"{TIDAL_BASE}/{path}", params=params).json() return r def _api_post(self, url, data, auth=None): - r = requests.post(url, data=data, auth=auth, verify=False).json() + r = self.session.post(url, data=data, auth=auth, verify=False).json() return r def _update_authorization(self): @@ -685,6 +675,9 @@ class SoundCloudClient(ClientInterface): max_quality = 0 logged_in = True + def __init__(self): + self.session = gen_threadsafe_session(headers={"User-Agent": AGENT}) + def login(self): raise NotImplementedError @@ -736,7 +729,7 @@ class SoundCloudClient(ClientInterface): url = f"{SOUNDCLOUD_BASE}/{path}" logger.debug(f"Fetching url {url}") - r = requests.get(url, params=params) + r = self.session.get(url, params=params) if resp_obj: return r diff --git a/streamrip/core.py b/streamrip/core.py index 669ce67..63d3846 100644 --- a/streamrip/core.py +++ b/streamrip/core.py @@ -173,11 +173,13 @@ class MusicDL(list): } logger.debug("Arguments from config: %s", arguments) - source_subdirs = self.config.session['downloads']['source_subdirectories'] + source_subdirs = self.config.session["downloads"]["source_subdirectories"] for item in self: if source_subdirs: - arguments['parent_folder'] = self.__get_source_subdir(item.client.source) + arguments["parent_folder"] = self.__get_source_subdir( + item.client.source + ) arguments["quality"] = self.config.session[item.client.source]["quality"] if isinstance(item, Artist): @@ -486,5 +488,5 @@ class MusicDL(list): return playlist_title, info def __get_source_subdir(self, source: str) -> str: - path = self.config.session['downloads']['folder'] + path = self.config.session["downloads"]["folder"] return os.path.join(path, capitalize(source)) diff --git a/streamrip/downloader.py b/streamrip/downloader.py index c383ac1..4cd9f37 100644 --- a/streamrip/downloader.py +++ b/streamrip/downloader.py @@ -230,8 +230,6 @@ class Track: self.sampling_rate = dl_info.get("sampling_rate") self.bit_depth = dl_info.get("bit_depth") - # click.secho(f"\nDownloading {self!s}", fg="blue") - # --------- Download Track ---------- if self.client.source in ("qobuz", "tidal"): logger.debug("Downloadable URL found: %s", dl_info.get("url")) @@ -737,7 +735,7 @@ class Tracklist(list): def download_message(self): click.secho( - f"\nDownloading {self.title} ({self.__class__.__name__})\n", + f"\n\nDownloading {self.title} ({self.__class__.__name__})\n", fg="blue", ) diff --git a/streamrip/utils.py b/streamrip/utils.py index ca810e7..8bc6fd9 100644 --- a/streamrip/utils.py +++ b/streamrip/utils.py @@ -17,10 +17,6 @@ from .exceptions import InvalidSourceError, NonStreamable urllib3.disable_warnings() logger = logging.getLogger(__name__) -session = requests.Session() -adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100) -session.mount("https://", adapter) - def safe_get(d: dict, *keys: Hashable, default=None): """A replacement for chained `get()` statements on dicts: @@ -113,6 +109,7 @@ def tqdm_download(url: str, filepath: str, params: dict = None): if params is None: params = {} + session = gen_threadsafe_session() r = session.get(url, allow_redirects=True, stream=True, params=params) total = int(r.headers.get("content-length", 0)) logger.debug(f"File size = {total}") @@ -221,3 +218,16 @@ def ext(quality: int, source: str): return ".mp3" else: return ".flac" + + +def gen_threadsafe_session( + headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100 +) -> requests.Session: + if headers is None: + headers = {} + + session = requests.Session() + adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100) + session.mount("https://", adapter) + session.headers.update(headers) + return session