Create rate limited requests session

This commit is contained in:
Nathan Thomas 2023-06-20 20:35:31 -07:00
parent fd353d57cc
commit ef34756046
2 changed files with 66 additions and 14 deletions

View file

@ -40,7 +40,7 @@ from .exceptions import (
NonStreamable, NonStreamable,
) )
from .spoofbuz import Spoofer from .spoofbuz import Spoofer
from .utils import gen_threadsafe_session, get_quality, safe_get from .utils import SRSession, gen_threadsafe_session, get_quality, safe_get
logger = logging.getLogger("streamrip") logger = logging.getLogger("streamrip")
@ -134,7 +134,7 @@ class QobuzClient(Client):
str(kwargs["app_id"]), str(kwargs["app_id"]),
kwargs["secrets"], kwargs["secrets"],
) )
self.session = gen_threadsafe_session( self.session = SRSession(
headers={"User-Agent": AGENT, "X-App-Id": self.app_id} headers={"User-Agent": AGENT, "X-App-Id": self.app_id}
) )
self._validate_secrets() self._validate_secrets()
@ -223,7 +223,7 @@ class QobuzClient(Client):
if not hasattr(self, "sec"): if not hasattr(self, "sec"):
if not hasattr(self, "session"): if not hasattr(self, "session"):
self.session = gen_threadsafe_session( self.session = SRSession(
headers={"User-Agent": AGENT, "X-App-Id": self.app_id} headers={"User-Agent": AGENT, "X-App-Id": self.app_id}
) )
self._validate_secrets() self._validate_secrets()
@ -343,7 +343,9 @@ class QobuzClient(Client):
return self._gen_pages(epoint, params) return self._gen_pages(epoint, params)
def _api_login(self, use_auth_token: bool, email_or_userid: str, password_or_token: str): def _api_login(
self, use_auth_token: bool, email_or_userid: str, password_or_token: str
):
"""Log into the api to get the user authentication token. """Log into the api to get the user authentication token.
:param use_auth_token: :param use_auth_token:
@ -380,7 +382,7 @@ class QobuzClient(Client):
raise IneligibleError("Free accounts are not eligible to download tracks.") raise IneligibleError("Free accounts are not eligible to download tracks.")
self.uat = resp["user_auth_token"] self.uat = resp["user_auth_token"]
self.session.headers.update({"X-User-Auth-Token": self.uat}) self.session.update_headers({"X-User-Auth-Token": self.uat})
self.label = resp["user"]["credential"]["parameters"]["short_label"] self.label = resp["user"]["credential"]["parameters"]["short_label"]
def _api_get_file_url( def _api_get_file_url(
@ -472,7 +474,6 @@ class DeezerClient(Client):
def __init__(self): def __init__(self):
"""Create a DeezerClient.""" """Create a DeezerClient."""
self.client = deezer.Deezer() self.client = deezer.Deezer()
# self.session = gen_threadsafe_session()
# no login required # no login required
self.logged_in = False self.logged_in = False
@ -645,7 +646,7 @@ class DeezloaderClient(Client):
def __init__(self): def __init__(self):
"""Create a DeezloaderClient.""" """Create a DeezloaderClient."""
self.session = gen_threadsafe_session() self.session = SRSession()
# no login required # no login required
self.logged_in = True self.logged_in = True
@ -735,7 +736,7 @@ class TidalClient(Client):
self.refresh_token = None self.refresh_token = None
self.expiry = None self.expiry = None
self.session = gen_threadsafe_session() self.session = SRSession()
def login( def login(
self, self,
@ -994,7 +995,7 @@ class TidalClient(Client):
def _update_authorization(self): def _update_authorization(self):
"""Update the requests session headers with the auth token.""" """Update the requests session headers with the auth token."""
self.session.headers.update(self.authorization) self.session.update_headers(self.authorization)
@property @property
def authorization(self): def authorization(self):
@ -1094,8 +1095,7 @@ class TidalClient(Client):
:param data: :param data:
:param auth: :param auth:
""" """
r = self.session.post(url, data=data, auth=auth, verify=False).json() return self.session.post(url, data=data, auth=auth, verify=False).json()
return r
class SoundCloudClient(Client): class SoundCloudClient(Client):
@ -1110,7 +1110,7 @@ class SoundCloudClient(Client):
def __init__(self): def __init__(self):
"""Create a SoundCloudClient.""" """Create a SoundCloudClient."""
self.session = gen_threadsafe_session( self.session = SRSession(
headers={ headers={
"User-Agent": AGENT, "User-Agent": AGENT,
} }

View file

@ -9,6 +9,8 @@ import os
import shutil import shutil
import subprocess import subprocess
import tempfile import tempfile
import time
from multiprocessing import Lock
from string import Formatter from string import Formatter
from typing import Dict, Hashable, Iterator, List, Optional, Tuple, Union from typing import Dict, Hashable, Iterator, List, Optional, Tuple, Union
@ -307,6 +309,56 @@ def ext(quality: int, source: str):
return ".flac" return ".flac"
class SRSession:
# requests per minute
PERIOD = 60.0
def __init__(
self,
headers: Optional[dict] = None,
pool_connections: int = 100,
pool_maxsize: int = 100,
requests_per_min: Optional[int] = None,
):
if headers is None:
headers = {}
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize)
self.session.mount("https://", adapter)
self.session.headers.update(headers)
self.has_rate_limit = requests_per_min is not None
self.rpm = requests_per_min
self.last_minute: float = time.time()
self.call_no: int = 0
self.rate_limit_lock = Lock() if self.has_rate_limit else None
def get(self, *args, **kwargs):
if self.has_rate_limit: # only use locks if there is a rate limit
assert self.rate_limit_lock is not None
assert self.rpm is not None
with self.rate_limit_lock:
now = time.time()
if self.call_no >= self.rpm:
if now - self.last_minute < SRSession.PERIOD:
time.sleep(SRSession.PERIOD - (now - self.last_minute))
self.last_minute = time.time()
self.call_no = 0
self.call_no += 1
return self.session.get(*args, **kwargs)
def update_headers(self, headers: dict):
self.session.headers.update(headers)
# No rate limit on post
def post(self, *args, **kwargs) -> requests.Response:
self.session.post(*args, **kwargs)
def gen_threadsafe_session( def gen_threadsafe_session(
headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100 headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100
) -> requests.Session: ) -> requests.Session:
@ -324,7 +376,7 @@ def gen_threadsafe_session(
headers = {} headers = {}
session = requests.Session() session = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100) adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize)
session.mount("https://", adapter) session.mount("https://", adapter)
session.headers.update(headers) session.headers.update(headers)
return session return session
@ -373,7 +425,7 @@ def get_cover_urls(resp: dict, source: str) -> Optional[dict]:
if source == "qobuz": if source == "qobuz":
cover_urls = resp["image"] cover_urls = resp["image"]
cover_urls["original"] = "org".join(cover_urls["large"].rsplit('600', 1)) cover_urls["original"] = "org".join(cover_urls["large"].rsplit("600", 1))
return cover_urls return cover_urls
if source == "tidal": if source == "tidal":