mirror of
https://github.com/nathom/streamrip.git
synced 2025-05-09 14:11:55 -04:00
Create rate limited requests session
This commit is contained in:
parent
fd353d57cc
commit
ef34756046
2 changed files with 66 additions and 14 deletions
|
@ -40,7 +40,7 @@ from .exceptions import (
|
|||
NonStreamable,
|
||||
)
|
||||
from .spoofbuz import Spoofer
|
||||
from .utils import gen_threadsafe_session, get_quality, safe_get
|
||||
from .utils import SRSession, gen_threadsafe_session, get_quality, safe_get
|
||||
|
||||
logger = logging.getLogger("streamrip")
|
||||
|
||||
|
@ -134,7 +134,7 @@ class QobuzClient(Client):
|
|||
str(kwargs["app_id"]),
|
||||
kwargs["secrets"],
|
||||
)
|
||||
self.session = gen_threadsafe_session(
|
||||
self.session = SRSession(
|
||||
headers={"User-Agent": AGENT, "X-App-Id": self.app_id}
|
||||
)
|
||||
self._validate_secrets()
|
||||
|
@ -223,7 +223,7 @@ class QobuzClient(Client):
|
|||
|
||||
if not hasattr(self, "sec"):
|
||||
if not hasattr(self, "session"):
|
||||
self.session = gen_threadsafe_session(
|
||||
self.session = SRSession(
|
||||
headers={"User-Agent": AGENT, "X-App-Id": self.app_id}
|
||||
)
|
||||
self._validate_secrets()
|
||||
|
@ -343,7 +343,9 @@ class QobuzClient(Client):
|
|||
|
||||
return self._gen_pages(epoint, params)
|
||||
|
||||
def _api_login(self, use_auth_token: bool, email_or_userid: str, password_or_token: str):
|
||||
def _api_login(
|
||||
self, use_auth_token: bool, email_or_userid: str, password_or_token: str
|
||||
):
|
||||
"""Log into the api to get the user authentication token.
|
||||
|
||||
:param use_auth_token:
|
||||
|
@ -380,7 +382,7 @@ class QobuzClient(Client):
|
|||
raise IneligibleError("Free accounts are not eligible to download tracks.")
|
||||
|
||||
self.uat = resp["user_auth_token"]
|
||||
self.session.headers.update({"X-User-Auth-Token": self.uat})
|
||||
self.session.update_headers({"X-User-Auth-Token": self.uat})
|
||||
self.label = resp["user"]["credential"]["parameters"]["short_label"]
|
||||
|
||||
def _api_get_file_url(
|
||||
|
@ -472,7 +474,6 @@ class DeezerClient(Client):
|
|||
def __init__(self):
|
||||
"""Create a DeezerClient."""
|
||||
self.client = deezer.Deezer()
|
||||
# self.session = gen_threadsafe_session()
|
||||
|
||||
# no login required
|
||||
self.logged_in = False
|
||||
|
@ -645,7 +646,7 @@ class DeezloaderClient(Client):
|
|||
|
||||
def __init__(self):
|
||||
"""Create a DeezloaderClient."""
|
||||
self.session = gen_threadsafe_session()
|
||||
self.session = SRSession()
|
||||
|
||||
# no login required
|
||||
self.logged_in = True
|
||||
|
@ -735,7 +736,7 @@ class TidalClient(Client):
|
|||
self.refresh_token = None
|
||||
self.expiry = None
|
||||
|
||||
self.session = gen_threadsafe_session()
|
||||
self.session = SRSession()
|
||||
|
||||
def login(
|
||||
self,
|
||||
|
@ -994,7 +995,7 @@ class TidalClient(Client):
|
|||
|
||||
def _update_authorization(self):
|
||||
"""Update the requests session headers with the auth token."""
|
||||
self.session.headers.update(self.authorization)
|
||||
self.session.update_headers(self.authorization)
|
||||
|
||||
@property
|
||||
def authorization(self):
|
||||
|
@ -1094,8 +1095,7 @@ class TidalClient(Client):
|
|||
:param data:
|
||||
:param auth:
|
||||
"""
|
||||
r = self.session.post(url, data=data, auth=auth, verify=False).json()
|
||||
return r
|
||||
return self.session.post(url, data=data, auth=auth, verify=False).json()
|
||||
|
||||
|
||||
class SoundCloudClient(Client):
|
||||
|
@ -1110,7 +1110,7 @@ class SoundCloudClient(Client):
|
|||
|
||||
def __init__(self):
|
||||
"""Create a SoundCloudClient."""
|
||||
self.session = gen_threadsafe_session(
|
||||
self.session = SRSession(
|
||||
headers={
|
||||
"User-Agent": AGENT,
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from multiprocessing import Lock
|
||||
from string import Formatter
|
||||
from typing import Dict, Hashable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -307,6 +309,56 @@ def ext(quality: int, source: str):
|
|||
return ".flac"
|
||||
|
||||
|
||||
class SRSession:
|
||||
# requests per minute
|
||||
PERIOD = 60.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: Optional[dict] = None,
|
||||
pool_connections: int = 100,
|
||||
pool_maxsize: int = 100,
|
||||
requests_per_min: Optional[int] = None,
|
||||
):
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
self.session = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize)
|
||||
self.session.mount("https://", adapter)
|
||||
self.session.headers.update(headers)
|
||||
self.has_rate_limit = requests_per_min is not None
|
||||
self.rpm = requests_per_min
|
||||
|
||||
self.last_minute: float = time.time()
|
||||
self.call_no: int = 0
|
||||
self.rate_limit_lock = Lock() if self.has_rate_limit else None
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
if self.has_rate_limit: # only use locks if there is a rate limit
|
||||
assert self.rate_limit_lock is not None
|
||||
assert self.rpm is not None
|
||||
with self.rate_limit_lock:
|
||||
now = time.time()
|
||||
if self.call_no >= self.rpm:
|
||||
if now - self.last_minute < SRSession.PERIOD:
|
||||
time.sleep(SRSession.PERIOD - (now - self.last_minute))
|
||||
self.last_minute = time.time()
|
||||
self.call_no = 0
|
||||
|
||||
self.call_no += 1
|
||||
|
||||
return self.session.get(*args, **kwargs)
|
||||
|
||||
def update_headers(self, headers: dict):
|
||||
self.session.headers.update(headers)
|
||||
|
||||
# No rate limit on post
|
||||
def post(self, *args, **kwargs) -> requests.Response:
|
||||
self.session.post(*args, **kwargs)
|
||||
|
||||
|
||||
def gen_threadsafe_session(
|
||||
headers: dict = None, pool_connections: int = 100, pool_maxsize: int = 100
|
||||
) -> requests.Session:
|
||||
|
@ -324,7 +376,7 @@ def gen_threadsafe_session(
|
|||
headers = {}
|
||||
|
||||
session = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100)
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections, pool_maxsize)
|
||||
session.mount("https://", adapter)
|
||||
session.headers.update(headers)
|
||||
return session
|
||||
|
@ -373,7 +425,7 @@ def get_cover_urls(resp: dict, source: str) -> Optional[dict]:
|
|||
|
||||
if source == "qobuz":
|
||||
cover_urls = resp["image"]
|
||||
cover_urls["original"] = "org".join(cover_urls["large"].rsplit('600', 1))
|
||||
cover_urls["original"] = "org".join(cover_urls["large"].rsplit("600", 1))
|
||||
return cover_urls
|
||||
|
||||
if source == "tidal":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue