diff --git a/toot/ahttp.py b/toot/ahttp.py new file mode 100644 index 0000000..1c416b3 --- /dev/null +++ b/toot/ahttp.py @@ -0,0 +1,76 @@ +import asyncio +import logging + +from types import SimpleNamespace +from typing import Optional, Tuple +from aiohttp import ( + ClientResponse, + ClientSession, + TraceConfig, + TraceRequestEndParams, + TraceRequestStartParams, +) +from toot import __version__ +from toot.cli import Context + + +logger = logging.getLogger(__name__) + + +async def make_session(context: Context) -> ClientSession: + base_url = context.app.base_url if context.app else None + headers = {"User-Agent": f"toot/{__version__}"} + + if context.user: + headers["Authorization"] = f"Bearer {context.user.access_token}" + + trace_config = logger_trace_config() + + return ClientSession( + headers=headers, + base_url=base_url, + trace_configs=[trace_config], + ) + + +async def get_error(response: ClientResponse) -> Tuple[Optional[str], Optional[str]]: + """Attempt to extract the error and error description from response body. + + See: https://docs.joinmastodon.org/entities/error/ + """ + try: + data = await response.json() + return data.get("error"), data.get("error_description") + except Exception: + pass + + return None, None + + +def logger_trace_config() -> TraceConfig: + async def on_request_start( + session: ClientSession, + context: SimpleNamespace, + params: TraceRequestStartParams, + ): + context.start = asyncio.get_event_loop().time() + logger.debug(f"--> {params.method} {params.url}") + + async def on_request_end( + session: ClientSession, + context: SimpleNamespace, + params: TraceRequestEndParams, + ): + elapsed = round(1000 * (asyncio.get_event_loop().time() - context.start)) + logger.debug( + f"<-- {params.method} {params.url} HTTP {params.response.status} {elapsed}ms" + ) + + trace_config = TraceConfig() + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + return trace_config + + +async def verify_credentials(session: ClientSession) -> ClientResponse: + return await session.get("/api/v1/accounts/verify_credentials") diff --git a/toot/cli/__init__.py b/toot/cli/__init__.py index a4698ff..f2826cc 100644 --- a/toot/cli/__init__.py +++ b/toot/cli/__init__.py @@ -1,3 +1,5 @@ +import asyncio +import aiohttp import click import logging import os @@ -126,6 +128,22 @@ def pass_context(f: "t.Callable[te.Concatenate[Context, P], R]") -> "t.Callable[ return wrapped +def pass_session(f: "t.Callable[te.Concatenate[aiohttp.ClientSession, P], t.Awaitable[R]]") -> "t.Callable[P, t.Awaitable[R]]": + """Pass the toot Context as first argument.""" + from toot.ahttp import make_session + + @wraps(f) + async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> R: + context = get_context() + session = await make_session(context) + try: + return await f(session, *args, **kwargs) + finally: + await session.close() + + return wrapped + + def get_context() -> Context: click_context = click.get_current_context() obj: TootObj = click_context.obj @@ -146,6 +164,15 @@ def get_context() -> Context: return Context(app, user, obj.color, obj.debug) +def async_command(f: "t.Callable[P, t.Awaitable[R]]") -> "t.Callable[P, R]": + # Integrating click with asyncio: + # https://github.com/pallets/click/issues/85#issuecomment-503464628 + @wraps(f) + def wrapper(*args: "P.args", **kwargs: "P.kwargs") -> R: + return asyncio.run(f(*args, **kwargs)) + + return wrapper + json_option = click.option( "--json", is_flag=True, diff --git a/toot/cli/read.py b/toot/cli/read.py index 32ce49a..9d2f00e 100644 --- a/toot/cli/read.py +++ b/toot/cli/read.py @@ -1,28 +1,30 @@ +from aiohttp import ClientSession import click import json as pyjson from itertools import chain from typing import Optional -from toot import api +from toot import api, ahttp from toot.cli.validators import validate_instance -from toot.entities import Instance, Status, from_dict, Account +from toot.entities import Instance, Status, from_dict, Account, from_response from toot.exceptions import ApiError, ConsoleError from toot.output import print_account, print_instance, print_search_results, print_status, print_timeline -from toot.cli import InstanceParamType, cli, get_context, json_option, pass_context, Context +from toot.cli import InstanceParamType, async_command, cli, get_context, json_option, pass_context, Context, pass_session @cli.command() @json_option -@pass_context -def whoami(ctx: Context, json: bool): +@async_command +@pass_session +async def whoami(session: ClientSession, json: bool): """Display logged in user details""" - response = api.verify_credentials(ctx.app, ctx.user) + response = await ahttp.verify_credentials(session) if json: - click.echo(response.text) + click.echo(await response.text()) else: - account = from_dict(Account, response.json()) + account = await from_response(Account, response) print_account(account) diff --git a/toot/entities.py b/toot/entities.py index 3563e12..e8ac49f 100644 --- a/toot/entities.py +++ b/toot/entities.py @@ -17,6 +17,8 @@ from functools import lru_cache from typing import Any, Dict, NamedTuple, Optional, Type, TypeVar, Union from typing import get_args, get_origin, get_type_hints +from aiohttp import ClientResponse + from toot.utils import get_text from toot.utils.datetime import parse_datetime @@ -497,6 +499,10 @@ def from_dict(cls: Type[T], data: Data) -> T: return cls(**dict(_fields())) +async def from_response(cls: Type[T], response: ClientResponse) -> T: + return from_dict(cls, await response.json()) + + @lru_cache def _get_fields(cls: type) -> t.List[Field]: hints = get_type_hints(cls)