use crate::sync::GILOnceCell;
use crate::types::any::PyAnyMethods;
use crate::types::PyCFunction;
use crate::{intern, wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python};
use pyo3_macros::pyfunction;
use std::sync::Arc;
use std::task::Wake;
pub struct AsyncioWaker(GILOnceCell<Option<LoopAndFuture>>);
impl AsyncioWaker {
pub(super) fn new() -> Self {
Self(GILOnceCell::new())
}
pub(super) fn reset(&mut self) {
self.0.take();
}
pub(super) fn initialize_future<'py>(
&self,
py: Python<'py>,
) -> PyResult<Option<&Bound<'py, PyAny>>> {
let init = || LoopAndFuture::new(py).map(Some);
let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref();
Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.bind(py)))
}
}
impl Wake for AsyncioWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
Python::with_gil(|gil| {
if let Some(loop_and_future) = self.0.get_or_init(gil, || None) {
loop_and_future
.set_result(gil)
.expect("unexpected error in coroutine waker");
}
});
}
}
struct LoopAndFuture {
event_loop: PyObject,
future: PyObject,
}
impl LoopAndFuture {
fn new(py: Python<'_>) -> PyResult<Self> {
static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
let import = || -> PyResult<_> {
let module = py.import("asyncio")?;
Ok(module.getattr("get_running_loop")?.into())
};
let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
let future = event_loop.call_method0(py, "create_future")?;
Ok(Self { event_loop, future })
}
fn set_result(&self, py: Python<'_>) -> PyResult<()> {
static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
let release_waiter = RELEASE_WAITER.get_or_try_init(py, || {
wrap_pyfunction!(release_waiter, py).map(Bound::unbind)
})?;
let call_soon_threadsafe = self.event_loop.call_method1(
py,
intern!(py, "call_soon_threadsafe"),
(release_waiter, self.future.bind(py)),
);
if let Err(err) = call_soon_threadsafe {
let is_closed = self.event_loop.call_method0(py, "is_closed")?;
if !is_closed.extract(py)? {
return Err(err);
}
}
Ok(())
}
}
#[pyfunction(crate = "crate")]
fn release_waiter(future: &Bound<'_, PyAny>) -> PyResult<()> {
let done = future.call_method0(intern!(future.py(), "done"))?;
if !done.extract::<bool>()? {
future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?;
}
Ok(())
}