from __future__ import annotations import array import math import socket import sys import types from collections.abc import AsyncIterator, Iterable from concurrent.futures import Future from dataclasses import dataclass from functools import partial from io import IOBase from os import PathLike from signal import Signals from socket import AddressFamily, SocketKind from types import TracebackType from typing import ( IO, Any, AsyncGenerator, Awaitable, Callable, Collection, ContextManager, Coroutine, Generic, Mapping, NoReturn, Sequence, TypeVar, cast, overload, ) import trio.from_thread import trio.lowlevel from outcome import Error, Outcome, Value from trio.lowlevel import ( current_root_task, current_task, wait_readable, wait_writable, ) from trio.socket import SocketType as TrioSocketType from trio.to_thread import run_sync from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc from .._core._eventloop import claim_worker_thread from .._core._exceptions import ( BrokenResourceError, BusyResourceError, ClosedResourceError, EndOfStream, ) from .._core._sockets import convert_ipv6_sockaddr from .._core._streams import create_memory_object_stream from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter from .._core._synchronization import Event as BaseEvent from .._core._synchronization import ResourceGuard from .._core._tasks import CancelScope as BaseCancelScope from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType from ..abc._eventloop import AsyncBackend from ..streams.memory import MemoryObjectSendStream if sys.version_info >= (3, 10): from typing import ParamSpec else: from typing_extensions import ParamSpec if sys.version_info >= (3, 11): from typing import TypeVarTuple, Unpack else: from exceptiongroup import BaseExceptionGroup from typing_extensions import TypeVarTuple, Unpack T = TypeVar("T") T_Retval = TypeVar("T_Retval") T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) PosArgsT = TypeVarTuple("PosArgsT") P = ParamSpec("P") # # Event loop # RunVar = trio.lowlevel.RunVar # # Timeouts and cancellation # class CancelScope(BaseCancelScope): def __new__( cls, original: trio.CancelScope | None = None, **kwargs: object ) -> CancelScope: return object.__new__(cls) def __init__(self, original: trio.CancelScope | None = None, **kwargs: Any) -> None: self.__original = original or trio.CancelScope(**kwargs) def __enter__(self) -> CancelScope: self.__original.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: # https://github.com/python-trio/trio-typing/pull/79 return self.__original.__exit__(exc_type, exc_val, exc_tb) def cancel(self) -> None: self.__original.cancel() @property def deadline(self) -> float: return self.__original.deadline @deadline.setter def deadline(self, value: float) -> None: self.__original.deadline = value @property def cancel_called(self) -> bool: return self.__original.cancel_called @property def cancelled_caught(self) -> bool: return self.__original.cancelled_caught @property def shield(self) -> bool: return self.__original.shield @shield.setter def shield(self, value: bool) -> None: self.__original.shield = value # # Task groups # class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self._active = False self._nursery_manager = trio.open_nursery(strict_exception_groups=True) self.cancel_scope = None # type: ignore[assignment] async def __aenter__(self) -> TaskGroup: self._active = True self._nursery = await self._nursery_manager.__aenter__() self.cancel_scope = CancelScope(self._nursery.cancel_scope) return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: try: return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) except BaseExceptionGroup as exc: _, rest = exc.split(trio.Cancelled) if not rest: cancelled_exc = trio.Cancelled._create() raise cancelled_exc from exc raise finally: self._active = False def start_soon( self, func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], *args: Unpack[PosArgsT], name: object = None, ) -> None: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." ) self._nursery.start_soon(func, *args, name=name) async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None ) -> Any: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." ) return await self._nursery.start(func, *args, name=name) # # Threads # class BlockingPortal(abc.BlockingPortal): def __new__(cls) -> BlockingPortal: return object.__new__(cls) def __init__(self) -> None: super().__init__() self._token = trio.lowlevel.current_trio_token() def _spawn_task_from_thread( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, future: Future[T_Retval], ) -> None: trio.from_thread.run_sync( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, trio_token=self._token, ) # # Subprocesses # @dataclass(eq=False) class ReceiveStreamWrapper(abc.ByteReceiveStream): _stream: trio.abc.ReceiveStream async def receive(self, max_bytes: int | None = None) -> bytes: try: data = await self._stream.receive_some(max_bytes) except trio.ClosedResourceError as exc: raise ClosedResourceError from exc.__cause__ except trio.BrokenResourceError as exc: raise BrokenResourceError from exc.__cause__ if data: return data else: raise EndOfStream async def aclose(self) -> None: await self._stream.aclose() @dataclass(eq=False) class SendStreamWrapper(abc.ByteSendStream): _stream: trio.abc.SendStream async def send(self, item: bytes) -> None: try: await self._stream.send_all(item) except trio.ClosedResourceError as exc: raise ClosedResourceError from exc.__cause__ except trio.BrokenResourceError as exc: raise BrokenResourceError from exc.__cause__ async def aclose(self) -> None: await self._stream.aclose() @dataclass(eq=False) class Process(abc.Process): _process: trio.Process _stdin: abc.ByteSendStream | None _stdout: abc.ByteReceiveStream | None _stderr: abc.ByteReceiveStream | None async def aclose(self) -> None: with CancelScope(shield=True): if self._stdin: await self._stdin.aclose() if self._stdout: await self._stdout.aclose() if self._stderr: await self._stderr.aclose() try: await self.wait() except BaseException: self.kill() with CancelScope(shield=True): await self.wait() raise async def wait(self) -> int: return await self._process.wait() def terminate(self) -> None: self._process.terminate() def kill(self) -> None: self._process.kill() def send_signal(self, signal: Signals) -> None: self._process.send_signal(signal) @property def pid(self) -> int: return self._process.pid @property def returncode(self) -> int | None: return self._process.returncode @property def stdin(self) -> abc.ByteSendStream | None: return self._stdin @property def stdout(self) -> abc.ByteReceiveStream | None: return self._stdout @property def stderr(self) -> abc.ByteReceiveStream | None: return self._stderr class _ProcessPoolShutdownInstrument(trio.abc.Instrument): def after_run(self) -> None: super().after_run() current_default_worker_process_limiter: trio.lowlevel.RunVar = RunVar( "current_default_worker_process_limiter" ) async def _shutdown_process_pool(workers: set[abc.Process]) -> None: try: await trio.sleep(math.inf) except trio.Cancelled: for process in workers: if process.returncode is None: process.kill() with CancelScope(shield=True): for process in workers: await process.aclose() # # Sockets and networking # class _TrioSocketMixin(Generic[T_SockAddr]): def __init__(self, trio_socket: TrioSocketType) -> None: self._trio_socket = trio_socket self._closed = False def _check_closed(self) -> None: if self._closed: raise ClosedResourceError if self._trio_socket.fileno() < 0: raise BrokenResourceError @property def _raw_socket(self) -> socket.socket: return self._trio_socket._sock # type: ignore[attr-defined] async def aclose(self) -> None: if self._trio_socket.fileno() >= 0: self._closed = True self._trio_socket.close() def _convert_socket_error(self, exc: BaseException) -> NoReturn: if isinstance(exc, trio.ClosedResourceError): raise ClosedResourceError from exc elif self._trio_socket.fileno() < 0 and self._closed: raise ClosedResourceError from None elif isinstance(exc, OSError): raise BrokenResourceError from exc else: raise exc class SocketStream(_TrioSocketMixin, abc.SocketStream): def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: try: data = await self._trio_socket.recv(max_bytes) except BaseException as exc: self._convert_socket_error(exc) if data: return data else: raise EndOfStream async def send(self, item: bytes) -> None: with self._send_guard: view = memoryview(item) while view: try: bytes_sent = await self._trio_socket.send(view) except BaseException as exc: self._convert_socket_error(exc) view = view[bytes_sent:] async def send_eof(self) -> None: self._trio_socket.shutdown(socket.SHUT_WR) class UNIXSocketStream(SocketStream, abc.UNIXSocketStream): async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: if not isinstance(msglen, int) or msglen < 0: raise ValueError("msglen must be a non-negative integer") if not isinstance(maxfds, int) or maxfds < 1: raise ValueError("maxfds must be a positive integer") fds = array.array("i") await trio.lowlevel.checkpoint() with self._receive_guard: while True: try: message, ancdata, flags, addr = await self._trio_socket.recvmsg( msglen, socket.CMSG_LEN(maxfds * fds.itemsize) ) except BaseException as exc: self._convert_socket_error(exc) else: if not message and not ancdata: raise EndOfStream break for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: raise RuntimeError( f"Received unexpected ancillary data; message = {message!r}, " f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" ) fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) return message, list(fds) async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: if not message: raise ValueError("message must not be empty") if not fds: raise ValueError("fds must not be empty") filenos: list[int] = [] for fd in fds: if isinstance(fd, int): filenos.append(fd) elif isinstance(fd, IOBase): filenos.append(fd.fileno()) fdarray = array.array("i", filenos) await trio.lowlevel.checkpoint() with self._send_guard: while True: try: await self._trio_socket.sendmsg( [message], [ ( socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray, ) ], ) break except BaseException as exc: self._convert_socket_error(exc) class TCPSocketListener(_TrioSocketMixin, abc.SocketListener): def __init__(self, raw_socket: socket.socket): super().__init__(trio.socket.from_stdlib_socket(raw_socket)) self._accept_guard = ResourceGuard("accepting connections from") async def accept(self) -> SocketStream: with self._accept_guard: try: trio_socket, _addr = await self._trio_socket.accept() except BaseException as exc: self._convert_socket_error(exc) trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return SocketStream(trio_socket) class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener): def __init__(self, raw_socket: socket.socket): super().__init__(trio.socket.from_stdlib_socket(raw_socket)) self._accept_guard = ResourceGuard("accepting connections from") async def accept(self) -> UNIXSocketStream: with self._accept_guard: try: trio_socket, _addr = await self._trio_socket.accept() except BaseException as exc: self._convert_socket_error(exc) return UNIXSocketStream(trio_socket) class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket): def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") async def receive(self) -> tuple[bytes, IPSockAddrType]: with self._receive_guard: try: data, addr = await self._trio_socket.recvfrom(65536) return data, convert_ipv6_sockaddr(addr) except BaseException as exc: self._convert_socket_error(exc) async def send(self, item: UDPPacketType) -> None: with self._send_guard: try: await self._trio_socket.sendto(*item) except BaseException as exc: self._convert_socket_error(exc) class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket): def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") async def receive(self) -> bytes: with self._receive_guard: try: return await self._trio_socket.recv(65536) except BaseException as exc: self._convert_socket_error(exc) async def send(self, item: bytes) -> None: with self._send_guard: try: await self._trio_socket.send(item) except BaseException as exc: self._convert_socket_error(exc) class UNIXDatagramSocket(_TrioSocketMixin[str], abc.UNIXDatagramSocket): def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") async def receive(self) -> UNIXDatagramPacketType: with self._receive_guard: try: data, addr = await self._trio_socket.recvfrom(65536) return data, addr except BaseException as exc: self._convert_socket_error(exc) async def send(self, item: UNIXDatagramPacketType) -> None: with self._send_guard: try: await self._trio_socket.sendto(*item) except BaseException as exc: self._convert_socket_error(exc) class ConnectedUNIXDatagramSocket( _TrioSocketMixin[str], abc.ConnectedUNIXDatagramSocket ): def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") async def receive(self) -> bytes: with self._receive_guard: try: return await self._trio_socket.recv(65536) except BaseException as exc: self._convert_socket_error(exc) async def send(self, item: bytes) -> None: with self._send_guard: try: await self._trio_socket.send(item) except BaseException as exc: self._convert_socket_error(exc) # # Synchronization # class Event(BaseEvent): def __new__(cls) -> Event: return object.__new__(cls) def __init__(self) -> None: self.__original = trio.Event() def is_set(self) -> bool: return self.__original.is_set() async def wait(self) -> None: return await self.__original.wait() def statistics(self) -> EventStatistics: orig_statistics = self.__original.statistics() return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting) def set(self) -> None: self.__original.set() class CapacityLimiter(BaseCapacityLimiter): def __new__( cls, total_tokens: float | None = None, *, original: trio.CapacityLimiter | None = None, ) -> CapacityLimiter: return object.__new__(cls) def __init__( self, total_tokens: float | None = None, *, original: trio.CapacityLimiter | None = None, ) -> None: if original is not None: self.__original = original else: assert total_tokens is not None self.__original = trio.CapacityLimiter(total_tokens) async def __aenter__(self) -> None: return await self.__original.__aenter__() async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.__original.__aexit__(exc_type, exc_val, exc_tb) @property def total_tokens(self) -> float: return self.__original.total_tokens @total_tokens.setter def total_tokens(self, value: float) -> None: self.__original.total_tokens = value @property def borrowed_tokens(self) -> int: return self.__original.borrowed_tokens @property def available_tokens(self) -> float: return self.__original.available_tokens def acquire_nowait(self) -> None: self.__original.acquire_nowait() def acquire_on_behalf_of_nowait(self, borrower: object) -> None: self.__original.acquire_on_behalf_of_nowait(borrower) async def acquire(self) -> None: await self.__original.acquire() async def acquire_on_behalf_of(self, borrower: object) -> None: await self.__original.acquire_on_behalf_of(borrower) def release(self) -> None: return self.__original.release() def release_on_behalf_of(self, borrower: object) -> None: return self.__original.release_on_behalf_of(borrower) def statistics(self) -> CapacityLimiterStatistics: orig = self.__original.statistics() return CapacityLimiterStatistics( borrowed_tokens=orig.borrowed_tokens, total_tokens=orig.total_tokens, borrowers=tuple(orig.borrowers), tasks_waiting=orig.tasks_waiting, ) _capacity_limiter_wrapper: trio.lowlevel.RunVar = RunVar("_capacity_limiter_wrapper") # # Signal handling # class _SignalReceiver: _iterator: AsyncIterator[int] def __init__(self, signals: tuple[Signals, ...]): self._signals = signals def __enter__(self) -> _SignalReceiver: self._cm = trio.open_signal_receiver(*self._signals) self._iterator = self._cm.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: return self._cm.__exit__(exc_type, exc_val, exc_tb) def __aiter__(self) -> _SignalReceiver: return self async def __anext__(self) -> Signals: signum = await self._iterator.__anext__() return Signals(signum) # # Testing and debugging # class TestRunner(abc.TestRunner): def __init__(self, **options: Any) -> None: from queue import Queue self._call_queue: Queue[Callable[[], object]] = Queue() self._send_stream: MemoryObjectSendStream | None = None self._options = options def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: if self._send_stream: self._send_stream.close() while self._send_stream is not None: self._call_queue.get()() async def _run_tests_and_fixtures(self) -> None: self._send_stream, receive_stream = create_memory_object_stream(1) with receive_stream: async for coro, outcome_holder in receive_stream: try: retval = await coro except BaseException as exc: outcome_holder.append(Error(exc)) else: outcome_holder.append(Value(retval)) def _main_task_finished(self, outcome: object) -> None: self._send_stream = None def _call_in_runner_task( self, func: Callable[P, Awaitable[T_Retval]], *args: P.args, **kwargs: P.kwargs, ) -> T_Retval: if self._send_stream is None: trio.lowlevel.start_guest_run( self._run_tests_and_fixtures, run_sync_soon_threadsafe=self._call_queue.put, done_callback=self._main_task_finished, **self._options, ) while self._send_stream is None: self._call_queue.get()() outcome_holder: list[Outcome] = [] self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder)) while not outcome_holder: self._call_queue.get()() return outcome_holder[0].unwrap() def run_asyncgen_fixture( self, fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], kwargs: dict[str, Any], ) -> Iterable[T_Retval]: asyncgen = fixture_func(**kwargs) fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None) yield fixturevalue try: self._call_in_runner_task(asyncgen.asend, None) except StopAsyncIteration: pass else: self._call_in_runner_task(asyncgen.aclose) raise RuntimeError("Async generator fixture did not stop") def run_fixture( self, fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], kwargs: dict[str, Any], ) -> T_Retval: return self._call_in_runner_task(fixture_func, **kwargs) def run_test( self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] ) -> None: self._call_in_runner_task(test_func, **kwargs) class TrioBackend(AsyncBackend): @classmethod def run( cls, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: return trio.run(func, *args) @classmethod def current_token(cls) -> object: return trio.lowlevel.current_trio_token() @classmethod def current_time(cls) -> float: return trio.current_time() @classmethod def cancelled_exception_class(cls) -> type[BaseException]: return trio.Cancelled @classmethod async def checkpoint(cls) -> None: await trio.lowlevel.checkpoint() @classmethod async def checkpoint_if_cancelled(cls) -> None: await trio.lowlevel.checkpoint_if_cancelled() @classmethod async def cancel_shielded_checkpoint(cls) -> None: await trio.lowlevel.cancel_shielded_checkpoint() @classmethod async def sleep(cls, delay: float) -> None: await trio.sleep(delay) @classmethod def create_cancel_scope( cls, *, deadline: float = math.inf, shield: bool = False ) -> abc.CancelScope: return CancelScope(deadline=deadline, shield=shield) @classmethod def current_effective_deadline(cls) -> float: return trio.current_effective_deadline() @classmethod def create_task_group(cls) -> abc.TaskGroup: return TaskGroup() @classmethod def create_event(cls) -> abc.Event: return Event() @classmethod def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: return CapacityLimiter(total_tokens) @classmethod async def run_sync_in_worker_thread( cls, func: Callable[[Unpack[PosArgsT]], T_Retval], args: tuple[Unpack[PosArgsT]], abandon_on_cancel: bool = False, limiter: abc.CapacityLimiter | None = None, ) -> T_Retval: def wrapper() -> T_Retval: with claim_worker_thread(TrioBackend, token): return func(*args) token = TrioBackend.current_token() return await run_sync( wrapper, abandon_on_cancel=abandon_on_cancel, limiter=cast(trio.CapacityLimiter, limiter), ) @classmethod def check_cancelled(cls) -> None: trio.from_thread.check_cancelled() @classmethod def run_async_from_thread( cls, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: return trio.from_thread.run(func, *args) @classmethod def run_sync_from_thread( cls, func: Callable[[Unpack[PosArgsT]], T_Retval], args: tuple[Unpack[PosArgsT]], token: object, ) -> T_Retval: return trio.from_thread.run_sync(func, *args) @classmethod def create_blocking_portal(cls) -> abc.BlockingPortal: return BlockingPortal() @classmethod async def open_process( cls, command: str | bytes | Sequence[str | bytes], *, shell: bool, stdin: int | IO[Any] | None, stdout: int | IO[Any] | None, stderr: int | IO[Any] | None, cwd: str | bytes | PathLike | None = None, env: Mapping[str, str] | None = None, start_new_session: bool = False, ) -> Process: process = await trio.lowlevel.open_process( # type: ignore[misc] command, # type: ignore[arg-type] stdin=stdin, stdout=stdout, stderr=stderr, shell=shell, cwd=cwd, env=env, start_new_session=start_new_session, ) stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None return Process(process, stdin_stream, stdout_stream, stderr_stream) @classmethod def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None: trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers) @classmethod async def connect_tcp( cls, host: str, port: int, local_address: IPSockAddrType | None = None ) -> SocketStream: family = socket.AF_INET6 if ":" in host else socket.AF_INET trio_socket = trio.socket.socket(family) trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if local_address: await trio_socket.bind(local_address) try: await trio_socket.connect((host, port)) except BaseException: trio_socket.close() raise return SocketStream(trio_socket) @classmethod async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream: trio_socket = trio.socket.socket(socket.AF_UNIX) try: await trio_socket.connect(path) except BaseException: trio_socket.close() raise return UNIXSocketStream(trio_socket) @classmethod def create_tcp_listener(cls, sock: socket.socket) -> abc.SocketListener: return TCPSocketListener(sock) @classmethod def create_unix_listener(cls, sock: socket.socket) -> abc.SocketListener: return UNIXSocketListener(sock) @classmethod async def create_udp_socket( cls, family: socket.AddressFamily, local_address: IPSockAddrType | None, remote_address: IPSockAddrType | None, reuse_port: bool, ) -> UDPSocket | ConnectedUDPSocket: trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) if reuse_port: trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if local_address: await trio_socket.bind(local_address) if remote_address: await trio_socket.connect(remote_address) return ConnectedUDPSocket(trio_socket) else: return UDPSocket(trio_socket) @classmethod @overload async def create_unix_datagram_socket( cls, raw_socket: socket.socket, remote_path: None ) -> abc.UNIXDatagramSocket: ... @classmethod @overload async def create_unix_datagram_socket( cls, raw_socket: socket.socket, remote_path: str | bytes ) -> abc.ConnectedUNIXDatagramSocket: ... @classmethod async def create_unix_datagram_socket( cls, raw_socket: socket.socket, remote_path: str | bytes | None ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket: trio_socket = trio.socket.from_stdlib_socket(raw_socket) if remote_path: await trio_socket.connect(remote_path) return ConnectedUNIXDatagramSocket(trio_socket) else: return UNIXDatagramSocket(trio_socket) @classmethod async def getaddrinfo( cls, host: bytes | str | None, port: str | int | None, *, family: int | AddressFamily = 0, type: int | SocketKind = 0, proto: int = 0, flags: int = 0, ) -> list[ tuple[ AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int], ] ]: return await trio.socket.getaddrinfo(host, port, family, type, proto, flags) @classmethod async def getnameinfo( cls, sockaddr: IPSockAddrType, flags: int = 0 ) -> tuple[str, str]: return await trio.socket.getnameinfo(sockaddr, flags) @classmethod async def wait_socket_readable(cls, sock: socket.socket) -> None: try: await wait_readable(sock) except trio.ClosedResourceError as exc: raise ClosedResourceError().with_traceback(exc.__traceback__) from None except trio.BusyResourceError: raise BusyResourceError("reading from") from None @classmethod async def wait_socket_writable(cls, sock: socket.socket) -> None: try: await wait_writable(sock) except trio.ClosedResourceError as exc: raise ClosedResourceError().with_traceback(exc.__traceback__) from None except trio.BusyResourceError: raise BusyResourceError("writing to") from None @classmethod def current_default_thread_limiter(cls) -> CapacityLimiter: try: return _capacity_limiter_wrapper.get() except LookupError: limiter = CapacityLimiter( original=trio.to_thread.current_default_thread_limiter() ) _capacity_limiter_wrapper.set(limiter) return limiter @classmethod def open_signal_receiver( cls, *signals: Signals ) -> ContextManager[AsyncIterator[Signals]]: return _SignalReceiver(signals) @classmethod def get_current_task(cls) -> TaskInfo: task = current_task() parent_id = None if task.parent_nursery and task.parent_nursery.parent_task: parent_id = id(task.parent_nursery.parent_task) return TaskInfo(id(task), parent_id, task.name, task.coro) @classmethod def get_running_tasks(cls) -> list[TaskInfo]: root_task = current_root_task() assert root_task task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)] nurseries = root_task.child_nurseries while nurseries: new_nurseries: list[trio.Nursery] = [] for nursery in nurseries: for task in nursery.child_tasks: task_infos.append( TaskInfo( id(task), id(nursery.parent_task), task.name, task.coro ) ) new_nurseries.extend(task.child_nurseries) nurseries = new_nurseries return task_infos @classmethod async def wait_all_tasks_blocked(cls) -> None: from trio.testing import wait_all_tasks_blocked await wait_all_tasks_blocked() @classmethod def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: return TestRunner(**options) backend_class = TrioBackend