mirror of
https://github.com/ihabunek/toot.git
synced 2024-09-22 04:25:55 -04:00
aiohttp poc
This commit is contained in:
parent
b97a995dc4
commit
4f508bd26a
76
toot/ahttp.py
Normal file
76
toot/ahttp.py
Normal file
@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Tuple
|
||||
from aiohttp import (
|
||||
ClientResponse,
|
||||
ClientSession,
|
||||
TraceConfig,
|
||||
TraceRequestEndParams,
|
||||
TraceRequestStartParams,
|
||||
)
|
||||
from toot import __version__
|
||||
from toot.cli import Context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def make_session(context: Context) -> ClientSession:
|
||||
base_url = context.app.base_url if context.app else None
|
||||
headers = {"User-Agent": f"toot/{__version__}"}
|
||||
|
||||
if context.user:
|
||||
headers["Authorization"] = f"Bearer {context.user.access_token}"
|
||||
|
||||
trace_config = logger_trace_config()
|
||||
|
||||
return ClientSession(
|
||||
headers=headers,
|
||||
base_url=base_url,
|
||||
trace_configs=[trace_config],
|
||||
)
|
||||
|
||||
|
||||
async def get_error(response: ClientResponse) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Attempt to extract the error and error description from response body.
|
||||
|
||||
See: https://docs.joinmastodon.org/entities/error/
|
||||
"""
|
||||
try:
|
||||
data = await response.json()
|
||||
return data.get("error"), data.get("error_description")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def logger_trace_config() -> TraceConfig:
|
||||
async def on_request_start(
|
||||
session: ClientSession,
|
||||
context: SimpleNamespace,
|
||||
params: TraceRequestStartParams,
|
||||
):
|
||||
context.start = asyncio.get_event_loop().time()
|
||||
logger.debug(f"--> {params.method} {params.url}")
|
||||
|
||||
async def on_request_end(
|
||||
session: ClientSession,
|
||||
context: SimpleNamespace,
|
||||
params: TraceRequestEndParams,
|
||||
):
|
||||
elapsed = round(1000 * (asyncio.get_event_loop().time() - context.start))
|
||||
logger.debug(
|
||||
f"<-- {params.method} {params.url} HTTP {params.response.status} {elapsed}ms"
|
||||
)
|
||||
|
||||
trace_config = TraceConfig()
|
||||
trace_config.on_request_start.append(on_request_start)
|
||||
trace_config.on_request_end.append(on_request_end)
|
||||
return trace_config
|
||||
|
||||
|
||||
async def verify_credentials(session: ClientSession) -> ClientResponse:
|
||||
return await session.get("/api/v1/accounts/verify_credentials")
|
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
@ -126,6 +128,22 @@ def pass_context(f: "t.Callable[te.Concatenate[Context, P], R]") -> "t.Callable[
|
||||
return wrapped
|
||||
|
||||
|
||||
def pass_session(f: "t.Callable[te.Concatenate[aiohttp.ClientSession, P], t.Awaitable[R]]") -> "t.Callable[P, t.Awaitable[R]]":
|
||||
"""Pass the toot Context as first argument."""
|
||||
from toot.ahttp import make_session
|
||||
|
||||
@wraps(f)
|
||||
async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> R:
|
||||
context = get_context()
|
||||
session = await make_session(context)
|
||||
try:
|
||||
return await f(session, *args, **kwargs)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def get_context() -> Context:
|
||||
click_context = click.get_current_context()
|
||||
obj: TootObj = click_context.obj
|
||||
@ -146,6 +164,15 @@ def get_context() -> Context:
|
||||
return Context(app, user, obj.color, obj.debug)
|
||||
|
||||
|
||||
def async_command(f: "t.Callable[P, t.Awaitable[R]]") -> "t.Callable[P, R]":
|
||||
# Integrating click with asyncio:
|
||||
# https://github.com/pallets/click/issues/85#issuecomment-503464628
|
||||
@wraps(f)
|
||||
def wrapper(*args: "P.args", **kwargs: "P.kwargs") -> R:
|
||||
return asyncio.run(f(*args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
|
||||
json_option = click.option(
|
||||
"--json",
|
||||
is_flag=True,
|
||||
|
@ -1,28 +1,30 @@
|
||||
from aiohttp import ClientSession
|
||||
import click
|
||||
import json as pyjson
|
||||
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
from toot import api
|
||||
from toot import api, ahttp
|
||||
from toot.cli.validators import validate_instance
|
||||
from toot.entities import Instance, Status, from_dict, Account
|
||||
from toot.entities import Instance, Status, from_dict, Account, from_response
|
||||
from toot.exceptions import ApiError, ConsoleError
|
||||
from toot.output import print_account, print_instance, print_search_results, print_status, print_timeline
|
||||
from toot.cli import InstanceParamType, cli, get_context, json_option, pass_context, Context
|
||||
from toot.cli import InstanceParamType, async_command, cli, get_context, json_option, pass_context, Context, pass_session
|
||||
|
||||
|
||||
@cli.command()
|
||||
@json_option
|
||||
@pass_context
|
||||
def whoami(ctx: Context, json: bool):
|
||||
@async_command
|
||||
@pass_session
|
||||
async def whoami(session: ClientSession, json: bool):
|
||||
"""Display logged in user details"""
|
||||
response = api.verify_credentials(ctx.app, ctx.user)
|
||||
response = await ahttp.verify_credentials(session)
|
||||
|
||||
if json:
|
||||
click.echo(response.text)
|
||||
click.echo(await response.text())
|
||||
else:
|
||||
account = from_dict(Account, response.json())
|
||||
account = await from_response(Account, response)
|
||||
print_account(account)
|
||||
|
||||
|
||||
|
@ -17,6 +17,8 @@ from functools import lru_cache
|
||||
from typing import Any, Dict, NamedTuple, Optional, Type, TypeVar, Union
|
||||
from typing import get_args, get_origin, get_type_hints
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from toot.utils import get_text
|
||||
from toot.utils.datetime import parse_datetime
|
||||
|
||||
@ -497,6 +499,10 @@ def from_dict(cls: Type[T], data: Data) -> T:
|
||||
return cls(**dict(_fields()))
|
||||
|
||||
|
||||
async def from_response(cls: Type[T], response: ClientResponse) -> T:
|
||||
return from_dict(cls, await response.json())
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_fields(cls: type) -> t.List[Field]:
|
||||
hints = get_type_hints(cls)
|
||||
|
Loading…
Reference in New Issue
Block a user