From ef3475604670aa9f4b3e05303857e8cee1485326 Mon Sep 17 00:00:00 2001 From: Nathan Thomas Date: Tue, 20 Jun 2023 20:35:31 -0700 Subject: [PATCH] Create rate limited requests session --- streamrip/clients.py | 24 +++++++++---------- streamrip/utils.py | 56 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/streamrip/clients.py b/streamrip/clients.py index dcc1ed4..ccc0cfc 100644 --- a/streamrip/clients.py +++ b/streamrip/clients.py @@ -40,7 +40,7 @@ from .exceptions import ( NonStreamable, ) from .spoofbuz import Spoofer -from .utils import gen_threadsafe_session, get_quality, safe_get +from .utils import SRSession, gen_threadsafe_session, get_quality, safe_get logger = logging.getLogger("streamrip") @@ -134,7 +134,7 @@ class QobuzClient(Client): str(kwargs["app_id"]), kwargs["secrets"], ) - self.session = gen_threadsafe_session( + self.session = SRSession( headers={"User-Agent": AGENT, "X-App-Id": self.app_id} ) self._validate_secrets() @@ -223,7 +223,7 @@ class QobuzClient(Client): if not hasattr(self, "sec"): if not hasattr(self, "session"): - self.session = gen_threadsafe_session( + self.session = SRSession( headers={"User-Agent": AGENT, "X-App-Id": self.app_id} ) self._validate_secrets() @@ -343,7 +343,9 @@ class QobuzClient(Client): return self._gen_pages(epoint, params) - def _api_login(self, use_auth_token: bool, email_or_userid: str, password_or_token: str): + def _api_login( + self, use_auth_token: bool, email_or_userid: str, password_or_token: str + ): """Log into the api to get the user authentication token. :param use_auth_token: @@ -380,7 +382,7 @@ class QobuzClient(Client): raise IneligibleError("Free accounts are not eligible to download tracks.") self.uat = resp["user_auth_token"] - self.session.headers.update({"X-User-Auth-Token": self.uat}) + self.session.update_headers({"X-User-Auth-Token": self.uat}) self.label = resp["user"]["credential"]["parameters"]["short_label"] def _api_get_file_url( @@ -472,7 +474,6 @@ class DeezerClient(Client): def __init__(self): """Create a DeezerClient.""" self.client = deezer.Deezer() - # self.session = gen_threadsafe_session() # no login required self.logged_in = False @@ -645,7 +646,7 @@ class DeezloaderClient(Client): def __init__(self): """Create a DeezloaderClient.""" - self.session = gen_threadsafe_session() + self.session = SRSession() # no login required self.logged_in = True @@ -735,7 +736,7 @@ class TidalClient(Client): self.refresh_token = None self.expiry = None - self.session = gen_threadsafe_session() + self.session = SRSession() def login( self, @@ -994,7 +995,7 @@ class TidalClient(Client): def _update_authorization(self): """Update the requests session headers with the auth token.""" - self.session.headers.update(self.authorization) + self.session.update_headers(self.authorization) @property def authorization(self): @@ -1094,8 +1095,7 @@ class TidalClient(Client): :param data: :param auth: """ - r = self.session.post(url, data=data, auth=auth, verify=False).json() - return r + return self.session.post(url, data=data, auth=auth, verify=False).json() class SoundCloudClient(Client): @@ -1110,7 +1110,7 @@ class SoundCloudClient(Client): def __init__(self): """Create a SoundCloudClient.""" - self.session = gen_threadsafe_session( + self.session = SRSession( headers={ "User-Agent": AGENT, } diff --git a/streamrip/utils.py b/streamrip/utils.py index e3b43f1..9bda814 100644 --- a/streamrip/utils.py +++ b/streamrip/utils.py @@ -9,6 +9,8 @@ import os import shutil import subprocess import tempfile +import time +from multiprocessing import Lock from string import Formatter from typing import Dict, Hashable, Iterator, List, Optional, Tuple, Union @@ -307,6 +309,56 @@ def ext(quality: int, source: str): return ".flac" +class SRSession: + # requests per minute + PERIOD = 60.0 + + def __init__( + self, + headers: Optional[dict] = None, + pool_connections: int = 100, + pool_maxsize: int = 100, + requests_per_min: Optional[int] = None, + ): + + if headers is None: + headers = {} + + self.session = requests.Session() + adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize) + self.session.mount("https://", adapter) + self.session.headers.update(headers) + self.has_rate_limit = requests_per_min is not None + self.rpm = requests_per_min + + self.last_minute: float = time.time() + self.call_no: int = 0 + self.rate_limit_lock = Lock() if self.has_rate_limit else None + + def get(self, *args, **kwargs): + if self.has_rate_limit: # only use locks if there is a rate limit + assert self.rate_limit_lock is not None + assert self.rpm is not None + with self.rate_limit_lock: + now = time.time() + if self.call_no >= self.rpm: + if now - self.last_minute < SRSession.PERIOD: + time.sleep(SRSession.PERIOD - (now - self.last_minute)) + self.last_minute = time.time() + self.call_no = 0 + + self.call_no += 1 + + return self.session.get(*args, **kwargs) + + def update_headers(self, headers: dict): + self.session.headers.update(headers) + + # No rate limit on post + def post(self, *args, **kwargs) -> requests.Response: + self.session.post(*args, **kwargs) + + def gen_threadsafe_session( headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100 ) -> requests.Session: @@ -324,7 +376,7 @@ def gen_threadsafe_session( headers = {} session = requests.Session() - adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100) + adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize) session.mount("https://", adapter) session.headers.update(headers) return session @@ -373,7 +425,7 @@ def get_cover_urls(resp: dict, source: str) -> Optional[dict]: if source == "qobuz": cover_urls = resp["image"] - cover_urls["original"] = "org".join(cover_urls["large"].rsplit('600', 1)) + cover_urls["original"] = "org".join(cover_urls["large"].rsplit("600", 1)) return cover_urls if source == "tidal":