1use std::{
2 cell::UnsafeCell,
3 sync::{Mutex, Once},
4 thread::ThreadId,
5};
6
7use crate::{
8 exceptions::{PyBaseException, PyTypeError},
9 ffi,
10 ffi_ptr_ext::FfiPtrExt,
11 types::{PyAnyMethods, PyTraceback, PyType},
12 Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
13};
14
15pub(crate) struct PyErrState {
16 normalized: Once,
19 normalizing_thread: Mutex<Option<ThreadId>>,
21 inner: UnsafeCell<Option<PyErrStateInner>>,
22}
23
24unsafe impl Send for PyErrState {}
27unsafe impl Sync for PyErrState {}
28#[cfg(feature = "nightly")]
29unsafe impl crate::marker::Ungil for PyErrState {}
30
31impl PyErrState {
32 pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
33 Self::from_inner(PyErrStateInner::Lazy(f))
34 }
35
36 pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
37 Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
38 PyErrStateLazyFnOutput {
39 ptype,
40 pvalue: args.arguments(py),
41 }
42 })))
43 }
44
45 pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
46 let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
47 state.normalized.call_once(|| {});
52 state
53 }
54
55 pub(crate) fn restore(self, py: Python<'_>) {
56 self.inner
57 .into_inner()
58 .expect("PyErr state should never be invalid outside of normalization")
59 .restore(py)
60 }
61
62 fn from_inner(inner: PyErrStateInner) -> Self {
63 Self {
64 normalized: Once::new(),
65 normalizing_thread: Mutex::new(None),
66 inner: UnsafeCell::new(Some(inner)),
67 }
68 }
69
70 #[inline]
71 pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
72 if self.normalized.is_completed() {
73 match unsafe {
74 &*self.inner.get()
76 } {
77 Some(PyErrStateInner::Normalized(n)) => return n,
78 _ => unreachable!(),
79 }
80 }
81
82 self.make_normalized(py)
83 }
84
85 #[cold]
86 fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
87 if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
94 assert!(
95 !(*thread == std::thread::current().id()),
96 "Re-entrant normalization of PyErrState detected"
97 );
98 }
99
100 py.allow_threads(|| {
102 self.normalized.call_once(|| {
103 self.normalizing_thread
104 .lock()
105 .unwrap()
106 .replace(std::thread::current().id());
107
108 let state = unsafe {
110 (*self.inner.get())
111 .take()
112 .expect("Cannot normalize a PyErr while already normalizing it.")
113 };
114
115 let normalized_state =
116 Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));
117
118 unsafe {
120 *self.inner.get() = Some(normalized_state);
121 }
122 })
123 });
124
125 match unsafe {
126 &*self.inner.get()
128 } {
129 Some(PyErrStateInner::Normalized(n)) => n,
130 _ => unreachable!(),
131 }
132 }
133}
134
135pub(crate) struct PyErrStateNormalized {
136 #[cfg(not(Py_3_12))]
137 ptype: Py<PyType>,
138 pub pvalue: Py<PyBaseException>,
139 #[cfg(not(Py_3_12))]
140 ptraceback: Option<Py<PyTraceback>>,
141}
142
143impl PyErrStateNormalized {
144 pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
145 Self {
146 #[cfg(not(Py_3_12))]
147 ptype: pvalue.get_type().into(),
148 #[cfg(not(Py_3_12))]
149 ptraceback: unsafe {
150 Py::from_owned_ptr_or_opt(
151 pvalue.py(),
152 ffi::PyException_GetTraceback(pvalue.as_ptr()),
153 )
154 },
155 pvalue: pvalue.into(),
156 }
157 }
158
159 #[cfg(not(Py_3_12))]
160 pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
161 self.ptype.bind(py).clone()
162 }
163
164 #[cfg(Py_3_12)]
165 pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
166 self.pvalue.bind(py).get_type()
167 }
168
169 #[cfg(not(Py_3_12))]
170 pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
171 self.ptraceback
172 .as_ref()
173 .map(|traceback| traceback.bind(py).clone())
174 }
175
176 #[cfg(Py_3_12)]
177 pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
178 unsafe {
179 ffi::PyException_GetTraceback(self.pvalue.as_ptr())
180 .assume_owned_or_opt(py)
181 .map(|b| b.downcast_into_unchecked())
182 }
183 }
184
185 pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
186 #[cfg(Py_3_12)]
187 {
188 unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
191 PyErrStateNormalized {
192 pvalue: unsafe { pvalue.downcast_into_unchecked() }.unbind(),
194 }
195 })
196 }
197
198 #[cfg(not(Py_3_12))]
199 {
200 let (ptype, pvalue, ptraceback) = unsafe {
201 let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
202 let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
203 let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
204
205 ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
206
207 if !ptype.is_null() {
209 ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
210 }
211
212 (
215 ptype
216 .assume_owned_or_opt(py)
217 .map(|b| b.downcast_into_unchecked()),
218 pvalue
219 .assume_owned_or_opt(py)
220 .map(|b| b.downcast_into_unchecked()),
221 ptraceback
222 .assume_owned_or_opt(py)
223 .map(|b| b.downcast_into_unchecked()),
224 )
225 };
226
227 ptype.map(|ptype| PyErrStateNormalized {
228 ptype: ptype.unbind(),
229 pvalue: pvalue.expect("normalized exception value missing").unbind(),
230 ptraceback: ptraceback.map(Bound::unbind),
231 })
232 }
233 }
234
235 #[cfg(not(Py_3_12))]
236 unsafe fn from_normalized_ffi_tuple(
237 py: Python<'_>,
238 ptype: *mut ffi::PyObject,
239 pvalue: *mut ffi::PyObject,
240 ptraceback: *mut ffi::PyObject,
241 ) -> Self {
242 PyErrStateNormalized {
243 ptype: unsafe { Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing") },
244 pvalue: unsafe {
245 Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing")
246 },
247 ptraceback: unsafe { Py::from_owned_ptr_or_opt(py, ptraceback) },
248 }
249 }
250
251 pub fn clone_ref(&self, py: Python<'_>) -> Self {
252 Self {
253 #[cfg(not(Py_3_12))]
254 ptype: self.ptype.clone_ref(py),
255 pvalue: self.pvalue.clone_ref(py),
256 #[cfg(not(Py_3_12))]
257 ptraceback: self
258 .ptraceback
259 .as_ref()
260 .map(|ptraceback| ptraceback.clone_ref(py)),
261 }
262 }
263}
264
265pub(crate) struct PyErrStateLazyFnOutput {
266 pub(crate) ptype: PyObject,
267 pub(crate) pvalue: PyObject,
268}
269
270pub(crate) type PyErrStateLazyFn =
271 dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
272
273enum PyErrStateInner {
274 Lazy(Box<PyErrStateLazyFn>),
275 Normalized(PyErrStateNormalized),
276}
277
278impl PyErrStateInner {
279 fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
280 match self {
281 #[cfg(not(Py_3_12))]
282 PyErrStateInner::Lazy(lazy) => {
283 let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
284 unsafe {
285 PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
286 }
287 }
288 #[cfg(Py_3_12)]
289 PyErrStateInner::Lazy(lazy) => {
290 raise_lazy(py, lazy);
293 PyErrStateNormalized::take(py)
294 .expect("exception missing after writing to the interpreter")
295 }
296 PyErrStateInner::Normalized(normalized) => normalized,
297 }
298 }
299
300 #[cfg(not(Py_3_12))]
301 fn restore(self, py: Python<'_>) {
302 let (ptype, pvalue, ptraceback) = match self {
303 PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
304 PyErrStateInner::Normalized(PyErrStateNormalized {
305 ptype,
306 pvalue,
307 ptraceback,
308 }) => (
309 ptype.into_ptr(),
310 pvalue.into_ptr(),
311 ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
312 ),
313 };
314 unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
315 }
316
317 #[cfg(Py_3_12)]
318 fn restore(self, py: Python<'_>) {
319 match self {
320 PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
321 PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
322 ffi::PyErr_SetRaisedException(pvalue.into_ptr())
323 },
324 }
325 }
326}
327
328#[cfg(not(Py_3_12))]
329fn lazy_into_normalized_ffi_tuple(
330 py: Python<'_>,
331 lazy: Box<PyErrStateLazyFn>,
332) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
333 raise_lazy(py, lazy);
336 let mut ptype = std::ptr::null_mut();
337 let mut pvalue = std::ptr::null_mut();
338 let mut ptraceback = std::ptr::null_mut();
339 unsafe {
340 ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
341 ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
342 }
343 (ptype, pvalue, ptraceback)
344}
345
346fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
354 let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
355 unsafe {
356 if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
357 ffi::PyErr_SetString(
358 PyTypeError::type_object_raw(py).cast(),
359 ffi::c_str!("exceptions must derive from BaseException").as_ptr(),
360 )
361 } else {
362 ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369
370 use crate::{
371 exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
372 };
373
374 #[test]
375 #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
376 fn test_reentrant_normalization() {
377 static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
378
379 struct RecursiveArgs;
380
381 impl PyErrArguments for RecursiveArgs {
382 fn arguments(self, py: Python<'_>) -> PyObject {
383 ERR.get(py)
385 .expect("is set just below")
386 .value(py)
387 .clone()
388 .into()
389 }
390 }
391
392 Python::with_gil(|py| {
393 ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
394 ERR.get(py).expect("is set just above").value(py);
395 })
396 }
397
398 #[test]
399 #[cfg(not(target_arch = "wasm32"))] fn test_no_deadlock_thread_switch() {
401 static ERR: GILOnceCell<PyErr> = GILOnceCell::new();
402
403 struct GILSwitchArgs;
404
405 impl PyErrArguments for GILSwitchArgs {
406 fn arguments(self, py: Python<'_>) -> PyObject {
407 py.allow_threads(|| {
410 std::thread::sleep(std::time::Duration::from_millis(10));
411 });
412 py.None()
413 }
414 }
415
416 Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
417
418 let handles = (0..10)
420 .map(|_| {
421 std::thread::spawn(|| {
422 Python::with_gil(|py| {
423 ERR.get(py).expect("is set just above").value(py);
424 });
425 })
426 })
427 .collect::<Vec<_>>();
428
429 for handle in handles {
430 handle.join().unwrap();
431 }
432
433 Python::with_gil(|py| {
436 assert!(ERR
437 .get(py)
438 .expect("is set above")
439 .is_instance_of::<PyValueError>(py))
440 });
441 }
442}