pyo3/err/
err_state.rs

1use std::{
2    cell::UnsafeCell,
3    sync::{Mutex, Once},
4    thread::ThreadId,
5};
6
7use crate::{
8    exceptions::{PyBaseException, PyTypeError},
9    ffi,
10    ffi_ptr_ext::FfiPtrExt,
11    types::{PyAnyMethods, PyTraceback, PyType},
12    Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
13};
14
15pub(crate) struct PyErrState {
16    // Safety: can only hand out references when in the "normalized" state. Will never change
17    // after normalization.
18    normalized: Once,
19    // Guard against re-entrancy when normalizing the exception state.
20    normalizing_thread: Mutex<Option<ThreadId>>,
21    inner: UnsafeCell<Option<PyErrStateInner>>,
22}
23
24// Safety: The inner value is protected by locking to ensure that only the normalized state is
25// handed out as a reference.
26unsafe impl Send for PyErrState {}
27unsafe impl Sync for PyErrState {}
28#[cfg(feature = "nightly")]
29unsafe impl crate::marker::Ungil for PyErrState {}
30
31impl PyErrState {
32    pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
33        Self::from_inner(PyErrStateInner::Lazy(f))
34    }
35
36    pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
37        Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
38            PyErrStateLazyFnOutput {
39                ptype,
40                pvalue: args.arguments(py),
41            }
42        })))
43    }
44
45    pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
46        let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
47        // This state is already normalized, by completing the Once immediately we avoid
48        // reaching the `py.allow_threads` in `make_normalized` which is less efficient
49        // and introduces a GIL switch which could deadlock.
50        // See https://github.com/PyO3/pyo3/issues/4764
51        state.normalized.call_once(|| {});
52        state
53    }
54
55    pub(crate) fn restore(self, py: Python<'_>) {
56        self.inner
57            .into_inner()
58            .expect("PyErr state should never be invalid outside of normalization")
59            .restore(py)
60    }
61
62    fn from_inner(inner: PyErrStateInner) -> Self {
63        Self {
64            normalized: Once::new(),
65            normalizing_thread: Mutex::new(None),
66            inner: UnsafeCell::new(Some(inner)),
67        }
68    }
69
70    #[inline]
71    pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
72        if self.normalized.is_completed() {
73            match unsafe {
74                // Safety: self.inner will never be written again once normalized.
75                &*self.inner.get()
76            } {
77                Some(PyErrStateInner::Normalized(n)) => return n,
78                _ => unreachable!(),
79            }
80        }
81
82        self.make_normalized(py)
83    }
84
85    #[cold]
86    fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
87        // This process is safe because:
88        // - Access is guaranteed not to be concurrent thanks to `Python` GIL token
89        // - Write happens only once, and then never will change again.
90
91        // Guard against re-entrant normalization, because `Once` does not provide
92        // re-entrancy guarantees.
93        if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
94            assert!(
95                !(*thread == std::thread::current().id()),
96                "Re-entrant normalization of PyErrState detected"
97            );
98        }
99
100        // avoid deadlock of `.call_once` with the GIL
101        py.allow_threads(|| {
102            self.normalized.call_once(|| {
103                self.normalizing_thread
104                    .lock()
105                    .unwrap()
106                    .replace(std::thread::current().id());
107
108                // Safety: no other thread can access the inner value while we are normalizing it.
109                let state = unsafe {
110                    (*self.inner.get())
111                        .take()
112                        .expect("Cannot normalize a PyErr while already normalizing it.")
113                };
114
115                let normalized_state =
116                    Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));
117
118                // Safety: no other thread can access the inner value while we are normalizing it.
119                unsafe {
120                    *self.inner.get() = Some(normalized_state);
121                }
122            })
123        });
124
125        match unsafe {
126            // Safety: self.inner will never be written again once normalized.
127            &*self.inner.get()
128        } {
129            Some(PyErrStateInner::Normalized(n)) => n,
130            _ => unreachable!(),
131        }
132    }
133}
134
135pub(crate) struct PyErrStateNormalized {
136    #[cfg(not(Py_3_12))]
137    ptype: Py<PyType>,
138    pub pvalue: Py<PyBaseException>,
139    #[cfg(not(Py_3_12))]
140    ptraceback: Option<Py<PyTraceback>>,
141}
142
143impl PyErrStateNormalized {
144    pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
145        Self {
146            #[cfg(not(Py_3_12))]
147            ptype: pvalue.get_type().into(),
148            #[cfg(not(Py_3_12))]
149            ptraceback: unsafe {
150                Py::from_owned_ptr_or_opt(
151                    pvalue.py(),
152                    ffi::PyException_GetTraceback(pvalue.as_ptr()),
153                )
154            },
155            pvalue: pvalue.into(),
156        }
157    }
158
159    #[cfg(not(Py_3_12))]
160    pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
161        self.ptype.bind(py).clone()
162    }
163
164    #[cfg(Py_3_12)]
165    pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
166        self.pvalue.bind(py).get_type()
167    }
168
169    #[cfg(not(Py_3_12))]
170    pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
171        self.ptraceback
172            .as_ref()
173            .map(|traceback| traceback.bind(py).clone())
174    }
175
176    #[cfg(Py_3_12)]
177    pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
178        unsafe {
179            ffi::PyException_GetTraceback(self.pvalue.as_ptr())
180                .assume_owned_or_opt(py)
181                .map(|b| b.downcast_into_unchecked())
182        }
183    }
184
185    pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
186        #[cfg(Py_3_12)]
187        {
188            // Safety: PyErr_GetRaisedException can be called when attached to Python and
189            // returns either NULL or an owned reference.
190            unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
191                PyErrStateNormalized {
192                    // Safety: PyErr_GetRaisedException returns a valid exception type.
193                    pvalue: unsafe { pvalue.downcast_into_unchecked() }.unbind(),
194                }
195            })
196        }
197
198        #[cfg(not(Py_3_12))]
199        {
200            let (ptype, pvalue, ptraceback) = unsafe {
201                let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
202                let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
203                let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
204
205                ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
206
207                // Ensure that the exception coming from the interpreter is normalized.
208                if !ptype.is_null() {
209                    ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
210                }
211
212                // Safety: PyErr_NormalizeException will have produced up to three owned
213                // references of the correct types.
214                (
215                    ptype
216                        .assume_owned_or_opt(py)
217                        .map(|b| b.downcast_into_unchecked()),
218                    pvalue
219                        .assume_owned_or_opt(py)
220                        .map(|b| b.downcast_into_unchecked()),
221                    ptraceback
222                        .assume_owned_or_opt(py)
223                        .map(|b| b.downcast_into_unchecked()),
224                )
225            };
226
227            ptype.map(|ptype| PyErrStateNormalized {
228                ptype: ptype.unbind(),
229                pvalue: pvalue.expect("normalized exception value missing").unbind(),
230                ptraceback: ptraceback.map(Bound::unbind),
231            })
232        }
233    }
234
235    #[cfg(not(Py_3_12))]
236    unsafe fn from_normalized_ffi_tuple(
237        py: Python<'_>,
238        ptype: *mut ffi::PyObject,
239        pvalue: *mut ffi::PyObject,
240        ptraceback: *mut ffi::PyObject,
241    ) -> Self {
242        PyErrStateNormalized {
243            ptype: unsafe { Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing") },
244            pvalue: unsafe {
245                Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing")
246            },
247            ptraceback: unsafe { Py::from_owned_ptr_or_opt(py, ptraceback) },
248        }
249    }
250
251    pub fn clone_ref(&self, py: Python<'_>) -> Self {
252        Self {
253            #[cfg(not(Py_3_12))]
254            ptype: self.ptype.clone_ref(py),
255            pvalue: self.pvalue.clone_ref(py),
256            #[cfg(not(Py_3_12))]
257            ptraceback: self
258                .ptraceback
259                .as_ref()
260                .map(|ptraceback| ptraceback.clone_ref(py)),
261        }
262    }
263}
264
265pub(crate) struct PyErrStateLazyFnOutput {
266    pub(crate) ptype: PyObject,
267    pub(crate) pvalue: PyObject,
268}
269
270pub(crate) type PyErrStateLazyFn =
271    dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
272
273enum PyErrStateInner {
274    Lazy(Box<PyErrStateLazyFn>),
275    Normalized(PyErrStateNormalized),
276}
277
278impl PyErrStateInner {
279    fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
280        match self {
281            #[cfg(not(Py_3_12))]
282            PyErrStateInner::Lazy(lazy) => {
283                let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
284                unsafe {
285                    PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
286                }
287            }
288            #[cfg(Py_3_12)]
289            PyErrStateInner::Lazy(lazy) => {
290                // To keep the implementation simple, just write the exception into the interpreter,
291                // which will cause it to be normalized
292                raise_lazy(py, lazy);
293                PyErrStateNormalized::take(py)
294                    .expect("exception missing after writing to the interpreter")
295            }
296            PyErrStateInner::Normalized(normalized) => normalized,
297        }
298    }
299
300    #[cfg(not(Py_3_12))]
301    fn restore(self, py: Python<'_>) {
302        let (ptype, pvalue, ptraceback) = match self {
303            PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
304            PyErrStateInner::Normalized(PyErrStateNormalized {
305                ptype,
306                pvalue,
307                ptraceback,
308            }) => (
309                ptype.into_ptr(),
310                pvalue.into_ptr(),
311                ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
312            ),
313        };
314        unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
315    }
316
317    #[cfg(Py_3_12)]
318    fn restore(self, py: Python<'_>) {
319        match self {
320            PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
321            PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
322                ffi::PyErr_SetRaisedException(pvalue.into_ptr())
323            },
324        }
325    }
326}
327
328#[cfg(not(Py_3_12))]
329fn lazy_into_normalized_ffi_tuple(
330    py: Python<'_>,
331    lazy: Box<PyErrStateLazyFn>,
332) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
333    // To be consistent with 3.12 logic, go via raise_lazy, but also then normalize
334    // the resulting exception
335    raise_lazy(py, lazy);
336    let mut ptype = std::ptr::null_mut();
337    let mut pvalue = std::ptr::null_mut();
338    let mut ptraceback = std::ptr::null_mut();
339    unsafe {
340        ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
341        ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
342    }
343    (ptype, pvalue, ptraceback)
344}
345
346/// Raises a "lazy" exception state into the Python interpreter.
347///
348/// In principle this could be split in two; first a function to create an exception
349/// in a normalized state, and then a call to `PyErr_SetRaisedException` to raise it.
350///
351/// This would require either moving some logic from C to Rust, or requesting a new
352/// API in CPython.
353fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
354    let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
355    unsafe {
356        if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
357            ffi::PyErr_SetString(
358                PyTypeError::type_object_raw(py).cast(),
359                ffi::c_str!("exceptions must derive from BaseException").as_ptr(),
360            )
361        } else {
362            ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369
370    use crate::{
371        exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
372    };
373
374    #[test]
375    #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
376    fn test_reentrant_normalization() {
377        static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
378
379        struct RecursiveArgs;
380
381        impl PyErrArguments for RecursiveArgs {
382            fn arguments(self, py: Python<'_>) -> PyObject {
383                // .value(py) triggers normalization
384                ERR.get(py)
385                    .expect("is set just below")
386                    .value(py)
387                    .clone()
388                    .into()
389            }
390        }
391
392        Python::with_gil(|py| {
393            ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
394            ERR.get(py).expect("is set just above").value(py);
395        })
396    }
397
398    #[test]
399    #[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
400    fn test_no_deadlock_thread_switch() {
401        static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
402
403        struct GILSwitchArgs;
404
405        impl PyErrArguments for GILSwitchArgs {
406            fn arguments(self, py: Python<'_>) -> PyObject {
407                // releasing the GIL potentially allows for other threads to deadlock
408                // with the normalization going on here
409                py.allow_threads(|| {
410                    std::thread::sleep(std::time::Duration::from_millis(10));
411                });
412                py.None()
413            }
414        }
415
416        Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
417
418        // Let many threads attempt to read the normalized value at the same time
419        let handles = (0..10)
420            .map(|_| {
421                std::thread::spawn(|| {
422                    Python::with_gil(|py| {
423                        ERR.get(py).expect("is set just above").value(py);
424                    });
425                })
426            })
427            .collect::<Vec<_>>();
428
429        for handle in handles {
430            handle.join().unwrap();
431        }
432
433        // We should never have deadlocked, and should be able to run
434        // this assertion
435        Python::with_gil(|py| {
436            assert!(ERR
437                .get(py)
438                .expect("is set above")
439                .is_instance_of::<PyValueError>(py))
440        });
441    }
442}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here