Source code for faust.livecheck.patches.aiohttp

"""LiveCheck :pypi:`aiohttp` integration."""
from contextlib import ExitStack
from types import SimpleNamespace
from typing import Any, List, Optional

import aiohttp
from aiohttp import web

from faust.livecheck.locals import current_test_stack
from faust.livecheck.models import TestExecution

__all__ = ['patch_all', 'patch_aiohttp_session', 'LiveCheckMiddleware']


[docs]def patch_all() -> None: """Patch all :pypi:`aiohttp` functions to integrate with LiveCheck.""" patch_aiohttp_session()
[docs]def patch_aiohttp_session() -> None: """Patch :class:`aiohttp.ClientSession` to integrate with LiveCheck. If there is any currently active test, we will use that to forward LiveCheck HTTP headers to the new HTTP request. """ from aiohttp import TraceConfig from aiohttp import client # monkeypatch to remove ridiculous "do not subclass" warning. def __init_subclass__() -> None: ... client.ClientSession.__init_subclass__ = __init_subclass__ async def _on_request_start( session: aiohttp.ClientSession, trace_config_ctx: SimpleNamespace, params: aiohttp.TraceRequestStartParams) -> None: test = current_test_stack.top if test is not None: params.headers.update(test.as_headers()) class ClientSession(client.ClientSession): def __init__(self, trace_configs: Optional[List[TraceConfig]] = None, **kwargs: Any) -> None: if trace_configs is None: trace_configs = [] trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(_on_request_start) trace_configs.append(trace_config) super().__init__(trace_configs=trace_configs, **kwargs) client.ClientSession = ClientSession
[docs]@web.middleware class LiveCheckMiddleware: """LiveCheck support for :pypi:`aiohttp` web servers. This middleware is applied to all incoming web requests, and is used to extract LiveCheck HTTP headers. If the web request is configured with the correct set of LiveCheck headers, we will use that to set the "current test" context. """ async def __call__(self, request: web.Request, handler: Any) -> Any: """Call to handle new web request.""" related_test = TestExecution.from_headers(request.headers) with ExitStack() as stack: if related_test: stack.enter_context(current_test_stack.push(related_test)) return await handler(request)