Astroport.ONE/venv/lib/python3.11/site-packages/graphql/execution/map_async_iterator.py
2024-03-01 16:15:45 +01:00

116 lines
3.6 KiB
Python

from asyncio import CancelledError, Event, Task, ensure_future, wait
from concurrent.futures import FIRST_COMPLETED
from inspect import isasyncgen, isawaitable
from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union
from types import TracebackType
__all__ = ["MapAsyncIterator"]
# noinspection PyAttributeOutsideInit
class MapAsyncIterator:
"""Map an AsyncIterable over a callback function.
Given an AsyncIterable and a callback function, return an AsyncIterator which
produces values mapped via calling the callback function.
When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
be closed.
"""
def __init__(self, iterable: AsyncIterable, callback: Callable) -> None:
self.iterator = iterable.__aiter__()
self.callback = callback
self._close_event = Event()
def __aiter__(self) -> "MapAsyncIterator":
"""Get the iterator object."""
return self
async def __anext__(self) -> Any:
"""Get the next value of the iterator."""
if self.is_closed:
if not isasyncgen(self.iterator):
raise StopAsyncIteration
value = await self.iterator.__anext__()
else:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())
try:
pending: Set[Task] = (
await wait([aclose, anext], return_when=FIRST_COMPLETED)
)[1]
except CancelledError:
# cancel underlying tasks and close
aclose.cancel()
anext.cancel()
await self.aclose()
raise # re-raise the cancellation
for task in pending:
task.cancel()
if aclose.done():
raise StopAsyncIteration
error = anext.exception()
if error:
raise error
value = anext.result()
result = self.callback(value)
return await result if isawaitable(result) else result
async def athrow(
self,
type_: Union[BaseException, Type[BaseException]],
value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
"""Throw an exception into the asynchronous iterator."""
if self.is_closed:
return
athrow = getattr(self.iterator, "athrow", None)
if athrow:
await athrow(type_, value, traceback)
else:
await self.aclose()
if value is None:
if traceback is None:
raise type_
value = (
type_
if isinstance(value, BaseException)
else cast(Type[BaseException], type_)()
)
if traceback is not None:
value = value.with_traceback(traceback)
raise value
async def aclose(self) -> None:
"""Close the iterator."""
if not self.is_closed:
aclose = getattr(self.iterator, "aclose", None)
if aclose:
try:
await aclose()
except RuntimeError:
pass
self.is_closed = True
@property
def is_closed(self) -> bool:
"""Check whether the iterator is closed."""
return self._close_event.is_set()
@is_closed.setter
def is_closed(self, value: bool) -> None:
"""Mark the iterator as closed."""
if value:
self._close_event.set()
else:
self._close_event.clear()