Add mypy checks

This commit is contained in:
Andre Basche 2023-04-15 15:55:22 +02:00
parent b6ca12ebff
commit f54b7b2dbf
6 changed files with 73 additions and 63 deletions

View file

@ -3,10 +3,11 @@ import logging
import re
import secrets
import urllib
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime, timedelta
from pprint import pformat
from typing import Dict, Optional
from typing import Dict, Optional, List
from urllib import parse
from urllib.parse import quote
@ -82,14 +83,15 @@ class HonAuth:
if fail:
raise exceptions.HonAuthenticationError("Can't login")
def _generate_nonce(self) -> str:
@staticmethod
def _generate_nonce() -> str:
nonce = secrets.token_hex(16)
return f"{nonce[:8]}-{nonce[8:12]}-{nonce[12:16]}-{nonce[16:20]}-{nonce[20:]}"
async def _load_login(self):
async def _load_login(self) -> bool:
login_url = await self._introduce()
login_url = await self._handle_redirects(login_url)
await self._login_url(login_url)
return await self._login_url(login_url)
async def _introduce(self) -> str:
redirect_uri = urllib.parse.quote(f"{const.APP}://mobilesdk/detect/oauth/done")
@ -101,8 +103,8 @@ class HonAuth:
"scope": "api openid refresh_token web",
"nonce": self._generate_nonce(),
}
params = "&".join([f"{k}={v}" for k, v in params.items()])
url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params}"
params_encode = "&".join([f"{k}={v}" for k, v in params.items()])
url = f"{const.AUTH_API}/services/oauth2/authorize/expid_Login?{params_encode}"
async with self._request.get(url) as response:
text = await response.text()
self._expires = datetime.utcnow()
@ -115,7 +117,7 @@ class HonAuth:
async def _manual_redirect(self, url: str) -> str:
async with self._request.get(url, allow_redirects=False) as response:
if not (new_location := response.headers.get("Location")):
if not (new_location := response.headers.get("Location", "")):
await self._error_logger(response)
return new_location
@ -138,11 +140,11 @@ class HonAuth:
)
return True
await self._error_logger(response)
return False
async def _login(self):
start_url = parse.unquote(self._login_data.url.split("startURL=")[-1]).split(
"%3D"
)[0]
async def _login(self) -> str:
start_url = self._login_data.url.rsplit("startURL=", maxsplit=1)[-1]
start_url = parse.unquote(start_url).split("%3D")[0]
action = {
"id": "79;a",
"descriptor": "apex://LightningLoginCustomController/ACTION$login",
@ -175,19 +177,13 @@ class HonAuth:
params=params,
) as response:
if response.status == 200:
try:
data = await response.json()
return data["events"][0]["attributes"]["values"]["url"]
except json.JSONDecodeError:
pass
except KeyError:
_LOGGER.error(
"Can't get login url - %s", pformat(await response.json())
)
with suppress(json.JSONDecodeError, KeyError):
result = await response.json()
return result["events"][0]["attributes"]["values"]["url"]
await self._error_logger(response)
return ""
def _parse_token_data(self, text):
def _parse_token_data(self, text: str) -> None:
if access_token := re.findall("access_token=(.*?)&", text):
self._access_token = access_token[0]
if refresh_token := re.findall("refresh_token=(.*?)&", text):
@ -195,22 +191,26 @@ class HonAuth:
if id_token := re.findall("id_token=(.*?)&", text):
self._id_token = id_token[0]
async def _get_token(self, url):
async def _get_token(self, url: str) -> bool:
async with self._request.get(url) as response:
if response.status != 200:
await self._error_logger(response)
return False
url = re.findall("href\\s*=\\s*[\"'](.+?)[\"']", await response.text())
if not url:
url_search = re.findall(
"href\\s*=\\s*[\"'](.+?)[\"']", await response.text()
)
if not url_search:
await self._error_logger(response)
return False
if "ProgressiveLogin" in url[0]:
async with self._request.get(url[0]) as response:
if "ProgressiveLogin" in url_search[0]:
async with self._request.get(url_search[0]) as response:
if response.status != 200:
await self._error_logger(response)
return False
url = re.findall("href\\s*=\\s*[\"'](.*?)[\"']", await response.text())
url = "/".join(const.AUTH_API.split("/")[:-1]) + url[0]
url_search = re.findall(
"href\\s*=\\s*[\"'](.*?)[\"']", await response.text()
)
url = "/".join(const.AUTH_API.split("/")[:-1]) + url_search[0]
async with self._request.get(url) as response:
if response.status != 200:
await self._error_logger(response)
@ -218,7 +218,7 @@ class HonAuth:
self._parse_token_data(await response.text())
return True
async def _api_auth(self):
async def _api_auth(self) -> bool:
post_headers = {"id-token": self._id_token}
data = self._device.get()
async with self._request.post(
@ -232,7 +232,7 @@ class HonAuth:
self._cognito_token = json_data["cognitoUser"]["Token"]
return True
async def authenticate(self):
async def authenticate(self) -> None:
self.clear()
try:
if not await self._load_login():
@ -246,7 +246,7 @@ class HonAuth:
except exceptions.HonNoAuthenticationNeeded:
return
async def refresh(self):
async def refresh(self) -> bool:
params = {
"client_id": const.CLIENT_ID,
"refresh_token": self._refresh_token,
@ -264,7 +264,7 @@ class HonAuth:
self._access_token = data["access_token"]
return await self._api_auth()
def clear(self):
def clear(self) -> None:
self._session.cookie_jar.clear_domain(const.AUTH_API.split("/")[-2])
self._request.called_urls = []
self._cognito_token = ""

View file

@ -1,7 +1,7 @@
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Optional, Callable, Dict
from typing import Optional, Callable, Dict, Any
import aiohttp
from typing_extensions import Self
@ -37,18 +37,18 @@ class ConnectionHandler:
raise NotImplementedError
@asynccontextmanager
async def get(self, *args, **kwargs) -> AsyncIterator[Callable]:
async def get(self, *args, **kwargs) -> AsyncIterator[aiohttp.ClientResponse]:
if self._session is None:
raise exceptions.NoSessionException()
response: Callable
response: aiohttp.ClientResponse
async with self._intercept(self._session.get, *args, **kwargs) as response:
yield response
@asynccontextmanager
async def post(self, *args, **kwargs) -> AsyncIterator[Callable]:
async def post(self, *args, **kwargs) -> AsyncIterator[aiohttp.ClientResponse]:
if self._session is None:
raise exceptions.NoSessionException()
response: Callable
response: aiohttp.ClientResponse
async with self._intercept(self._session.post, *args, **kwargs) as response:
yield response

View file

@ -10,7 +10,7 @@ from typing_extensions import Self
from pyhon.connection.auth import HonAuth
from pyhon.connection.device import HonDevice
from pyhon.connection.handler.base import ConnectionHandler
from pyhon.exceptions import HonAuthenticationError
from pyhon.exceptions import HonAuthenticationError, NoAuthenticationException
_LOGGER = logging.getLogger(__name__)
@ -30,7 +30,9 @@ class HonConnectionHandler(ConnectionHandler):
self._auth: Optional[HonAuth] = None
@property
def auth(self) -> Optional[HonAuth]:
def auth(self) -> HonAuth:
if self._auth is None:
raise NoAuthenticationException()
return self._auth
@property
@ -39,16 +41,14 @@ class HonConnectionHandler(ConnectionHandler):
async def create(self) -> Self:
await super().create()
self._auth: HonAuth = HonAuth(
self._session, self._email, self._password, self._device
)
self._auth = HonAuth(self._session, self._email, self._password, self._device)
return self
async def _check_headers(self, headers: Dict) -> Dict:
if not (self._auth.cognito_token and self._auth.id_token):
await self._auth.authenticate()
headers["cognito-token"] = self._auth.cognito_token
headers["id-token"] = self._auth.id_token
if not (self.auth.cognito_token and self.auth.id_token):
await self.auth.authenticate()
headers["cognito-token"] = self.auth.cognito_token
headers["id-token"] = self.auth.id_token
return self._HEADERS | headers
@asynccontextmanager
@ -58,16 +58,16 @@ class HonConnectionHandler(ConnectionHandler):
kwargs["headers"] = await self._check_headers(kwargs.get("headers", {}))
async with method(*args, **kwargs) as response:
if (
self._auth.token_expires_soon or response.status in [401, 403]
self.auth.token_expires_soon or response.status in [401, 403]
) and loop == 0:
_LOGGER.info("Try refreshing token...")
await self._auth.refresh()
await self.auth.refresh()
async with self._intercept(
method, *args, loop=loop + 1, **kwargs
) as result:
yield result
elif (
self._auth.token_is_expired or response.status in [401, 403]
self.auth.token_is_expired or response.status in [401, 403]
) and loop == 1:
_LOGGER.warning(
"%s - Error %s - %s",