pyo3/coroutine/
waker.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use crate::sync::GILOnceCell;
use crate::types::any::PyAnyMethods;
use crate::types::PyCFunction;
use crate::{intern, wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python};
use pyo3_macros::pyfunction;
use std::sync::Arc;
use std::task::Wake;

/// Lazy `asyncio.Future` wrapper, implementing [`Wake`] by calling `Future.set_result`.
///
/// asyncio future is let uninitialized until [`initialize_future`][1] is called.
/// If [`wake`][2] is called before future initialization (during Rust future polling),
/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`)
///
/// [1]: AsyncioWaker::initialize_future
/// [2]: AsyncioWaker::wake
pub struct AsyncioWaker(GILOnceCell<Option<LoopAndFuture>>);

impl AsyncioWaker {
    pub(super) fn new() -> Self {
        Self(GILOnceCell::new())
    }

    pub(super) fn reset(&mut self) {
        self.0.take();
    }

    pub(super) fn initialize_future<'py>(
        &self,
        py: Python<'py>,
    ) -> PyResult<Option<&Bound<'py, PyAny>>> {
        let init = || LoopAndFuture::new(py).map(Some);
        let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref();
        Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.bind(py)))
    }
}

impl Wake for AsyncioWaker {
    fn wake(self: Arc<Self>) {
        self.wake_by_ref()
    }

    fn wake_by_ref(self: &Arc<Self>) {
        Python::with_gil(|gil| {
            if let Some(loop_and_future) = self.0.get_or_init(gil, || None) {
                loop_and_future
                    .set_result(gil)
                    .expect("unexpected error in coroutine waker");
            }
        });
    }
}

struct LoopAndFuture {
    event_loop: PyObject,
    future: PyObject,
}

impl LoopAndFuture {
    fn new(py: Python<'_>) -> PyResult<Self> {
        static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
        let import = || -> PyResult<_> {
            let module = py.import("asyncio")?;
            Ok(module.getattr("get_running_loop")?.into())
        };
        let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
        let future = event_loop.call_method0(py, "create_future")?;
        Ok(Self { event_loop, future })
    }

    fn set_result(&self, py: Python<'_>) -> PyResult<()> {
        static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
        let release_waiter = RELEASE_WAITER.get_or_try_init(py, || {
            wrap_pyfunction!(release_waiter, py).map(Bound::unbind)
        })?;
        // `Future.set_result` must be called in event loop thread,
        // so it requires `call_soon_threadsafe`
        let call_soon_threadsafe = self.event_loop.call_method1(
            py,
            intern!(py, "call_soon_threadsafe"),
            (release_waiter, self.future.bind(py)),
        );
        if let Err(err) = call_soon_threadsafe {
            // `call_soon_threadsafe` will raise if the event loop is closed;
            // instead of catching an unspecific `RuntimeError`, check directly if it's closed.
            let is_closed = self.event_loop.call_method0(py, "is_closed")?;
            if !is_closed.extract(py)? {
                return Err(err);
            }
        }
        Ok(())
    }
}

/// Call `future.set_result` if the future is not done.
///
/// Future can be cancelled by the event loop before being waken.
/// See <https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5>
#[pyfunction(crate = "crate")]
fn release_waiter(future: &Bound<'_, PyAny>) -> PyResult<()> {
    let done = future.call_method0(intern!(future.py(), "done"))?;
    if !done.extract::<bool>()? {
        future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?;
    }
    Ok(())
}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here