Add more type hints

This commit is contained in:
Andre Basche 2023-06-28 19:02:11 +02:00
parent ad0d065b03
commit 9eb99f283b
30 changed files with 392 additions and 243 deletions

View file

@ -1,19 +1,24 @@
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Callable, Dict
from typing import Dict, Any
import aiohttp
from pyhon import const
from pyhon.connection.handler.base import ConnectionHandler
from pyhon.typedefs import Callback
_LOGGER = logging.getLogger(__name__)
class HonAnonymousConnectionHandler(ConnectionHandler):
_HEADERS: Dict = ConnectionHandler._HEADERS | {"x-api-key": const.API_KEY}
_HEADERS: Dict[str, str] = ConnectionHandler._HEADERS | {"x-api-key": const.API_KEY}
@asynccontextmanager
async def _intercept(self, method: Callable, *args, **kwargs) -> AsyncIterator:
async def _intercept(
self, method: Callback, *args: Any, **kwargs: Any
) -> AsyncIterator[aiohttp.ClientResponse]:
kwargs["headers"] = kwargs.pop("headers", {}) | self._HEADERS
async with method(*args, **kwargs) as response:
if response.status == 403:

View file

@ -1,12 +1,13 @@
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Optional, Callable, List, Tuple
from typing import Optional, List, Tuple, Any
import aiohttp
from pyhon import const
from pyhon.connection.handler.base import ConnectionHandler
from pyhon.typedefs import Callback
_LOGGER = logging.getLogger(__name__)
@ -28,9 +29,9 @@ class HonAuthConnectionHandler(ConnectionHandler):
@asynccontextmanager
async def _intercept(
self, method: Callable, *args, loop: int = 0, **kwargs
) -> AsyncIterator:
self, method: Callback, *args: Any, **kwargs: Any
) -> AsyncIterator[aiohttp.ClientResponse]:
kwargs["headers"] = kwargs.pop("headers", {}) | self._HEADERS
async with method(*args, **kwargs) as response:
self._called_urls.append((response.status, response.request_info.url))
self._called_urls.append((response.status, str(response.request_info.url)))
yield response

View file

@ -1,18 +1,20 @@
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Optional, Callable, Dict
from types import TracebackType
from typing import Optional, Dict, Type, Any, Protocol
import aiohttp
from typing_extensions import Self
from pyhon import const, exceptions
from pyhon.typedefs import Callback
_LOGGER = logging.getLogger(__name__)
class ConnectionHandler:
_HEADERS: Dict = {
_HEADERS: Dict[str, str] = {
"user-agent": const.USER_AGENT,
"Content-Type": "application/json",
}
@ -24,32 +26,49 @@ class ConnectionHandler:
async def __aenter__(self) -> Self:
return await self.create()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()
@property
def session(self) -> aiohttp.ClientSession:
if self._session is None:
raise exceptions.NoSessionException
return self._session
async def create(self) -> Self:
if self._create_session:
self._session = aiohttp.ClientSession()
return self
@asynccontextmanager
def _intercept(self, method: Callable, *args, loop: int = 0, **kwargs):
def _intercept(
self, method: Callback, *args: Any, loop: int = 0, **kwargs: Any
) -> AsyncIterator[aiohttp.ClientResponse]:
raise NotImplementedError
@asynccontextmanager
async def get(self, *args, **kwargs) -> AsyncIterator[aiohttp.ClientResponse]:
async def get(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[aiohttp.ClientResponse]:
if self._session is None:
raise exceptions.NoSessionException()
response: aiohttp.ClientResponse
async with self._intercept(self._session.get, *args, **kwargs) as response:
async with self._intercept(self._session.get, *args, **kwargs) as response: # type: ignore[arg-type]
yield response
@asynccontextmanager
async def post(self, *args, **kwargs) -> AsyncIterator[aiohttp.ClientResponse]:
async def post(
self, *args: Any, **kwargs: Any
) -> AsyncIterator[aiohttp.ClientResponse]:
if self._session is None:
raise exceptions.NoSessionException()
response: aiohttp.ClientResponse
async with self._intercept(self._session.post, *args, **kwargs) as response:
async with self._intercept(self._session.post, *args, **kwargs) as response: # type: ignore[arg-type]
yield response
async def close(self) -> None:

View file

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Optional, Callable, Dict
from typing import Optional, Dict, Any
import aiohttp
from typing_extensions import Self
@ -11,6 +11,7 @@ from pyhon.connection.auth import HonAuth
from pyhon.connection.device import HonDevice
from pyhon.connection.handler.base import ConnectionHandler
from pyhon.exceptions import HonAuthenticationError, NoAuthenticationException
from pyhon.typedefs import Callback
_LOGGER = logging.getLogger(__name__)
@ -41,10 +42,10 @@ class HonConnectionHandler(ConnectionHandler):
async def create(self) -> Self:
await super().create()
self._auth = HonAuth(self._session, self._email, self._password, self._device)
self._auth = HonAuth(self.session, self._email, self._password, self._device)
return self
async def _check_headers(self, headers: Dict) -> Dict:
async def _check_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
if not (self.auth.cognito_token and self.auth.id_token):
await self.auth.authenticate()
headers["cognito-token"] = self.auth.cognito_token
@ -53,18 +54,16 @@ class HonConnectionHandler(ConnectionHandler):
@asynccontextmanager
async def _intercept(
self, method: Callable, *args, loop: int = 0, **kwargs
) -> AsyncIterator:
self, method: Callback, *args: Any, loop: int = 0, **kwargs: Dict[str, str]
) -> AsyncIterator[aiohttp.ClientResponse]:
kwargs["headers"] = await self._check_headers(kwargs.get("headers", {}))
async with method(*args, **kwargs) as response:
async with method(args[0], *args[1:], **kwargs) as response:
if (
self.auth.token_expires_soon or response.status in [401, 403]
) and loop == 0:
_LOGGER.info("Try refreshing token...")
await self.auth.refresh()
async with self._intercept(
method, *args, loop=loop + 1, **kwargs
) as result:
async with self._intercept(method, loop=loop + 1, **kwargs) as result:
yield result
elif (
self.auth.token_is_expired or response.status in [401, 403]
@ -76,9 +75,7 @@ class HonConnectionHandler(ConnectionHandler):
await response.text(),
)
await self.create()
async with self._intercept(
method, *args, loop=loop + 1, **kwargs
) as result:
async with self._intercept(method, loop=loop + 1, **kwargs) as result:
yield result
elif loop >= 2:
_LOGGER.error(