Astroport.ONE/venv/lib/python3.11/site-packages/graphql/execution/middleware.py
2024-03-01 16:15:45 +01:00

64 lines
2.4 KiB
Python

from functools import partial, reduce
from inspect import isfunction
from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
__all__ = ["MiddlewareManager"]
GraphQLFieldResolver = Callable[..., Any]
class MiddlewareManager:
"""Manager for the middleware chain.
This class helps to wrap resolver functions with the provided middleware functions
and/or objects. The functions take the next middleware function as first argument.
If middleware is provided as an object, it must provide a method ``resolve`` that is
used as the middleware function.
Note that since resolvers return "AwaitableOrValue"s, all middleware functions
must be aware of this and check whether values are awaitable before awaiting them.
"""
# allow custom attributes (not used internally)
__slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
_middleware_resolvers: Optional[List[Callable]]
def __init__(self, *middlewares: Any):
self.middlewares = middlewares
self._middleware_resolvers = (
list(get_middleware_resolvers(middlewares)) if middlewares else None
)
self._cached_resolvers = {}
def get_field_resolver(
self, field_resolver: GraphQLFieldResolver
) -> GraphQLFieldResolver:
"""Wrap the provided resolver with the middleware.
Returns a function that chains the middleware functions with the provided
resolver function.
"""
if self._middleware_resolvers is None:
return field_resolver
if field_resolver not in self._cached_resolvers:
self._cached_resolvers[field_resolver] = reduce(
lambda chained_fns, next_fn: partial(next_fn, chained_fns),
self._middleware_resolvers,
field_resolver,
)
return self._cached_resolvers[field_resolver]
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
"""Get a list of resolver functions from a list of classes or functions."""
for middleware in middlewares:
if isfunction(middleware):
yield middleware
else: # middleware provided as object with 'resolve' method
resolver_func = getattr(middleware, "resolve", None)
if resolver_func is not None:
yield resolver_func