533 lines
20 KiB
Python
533 lines
20 KiB
Python
import asyncio
|
|
import asyncio.coroutines
|
|
import contextvars
|
|
import functools
|
|
import inspect
|
|
import os
|
|
import sys
|
|
import threading
|
|
import warnings
|
|
import weakref
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from typing import Any, Callable, Dict, Optional, overload
|
|
|
|
from .current_thread_executor import CurrentThreadExecutor
|
|
from .local import Local
|
|
|
|
|
|
def _restore_context(context):
|
|
# Check for changes in contextvars, and set them to the current
|
|
# context for downstream consumers
|
|
for cvar in context:
|
|
try:
|
|
if cvar.get() != context.get(cvar):
|
|
cvar.set(context.get(cvar))
|
|
except LookupError:
|
|
cvar.set(context.get(cvar))
|
|
|
|
|
|
def _iscoroutinefunction_or_partial(func: Any) -> bool:
|
|
# Python < 3.8 does not correctly determine partially wrapped
|
|
# coroutine functions are coroutine functions, hence the need for
|
|
# this to exist. Code taken from CPython.
|
|
if sys.version_info >= (3, 8):
|
|
return asyncio.iscoroutinefunction(func)
|
|
else:
|
|
while inspect.ismethod(func):
|
|
func = func.__func__
|
|
while isinstance(func, functools.partial):
|
|
func = func.func
|
|
|
|
return asyncio.iscoroutinefunction(func)
|
|
|
|
|
|
class ThreadSensitiveContext:
|
|
"""Async context manager to manage context for thread sensitive mode
|
|
|
|
This context manager controls which thread pool executor is used when in
|
|
thread sensitive mode. By default, a single thread pool executor is shared
|
|
within a process.
|
|
|
|
In Python 3.7+, the ThreadSensitiveContext() context manager may be used to
|
|
specify a thread pool per context.
|
|
|
|
This context manager is re-entrant, so only the outer-most call to
|
|
ThreadSensitiveContext will set the context.
|
|
|
|
Usage:
|
|
|
|
>>> import time
|
|
>>> async with ThreadSensitiveContext():
|
|
... await sync_to_async(time.sleep, 1)()
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.token = None
|
|
|
|
async def __aenter__(self):
|
|
try:
|
|
SyncToAsync.thread_sensitive_context.get()
|
|
except LookupError:
|
|
self.token = SyncToAsync.thread_sensitive_context.set(self)
|
|
|
|
return self
|
|
|
|
async def __aexit__(self, exc, value, tb):
|
|
if not self.token:
|
|
return
|
|
|
|
executor = SyncToAsync.context_to_thread_executor.pop(self, None)
|
|
if executor:
|
|
executor.shutdown()
|
|
SyncToAsync.thread_sensitive_context.reset(self.token)
|
|
|
|
|
|
class AsyncToSync:
|
|
"""
|
|
Utility class which turns an awaitable that only works on the thread with
|
|
the event loop into a synchronous callable that works in a subthread.
|
|
|
|
If the call stack contains an async loop, the code runs there.
|
|
Otherwise, the code runs in a new loop in a new thread.
|
|
|
|
Either way, this thread then pauses and waits to run any thread_sensitive
|
|
code called from further down the call stack using SyncToAsync, before
|
|
finally exiting once the async task returns.
|
|
"""
|
|
|
|
# Maps launched Tasks to the threads that launched them (for locals impl)
|
|
launch_map: "Dict[asyncio.Task[object], threading.Thread]" = {}
|
|
|
|
# Keeps track of which CurrentThreadExecutor to use. This uses an asgiref
|
|
# Local, not a threadlocal, so that tasks can work out what their parent used.
|
|
executors = Local()
|
|
|
|
# When we can't find a CurrentThreadExecutor from the context, such as
|
|
# inside create_task, we'll look it up here from the running event loop.
|
|
loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
|
|
|
|
def __init__(self, awaitable, force_new_loop=False):
|
|
if not callable(awaitable) or (
|
|
not _iscoroutinefunction_or_partial(awaitable)
|
|
and not _iscoroutinefunction_or_partial(
|
|
getattr(awaitable, "__call__", awaitable)
|
|
)
|
|
):
|
|
# Python does not have very reliable detection of async functions
|
|
# (lots of false negatives) so this is just a warning.
|
|
warnings.warn(
|
|
"async_to_sync was passed a non-async-marked callable", stacklevel=2
|
|
)
|
|
self.awaitable = awaitable
|
|
try:
|
|
self.__self__ = self.awaitable.__self__
|
|
except AttributeError:
|
|
pass
|
|
if force_new_loop:
|
|
# They have asked that we always run in a new sub-loop.
|
|
self.main_event_loop = None
|
|
else:
|
|
try:
|
|
self.main_event_loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
# There's no event loop in this thread. Look for the threadlocal if
|
|
# we're inside SyncToAsync
|
|
main_event_loop_pid = getattr(
|
|
SyncToAsync.threadlocal, "main_event_loop_pid", None
|
|
)
|
|
# We make sure the parent loop is from the same process - if
|
|
# they've forked, this is not going to be valid any more (#194)
|
|
if main_event_loop_pid and main_event_loop_pid == os.getpid():
|
|
self.main_event_loop = getattr(
|
|
SyncToAsync.threadlocal, "main_event_loop", None
|
|
)
|
|
else:
|
|
self.main_event_loop = None
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
# You can't call AsyncToSync from a thread with a running event loop
|
|
try:
|
|
event_loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
pass
|
|
else:
|
|
if event_loop.is_running():
|
|
raise RuntimeError(
|
|
"You cannot use AsyncToSync in the same thread as an async event loop - "
|
|
"just await the async function directly."
|
|
)
|
|
|
|
# Wrapping context in list so it can be reassigned from within
|
|
# `main_wrap`.
|
|
context = [contextvars.copy_context()]
|
|
|
|
# Make a future for the return information
|
|
call_result = Future()
|
|
# Get the source thread
|
|
source_thread = threading.current_thread()
|
|
# Make a CurrentThreadExecutor we'll use to idle in this thread - we
|
|
# need one for every sync frame, even if there's one above us in the
|
|
# same thread.
|
|
if hasattr(self.executors, "current"):
|
|
old_current_executor = self.executors.current
|
|
else:
|
|
old_current_executor = None
|
|
current_executor = CurrentThreadExecutor()
|
|
self.executors.current = current_executor
|
|
loop = None
|
|
# Use call_soon_threadsafe to schedule a synchronous callback on the
|
|
# main event loop's thread if it's there, otherwise make a new loop
|
|
# in this thread.
|
|
try:
|
|
awaitable = self.main_wrap(
|
|
args, kwargs, call_result, source_thread, sys.exc_info(), context
|
|
)
|
|
|
|
if not (self.main_event_loop and self.main_event_loop.is_running()):
|
|
# Make our own event loop - in a new thread - and run inside that.
|
|
loop = asyncio.new_event_loop()
|
|
self.loop_thread_executors[loop] = current_executor
|
|
loop_executor = ThreadPoolExecutor(max_workers=1)
|
|
loop_future = loop_executor.submit(
|
|
self._run_event_loop, loop, awaitable
|
|
)
|
|
if current_executor:
|
|
# Run the CurrentThreadExecutor until the future is done
|
|
current_executor.run_until_future(loop_future)
|
|
# Wait for future and/or allow for exception propagation
|
|
loop_future.result()
|
|
else:
|
|
# Call it inside the existing loop
|
|
self.main_event_loop.call_soon_threadsafe(
|
|
self.main_event_loop.create_task, awaitable
|
|
)
|
|
if current_executor:
|
|
# Run the CurrentThreadExecutor until the future is done
|
|
current_executor.run_until_future(call_result)
|
|
finally:
|
|
# Clean up any executor we were running
|
|
if loop is not None:
|
|
del self.loop_thread_executors[loop]
|
|
if hasattr(self.executors, "current"):
|
|
del self.executors.current
|
|
if old_current_executor:
|
|
self.executors.current = old_current_executor
|
|
_restore_context(context[0])
|
|
|
|
# Wait for results from the future.
|
|
return call_result.result()
|
|
|
|
def _run_event_loop(self, loop, coro):
|
|
"""
|
|
Runs the given event loop (designed to be called in a thread).
|
|
"""
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
loop.run_until_complete(coro)
|
|
finally:
|
|
try:
|
|
# mimic asyncio.run() behavior
|
|
# cancel unexhausted async generators
|
|
tasks = asyncio.all_tasks(loop)
|
|
for task in tasks:
|
|
task.cancel()
|
|
|
|
async def gather():
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
loop.run_until_complete(gather())
|
|
for task in tasks:
|
|
if task.cancelled():
|
|
continue
|
|
if task.exception() is not None:
|
|
loop.call_exception_handler(
|
|
{
|
|
"message": "unhandled exception during loop shutdown",
|
|
"exception": task.exception(),
|
|
"task": task,
|
|
}
|
|
)
|
|
if hasattr(loop, "shutdown_asyncgens"):
|
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
finally:
|
|
loop.close()
|
|
asyncio.set_event_loop(self.main_event_loop)
|
|
|
|
def __get__(self, parent, objtype):
|
|
"""
|
|
Include self for methods
|
|
"""
|
|
func = functools.partial(self.__call__, parent)
|
|
return functools.update_wrapper(func, self.awaitable)
|
|
|
|
async def main_wrap(
|
|
self, args, kwargs, call_result, source_thread, exc_info, context
|
|
):
|
|
"""
|
|
Wraps the awaitable with something that puts the result into the
|
|
result/exception future.
|
|
"""
|
|
if context is not None:
|
|
_restore_context(context[0])
|
|
|
|
current_task = SyncToAsync.get_current_task()
|
|
self.launch_map[current_task] = source_thread
|
|
try:
|
|
# If we have an exception, run the function inside the except block
|
|
# after raising it so exc_info is correctly populated.
|
|
if exc_info[1]:
|
|
try:
|
|
raise exc_info[1]
|
|
except BaseException:
|
|
result = await self.awaitable(*args, **kwargs)
|
|
else:
|
|
result = await self.awaitable(*args, **kwargs)
|
|
except BaseException as e:
|
|
call_result.set_exception(e)
|
|
else:
|
|
call_result.set_result(result)
|
|
finally:
|
|
del self.launch_map[current_task]
|
|
|
|
context[0] = contextvars.copy_context()
|
|
|
|
|
|
class SyncToAsync:
|
|
"""
|
|
Utility class which turns a synchronous callable into an awaitable that
|
|
runs in a threadpool. It also sets a threadlocal inside the thread so
|
|
calls to AsyncToSync can escape it.
|
|
|
|
If thread_sensitive is passed, the code will run in the same thread as any
|
|
outer code. This is needed for underlying Python code that is not
|
|
threadsafe (for example, code which handles SQLite database connections).
|
|
|
|
If the outermost program is async (i.e. SyncToAsync is outermost), then
|
|
this will be a dedicated single sub-thread that all sync code runs in,
|
|
one after the other. If the outermost program is sync (i.e. AsyncToSync is
|
|
outermost), this will just be the main thread. This is achieved by idling
|
|
with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent,
|
|
rather than just blocking.
|
|
|
|
If executor is passed in, that will be used instead of the loop's default executor.
|
|
In order to pass in an executor, thread_sensitive must be set to False, otherwise
|
|
a TypeError will be raised.
|
|
"""
|
|
|
|
# If they've set ASGI_THREADS, update the default asyncio executor for now
|
|
if "ASGI_THREADS" in os.environ:
|
|
# We use get_event_loop here - not get_running_loop - as this will
|
|
# be run at import time, and we want to update the main thread's loop.
|
|
loop = asyncio.get_event_loop()
|
|
loop.set_default_executor(
|
|
ThreadPoolExecutor(max_workers=int(os.environ["ASGI_THREADS"]))
|
|
)
|
|
|
|
# Maps launched threads to the coroutines that spawned them
|
|
launch_map: "Dict[threading.Thread, asyncio.Task[object]]" = {}
|
|
|
|
# Storage for main event loop references
|
|
threadlocal = threading.local()
|
|
|
|
# Single-thread executor for thread-sensitive code
|
|
single_thread_executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
# Maintain a contextvar for the current execution context. Optionally used
|
|
# for thread sensitive mode.
|
|
thread_sensitive_context: "contextvars.ContextVar[str]" = contextvars.ContextVar(
|
|
"thread_sensitive_context"
|
|
)
|
|
|
|
# Contextvar that is used to detect if the single thread executor
|
|
# would be awaited on while already being used in the same context
|
|
deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar(
|
|
"deadlock_context"
|
|
)
|
|
|
|
# Maintaining a weak reference to the context ensures that thread pools are
|
|
# erased once the context goes out of scope. This terminates the thread pool.
|
|
context_to_thread_executor: "weakref.WeakKeyDictionary[object, ThreadPoolExecutor]" = (
|
|
weakref.WeakKeyDictionary()
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
func: Callable[..., Any],
|
|
thread_sensitive: bool = True,
|
|
executor: Optional["ThreadPoolExecutor"] = None,
|
|
) -> None:
|
|
if (
|
|
not callable(func)
|
|
or _iscoroutinefunction_or_partial(func)
|
|
or _iscoroutinefunction_or_partial(getattr(func, "__call__", func))
|
|
):
|
|
raise TypeError("sync_to_async can only be applied to sync functions.")
|
|
self.func = func
|
|
functools.update_wrapper(self, func)
|
|
self._thread_sensitive = thread_sensitive
|
|
self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore
|
|
if thread_sensitive and executor is not None:
|
|
raise TypeError("executor must not be set when thread_sensitive is True")
|
|
self._executor = executor
|
|
try:
|
|
self.__self__ = func.__self__ # type: ignore
|
|
except AttributeError:
|
|
pass
|
|
|
|
async def __call__(self, *args, **kwargs):
|
|
loop = asyncio.get_running_loop()
|
|
|
|
# Work out what thread to run the code in
|
|
if self._thread_sensitive:
|
|
if hasattr(AsyncToSync.executors, "current"):
|
|
# If we have a parent sync thread above somewhere, use that
|
|
executor = AsyncToSync.executors.current
|
|
elif self.thread_sensitive_context and self.thread_sensitive_context.get(
|
|
None
|
|
):
|
|
# If we have a way of retrieving the current context, attempt
|
|
# to use a per-context thread pool executor
|
|
thread_sensitive_context = self.thread_sensitive_context.get()
|
|
|
|
if thread_sensitive_context in self.context_to_thread_executor:
|
|
# Re-use thread executor in current context
|
|
executor = self.context_to_thread_executor[thread_sensitive_context]
|
|
else:
|
|
# Create new thread executor in current context
|
|
executor = ThreadPoolExecutor(max_workers=1)
|
|
self.context_to_thread_executor[thread_sensitive_context] = executor
|
|
elif loop in AsyncToSync.loop_thread_executors:
|
|
# Re-use thread executor for running loop
|
|
executor = AsyncToSync.loop_thread_executors[loop]
|
|
elif self.deadlock_context and self.deadlock_context.get(False):
|
|
raise RuntimeError(
|
|
"Single thread executor already being used, would deadlock"
|
|
)
|
|
else:
|
|
# Otherwise, we run it in a fixed single thread
|
|
executor = self.single_thread_executor
|
|
if self.deadlock_context:
|
|
self.deadlock_context.set(True)
|
|
else:
|
|
# Use the passed in executor, or the loop's default if it is None
|
|
executor = self._executor
|
|
|
|
context = contextvars.copy_context()
|
|
child = functools.partial(self.func, *args, **kwargs)
|
|
func = context.run
|
|
args = (child,)
|
|
kwargs = {}
|
|
|
|
try:
|
|
# Run the code in the right thread
|
|
future = loop.run_in_executor(
|
|
executor,
|
|
functools.partial(
|
|
self.thread_handler,
|
|
loop,
|
|
self.get_current_task(),
|
|
sys.exc_info(),
|
|
func,
|
|
*args,
|
|
**kwargs,
|
|
),
|
|
)
|
|
ret = await asyncio.wait_for(future, timeout=None)
|
|
|
|
finally:
|
|
_restore_context(context)
|
|
if self.deadlock_context:
|
|
self.deadlock_context.set(False)
|
|
|
|
return ret
|
|
|
|
def __get__(self, parent, objtype):
|
|
"""
|
|
Include self for methods
|
|
"""
|
|
return functools.partial(self.__call__, parent)
|
|
|
|
def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs):
|
|
"""
|
|
Wraps the sync application with exception handling.
|
|
"""
|
|
# Set the threadlocal for AsyncToSync
|
|
self.threadlocal.main_event_loop = loop
|
|
self.threadlocal.main_event_loop_pid = os.getpid()
|
|
# Set the task mapping (used for the locals module)
|
|
current_thread = threading.current_thread()
|
|
if AsyncToSync.launch_map.get(source_task) == current_thread:
|
|
# Our parent task was launched from this same thread, so don't make
|
|
# a launch map entry - let it shortcut over us! (and stop infinite loops)
|
|
parent_set = False
|
|
else:
|
|
self.launch_map[current_thread] = source_task
|
|
parent_set = True
|
|
# Run the function
|
|
try:
|
|
# If we have an exception, run the function inside the except block
|
|
# after raising it so exc_info is correctly populated.
|
|
if exc_info[1]:
|
|
try:
|
|
raise exc_info[1]
|
|
except BaseException:
|
|
return func(*args, **kwargs)
|
|
else:
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
# Only delete the launch_map parent if we set it, otherwise it is
|
|
# from someone else.
|
|
if parent_set:
|
|
del self.launch_map[current_thread]
|
|
|
|
@staticmethod
|
|
def get_current_task():
|
|
"""
|
|
Implementation of asyncio.current_task()
|
|
that returns None if there is no task.
|
|
"""
|
|
try:
|
|
return asyncio.current_task()
|
|
except RuntimeError:
|
|
return None
|
|
|
|
|
|
# Lowercase aliases (and decorator friendliness)
|
|
async_to_sync = AsyncToSync
|
|
|
|
|
|
@overload
|
|
def sync_to_async(
|
|
func: None = None,
|
|
thread_sensitive: bool = True,
|
|
executor: Optional["ThreadPoolExecutor"] = None,
|
|
) -> Callable[[Callable[..., Any]], SyncToAsync]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def sync_to_async(
|
|
func: Callable[..., Any],
|
|
thread_sensitive: bool = True,
|
|
executor: Optional["ThreadPoolExecutor"] = None,
|
|
) -> SyncToAsync:
|
|
...
|
|
|
|
|
|
def sync_to_async(
|
|
func=None,
|
|
thread_sensitive=True,
|
|
executor=None,
|
|
):
|
|
if func is None:
|
|
return lambda f: SyncToAsync(
|
|
f,
|
|
thread_sensitive=thread_sensitive,
|
|
executor=executor,
|
|
)
|
|
return SyncToAsync(
|
|
func,
|
|
thread_sensitive=thread_sensitive,
|
|
executor=executor,
|
|
)
|