Skip to main content

pyo3/pycell/
impl_.rs

1#![allow(missing_docs)]
2//! Crate-private implementation of PyClassObject
3
4use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::mem::{offset_of, ManuallyDrop, MaybeUninit};
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::impl_::pyclass::{
10    PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef, PyObjectOffset,
11};
12use crate::internal::get_slot::{TP_DEALLOC, TP_FREE};
13use crate::type_object::{PyLayout, PySizedLayout, PyTypeInfo};
14use crate::types::PyType;
15use crate::{ffi, PyClass, Python};
16
17use crate::types::PyTypeMethods;
18
19use super::{PyBorrowError, PyBorrowMutError};
20
21pub trait PyClassMutability {
22    // The storage for this inheritance layer. Only the first mutable class in
23    // an inheritance hierarchy needs to store the borrow flag.
24    type Storage: PyClassBorrowChecker;
25    // The borrow flag needed to implement this class' mutability. Empty until
26    // the first mutable class, at which point it is BorrowChecker and will be
27    // for all subclasses.
28    type Checker: PyClassBorrowChecker;
29    type ImmutableChild: PyClassMutability;
30    type MutableChild: PyClassMutability;
31}
32
33pub struct ImmutableClass(());
34pub struct MutableClass(());
35pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>);
36
37impl PyClassMutability for ImmutableClass {
38    type Storage = EmptySlot;
39    type Checker = EmptySlot;
40    type ImmutableChild = ImmutableClass;
41    type MutableChild = MutableClass;
42}
43
44impl PyClassMutability for MutableClass {
45    type Storage = BorrowChecker;
46    type Checker = BorrowChecker;
47    type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
48    type MutableChild = ExtendsMutableAncestor<MutableClass>;
49}
50
51impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
52    type Storage = EmptySlot;
53    type Checker = BorrowChecker;
54    type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
55    type MutableChild = ExtendsMutableAncestor<MutableClass>;
56}
57
58#[derive(Debug)]
59struct BorrowFlag(AtomicUsize);
60
61impl BorrowFlag {
62    pub(crate) const UNUSED: usize = 0;
63    const HAS_MUTABLE_BORROW: usize = usize::MAX;
64    fn increment(&self) -> Result<(), PyBorrowError> {
65        // relaxed is OK because we will read the value again in the compare_exchange
66        let mut value = self.0.load(Ordering::Relaxed);
67        loop {
68            if value == BorrowFlag::HAS_MUTABLE_BORROW {
69                return Err(PyBorrowError { _private: () });
70            }
71            match self.0.compare_exchange(
72                // only increment if the value hasn't changed since the
73                // last atomic load
74                value,
75                value + 1,
76                // reading the value is happens-after a previous write
77                // writing the new value is happens-after the previous read
78                Ordering::AcqRel,
79                // relaxed is OK here because we're going to try to read again
80                Ordering::Relaxed,
81            ) {
82                Ok(..) => {
83                    break Ok(());
84                }
85                Err(changed_value) => {
86                    // value changed under us, need to try again
87                    value = changed_value;
88                }
89            }
90        }
91    }
92    fn decrement(&self) {
93        // relaxed load is OK but decrements must happen-before the next read
94        self.0.fetch_sub(1, Ordering::Release);
95    }
96}
97
98pub struct EmptySlot(());
99pub struct BorrowChecker(BorrowFlag);
100
101pub trait PyClassBorrowChecker {
102    /// Initial value for self
103    fn new() -> Self
104    where
105        Self: Sized;
106
107    /// Increments immutable borrow count, if possible
108    fn try_borrow(&self) -> Result<(), PyBorrowError>;
109
110    /// Decrements immutable borrow count
111    fn release_borrow(&self);
112    /// Increments mutable borrow count, if possible
113    fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>;
114    /// Decremements mutable borrow count
115    fn release_borrow_mut(&self);
116}
117
118impl PyClassBorrowChecker for EmptySlot {
119    #[inline]
120    fn new() -> Self {
121        EmptySlot(())
122    }
123
124    #[inline]
125    fn try_borrow(&self) -> Result<(), PyBorrowError> {
126        Ok(())
127    }
128
129    #[inline]
130    fn release_borrow(&self) {}
131
132    #[inline]
133    fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
134        unreachable!()
135    }
136
137    #[inline]
138    fn release_borrow_mut(&self) {
139        unreachable!()
140    }
141}
142
143impl PyClassBorrowChecker for BorrowChecker {
144    #[inline]
145    fn new() -> Self {
146        Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
147    }
148
149    fn try_borrow(&self) -> Result<(), PyBorrowError> {
150        self.0.increment()
151    }
152
153    fn release_borrow(&self) {
154        self.0.decrement();
155    }
156
157    fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
158        let flag = &self.0;
159        match flag.0.compare_exchange(
160            // only allowed to transition to mutable borrow if the reference is
161            // currently unused
162            BorrowFlag::UNUSED,
163            BorrowFlag::HAS_MUTABLE_BORROW,
164            // On success, reading the flag and updating its state are an atomic
165            // operation
166            Ordering::AcqRel,
167            // It doesn't matter precisely when the failure gets turned
168            // into an error
169            Ordering::Relaxed,
170        ) {
171            Ok(..) => Ok(()),
172            Err(..) => Err(PyBorrowMutError { _private: () }),
173        }
174    }
175
176    fn release_borrow_mut(&self) {
177        self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
178    }
179}
180
181pub trait GetBorrowChecker<T: PyClassImpl> {
182    fn borrow_checker(
183        class_object: &T::Layout,
184    ) -> &<T::PyClassMutability as PyClassMutability>::Checker;
185}
186
187impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
188    fn borrow_checker(class_object: &T::Layout) -> &BorrowChecker {
189        &class_object.contents().borrow_checker
190    }
191}
192
193impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
194    fn borrow_checker(class_object: &T::Layout) -> &EmptySlot {
195        &class_object.contents().borrow_checker
196    }
197}
198
199impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T>
200    for ExtendsMutableAncestor<M>
201where
202    T::BaseType: PyClassImpl + PyClassBaseType<LayoutAsBase = <T::BaseType as PyClassImpl>::Layout>,
203    <T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>,
204{
205    fn borrow_checker(class_object: &T::Layout) -> &BorrowChecker {
206        <<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(class_object.ob_base())
207    }
208}
209
210/// Base layout of PyClassObject with a known sized base type.
211/// Corresponds to [PyObject](https://docs.python.org/3/c-api/structures.html#c.PyObject) from the C API.
212#[doc(hidden)]
213#[repr(C)]
214pub struct PyClassObjectBase<T> {
215    ob_base: T,
216}
217
218unsafe impl<T, U> PyLayout<T> for PyClassObjectBase<U> where U: PySizedLayout<T> {}
219
220impl<T, U> PyClassObjectBaseLayout<T> for PyClassObjectBase<U>
221where
222    U: PySizedLayout<T>,
223    T: PyTypeInfo,
224{
225    fn ensure_threadsafe(&self) {}
226    fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
227        Ok(())
228    }
229    unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
230        unsafe { tp_dealloc(slf, &T::type_object(py)) };
231    }
232}
233
234/// Base layout of PyClassObject with an unknown sized base type.
235/// Corresponds to [PyVarObject](https://docs.python.org/3/c-api/structures.html#c.PyVarObject) from the C API.
236#[doc(hidden)]
237#[repr(C)]
238pub struct PyVariableClassObjectBase {
239    ob_base: ffi::PyVarObject,
240}
241
242unsafe impl<T> PyLayout<T> for PyVariableClassObjectBase {}
243
244impl<T: PyTypeInfo> PyClassObjectBaseLayout<T> for PyVariableClassObjectBase {
245    fn ensure_threadsafe(&self) {}
246    fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
247        Ok(())
248    }
249    unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
250        unsafe { tp_dealloc(slf, &T::type_object(py)) };
251    }
252}
253
254/// Implementation of tp_dealloc.
255/// # Safety
256/// - `slf` must be a valid pointer to an instance of the type at `type_obj` or a subclass.
257/// - `slf` must not be used after this call (as it will be freed).
258unsafe fn tp_dealloc(slf: *mut ffi::PyObject, type_obj: &crate::Bound<'_, PyType>) {
259    let py = type_obj.py();
260    unsafe {
261        // FIXME: there is potentially subtle issues here if the base is overwritten
262        // at runtime? To be investigated.
263        let type_ptr = type_obj.as_type_ptr();
264        let actual_type = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(slf));
265
266        // For `#[pyclass]` types which inherit from PyAny, we can just call tp_free
267        if std::ptr::eq(type_ptr, &raw const ffi::PyBaseObject_Type) {
268            let tp_free = actual_type
269                .get_slot(TP_FREE)
270                .expect("PyBaseObject_Type should have tp_free");
271            return tp_free(slf.cast());
272        }
273
274        // More complex native types (e.g. `extends=PyDict`) require calling the base's dealloc.
275        // FIXME: should this be using actual_type.tp_dealloc?
276        if let Some(dealloc) = type_obj.get_slot(TP_DEALLOC) {
277            // Before CPython 3.11 BaseException_dealloc would use Py_GC_UNTRACK which
278            // assumes the exception is currently GC tracked, so we have to re-track
279            // before calling the dealloc so that it can safely call Py_GC_UNTRACK.
280            #[cfg(not(any(Py_3_11, PyPy)))]
281            if ffi::PyType_FastSubclass(type_ptr, ffi::Py_TPFLAGS_BASE_EXC_SUBCLASS) == 1 {
282                ffi::PyObject_GC_Track(slf.cast());
283            }
284            dealloc(slf);
285        } else {
286            type_obj.get_slot(TP_FREE).expect("type missing tp_free")(slf.cast());
287        }
288    }
289}
290
291/// functionality common to all PyObjects regardless of the layout
292#[doc(hidden)]
293pub trait PyClassObjectBaseLayout<T>: PyLayout<T> {
294    fn ensure_threadsafe(&self);
295    fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
296    /// Implementation of tp_dealloc.
297    /// # Safety
298    /// - slf must be a valid pointer to an instance of a T or a subclass.
299    /// - slf must not be used after this call (as it will be freed).
300    unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject);
301}
302
303/// Functionality required for creating and managing the memory associated with a pyclass annotated struct.
304#[doc(hidden)]
305#[diagnostic::on_unimplemented(
306    message = "the class layout is not valid",
307    label = "required for `#[pyclass(extends=...)]`",
308    note = "the python version being built against influences which layouts are valid"
309)]
310pub trait PyClassObjectLayout<T: PyClassImpl>: PyClassObjectBaseLayout<T> {
311    /// Gets the offset of the contents from the start of the struct in bytes.
312    const CONTENTS_OFFSET: PyObjectOffset;
313
314    /// Used to set `PyType_Spec::basicsize`
315    /// ([docs](https://docs.python.org/3/c-api/type.html#c.PyType_Spec.basicsize))
316    const BASIC_SIZE: ffi::Py_ssize_t;
317
318    /// Gets the offset of the dictionary from the start of the struct in bytes.
319    const DICT_OFFSET: PyObjectOffset;
320
321    /// Gets the offset of the weakref list from the start of the struct in bytes.
322    const WEAKLIST_OFFSET: PyObjectOffset;
323
324    /// Obtain a pointer to the contents of an uninitialized PyObject of this type.
325    ///
326    /// SAFETY: `obj` must have the layout that the implementation is expecting
327    unsafe fn contents_uninit(
328        obj: *mut ffi::PyObject,
329    ) -> *mut MaybeUninit<PyClassObjectContents<T>>;
330
331    /// Obtain a reference to the structure that contains the pyclass struct and associated metadata.
332    fn contents(&self) -> &PyClassObjectContents<T>;
333
334    /// Obtain a mutable reference to the structure that contains the pyclass struct and associated metadata.
335    fn contents_mut(&mut self) -> &mut PyClassObjectContents<T>;
336
337    /// Obtain a pointer to the pyclass struct.
338    fn get_ptr(&self) -> *mut T;
339
340    /// obtain a reference to the data at the start of the PyObject.
341    fn ob_base(&self) -> &<T::BaseType as PyClassBaseType>::LayoutAsBase;
342
343    fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker;
344}
345
346#[repr(C)]
347pub struct PyClassObjectContents<T: PyClassImpl> {
348    pub(crate) value: ManuallyDrop<UnsafeCell<T>>,
349    pub(crate) borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage,
350    pub(crate) thread_checker: T::ThreadChecker,
351    pub(crate) dict: T::Dict,
352    pub(crate) weakref: T::WeakRef,
353}
354
355impl<T: PyClassImpl> PyClassObjectContents<T> {
356    pub(crate) fn new(init: T) -> Self {
357        PyClassObjectContents {
358            value: ManuallyDrop::new(UnsafeCell::new(init)),
359            borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(),
360            thread_checker: T::ThreadChecker::new(),
361            dict: T::Dict::INIT,
362            weakref: T::WeakRef::INIT,
363        }
364    }
365
366    unsafe fn dealloc(&mut self, py: Python<'_>, py_object: *mut ffi::PyObject) {
367        if self.thread_checker.can_drop(py) {
368            unsafe { ManuallyDrop::drop(&mut self.value) };
369        }
370        self.dict.clear_dict(py);
371        unsafe { self.weakref.clear_weakrefs(py_object, py) };
372    }
373}
374
375/// The layout of a PyClassObject with a known sized base class.
376#[repr(C)]
377pub struct PyStaticClassObject<T: PyClassImpl> {
378    ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
379    contents: PyClassObjectContents<T>,
380}
381
382impl<T: PyClassImpl<Layout = Self>> PyClassObjectLayout<T> for PyStaticClassObject<T> {
383    /// Gets the offset of the contents from the start of the struct in bytes.
384    const CONTENTS_OFFSET: PyObjectOffset = {
385        let offset = offset_of!(Self, contents);
386        // Py_ssize_t may not be equal to isize on all platforms
387        assert!(offset <= ffi::Py_ssize_t::MAX as usize);
388        PyObjectOffset::Absolute(offset as ffi::Py_ssize_t)
389    };
390
391    const BASIC_SIZE: ffi::Py_ssize_t = {
392        let size = std::mem::size_of::<Self>();
393        assert!(size <= ffi::Py_ssize_t::MAX as usize);
394        size as _
395    };
396
397    const DICT_OFFSET: PyObjectOffset = {
398        let offset = offset_of!(PyStaticClassObject<T>, contents)
399            + offset_of!(PyClassObjectContents<T>, dict);
400        assert!(offset <= ffi::Py_ssize_t::MAX as usize);
401        PyObjectOffset::Absolute(offset as _)
402    };
403
404    const WEAKLIST_OFFSET: PyObjectOffset = {
405        let offset = offset_of!(PyStaticClassObject<T>, contents)
406            + offset_of!(PyClassObjectContents<T>, weakref);
407        assert!(offset <= ffi::Py_ssize_t::MAX as usize);
408        PyObjectOffset::Absolute(offset as _)
409    };
410
411    unsafe fn contents_uninit(
412        obj: *mut ffi::PyObject,
413    ) -> *mut MaybeUninit<PyClassObjectContents<T>> {
414        #[repr(C)]
415        struct PartiallyInitializedClassObject<T: PyClassImpl> {
416            _ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
417            contents: MaybeUninit<PyClassObjectContents<T>>,
418        }
419        let obj = obj.cast::<PartiallyInitializedClassObject<T>>();
420        unsafe { &raw mut (*obj).contents }
421    }
422
423    fn contents(&self) -> &PyClassObjectContents<T> {
424        &self.contents
425    }
426
427    fn contents_mut(&mut self) -> &mut PyClassObjectContents<T> {
428        &mut self.contents
429    }
430
431    fn get_ptr(&self) -> *mut T {
432        self.contents.value.get()
433    }
434
435    fn ob_base(&self) -> &<T::BaseType as PyClassBaseType>::LayoutAsBase {
436        &self.ob_base
437    }
438
439    fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker {
440        T::PyClassMutability::borrow_checker(self)
441    }
442}
443
444unsafe impl<T: PyClassImpl> PyLayout<T> for PyStaticClassObject<T> {}
445impl<T: PyClass> PySizedLayout<T> for PyStaticClassObject<T> {}
446
447impl<T: PyClassImpl<Layout = Self>> PyClassObjectBaseLayout<T> for PyStaticClassObject<T>
448where
449    <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectBaseLayout<T::BaseType>,
450{
451    fn ensure_threadsafe(&self) {
452        self.contents.thread_checker.ensure();
453        self.ob_base.ensure_threadsafe();
454    }
455    fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
456        if !self.contents.thread_checker.check() {
457            return Err(PyBorrowError { _private: () });
458        }
459        self.ob_base.check_threadsafe()
460    }
461    unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
462        // Safety: Python only calls tp_dealloc when no references to the object remain.
463        let class_object = unsafe { &mut *(slf.cast::<T::Layout>()) };
464        unsafe { class_object.contents_mut().dealloc(py, slf) };
465        unsafe { <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf) }
466    }
467}
468
469/// A layout for a PyClassObject with an unknown sized base type.
470///
471/// Utilises [PEP-697](https://peps.python.org/pep-0697/)
472#[doc(hidden)]
473#[repr(C)]
474pub struct PyVariableClassObject<T: PyClassImpl> {
475    ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
476}
477
478#[cfg(Py_3_12)]
479impl<T: PyClass<Layout = Self>> PyVariableClassObject<T> {
480    /// # Safety
481    /// - `obj` must have the layout that the implementation is expecting
482    /// - thread must be attached to the interpreter
483    unsafe fn get_contents_of_obj(
484        obj: *mut ffi::PyObject,
485    ) -> *mut MaybeUninit<PyClassObjectContents<T>> {
486        // TODO: it would be nice to eventually avoid coupling to the PyO3 statics here, maybe using
487        // 3.14's PyType_GetBaseByToken, to support PEP 587 / multiple interpreters better
488        // SAFETY: caller guarantees attached to the interpreter
489        let type_obj = T::type_object_raw(unsafe { Python::assume_attached() });
490        let pointer = unsafe { ffi::PyObject_GetTypeData(obj, type_obj) };
491        pointer.cast()
492    }
493
494    fn get_contents_ptr(&self) -> *mut PyClassObjectContents<T> {
495        unsafe {
496            Self::get_contents_of_obj(self as *const PyVariableClassObject<T> as *mut ffi::PyObject)
497        }
498        .cast()
499    }
500}
501
502#[cfg(Py_3_12)]
503impl<T: PyClass<Layout = Self>> PyClassObjectLayout<T> for PyVariableClassObject<T> {
504    /// Gets the offset of the contents from the start of the struct in bytes.
505    const CONTENTS_OFFSET: PyObjectOffset = PyObjectOffset::Relative(0);
506    const BASIC_SIZE: ffi::Py_ssize_t = {
507        let size = std::mem::size_of::<PyClassObjectContents<T>>();
508        assert!(size <= ffi::Py_ssize_t::MAX as usize);
509        // negative to indicate 'extra' space that cpython will allocate for us
510        -(size as ffi::Py_ssize_t)
511    };
512    const DICT_OFFSET: PyObjectOffset = {
513        let offset = offset_of!(PyClassObjectContents<T>, dict);
514        assert!(offset <= ffi::Py_ssize_t::MAX as usize);
515        PyObjectOffset::Relative(offset as _)
516    };
517    const WEAKLIST_OFFSET: PyObjectOffset = {
518        let offset = offset_of!(PyClassObjectContents<T>, weakref);
519        assert!(offset <= ffi::Py_ssize_t::MAX as usize);
520        PyObjectOffset::Relative(offset as _)
521    };
522
523    unsafe fn contents_uninit(
524        obj: *mut ffi::PyObject,
525    ) -> *mut MaybeUninit<PyClassObjectContents<T>> {
526        unsafe { Self::get_contents_of_obj(obj) }
527    }
528
529    fn get_ptr(&self) -> *mut T {
530        self.contents().value.get()
531    }
532
533    fn ob_base(&self) -> &<T::BaseType as PyClassBaseType>::LayoutAsBase {
534        &self.ob_base
535    }
536
537    fn contents(&self) -> &PyClassObjectContents<T> {
538        unsafe { self.get_contents_ptr().cast_const().as_ref() }
539            .expect("should be able to cast PyClassObjectContents pointer")
540    }
541
542    fn contents_mut(&mut self) -> &mut PyClassObjectContents<T> {
543        unsafe { self.get_contents_ptr().as_mut() }
544            .expect("should be able to cast PyClassObjectContents pointer")
545    }
546
547    fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker {
548        T::PyClassMutability::borrow_checker(self)
549    }
550}
551
552unsafe impl<T: PyClassImpl> PyLayout<T> for PyVariableClassObject<T> {}
553
554#[cfg(Py_3_12)]
555impl<T: PyClass<Layout = Self>> PyClassObjectBaseLayout<T> for PyVariableClassObject<T>
556where
557    <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectBaseLayout<T::BaseType>,
558{
559    fn ensure_threadsafe(&self) {
560        self.contents().thread_checker.ensure();
561        self.ob_base.ensure_threadsafe();
562    }
563    fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
564        if !self.contents().thread_checker.check() {
565            return Err(PyBorrowError { _private: () });
566        }
567        self.ob_base.check_threadsafe()
568    }
569    unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
570        // Safety: Python only calls tp_dealloc when no references to the object remain.
571        let class_object = unsafe { &mut *(slf.cast::<T::Layout>()) };
572        unsafe { class_object.contents_mut().dealloc(py, slf) };
573        unsafe { <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf) }
574    }
575}
576
577#[cfg(test)]
578#[cfg(feature = "macros")]
579mod tests {
580    use super::*;
581
582    use crate::prelude::*;
583    use crate::pyclass::boolean_struct::{False, True};
584
585    #[pyclass(crate = "crate", subclass)]
586    struct MutableBase;
587
588    #[pyclass(crate = "crate", extends = MutableBase, subclass)]
589    struct MutableChildOfMutableBase;
590
591    #[pyclass(crate = "crate", extends = MutableBase, frozen, subclass)]
592    struct ImmutableChildOfMutableBase;
593
594    #[pyclass(crate = "crate", extends = MutableChildOfMutableBase)]
595    struct MutableChildOfMutableChildOfMutableBase;
596
597    #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase)]
598    struct MutableChildOfImmutableChildOfMutableBase;
599
600    #[pyclass(crate = "crate", extends = MutableChildOfMutableBase, frozen)]
601    struct ImmutableChildOfMutableChildOfMutableBase;
602
603    #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase, frozen)]
604    struct ImmutableChildOfImmutableChildOfMutableBase;
605
606    #[pyclass(crate = "crate", frozen, subclass)]
607    struct ImmutableBase;
608
609    #[pyclass(crate = "crate", extends = ImmutableBase, subclass)]
610    struct MutableChildOfImmutableBase;
611
612    #[pyclass(crate = "crate", extends = ImmutableBase, frozen, subclass)]
613    struct ImmutableChildOfImmutableBase;
614
615    #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase)]
616    struct MutableChildOfMutableChildOfImmutableBase;
617
618    #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase)]
619    struct MutableChildOfImmutableChildOfImmutableBase;
620
621    #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase, frozen)]
622    struct ImmutableChildOfMutableChildOfImmutableBase;
623
624    #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase, frozen)]
625    struct ImmutableChildOfImmutableChildOfImmutableBase;
626
627    #[pyclass(crate = "crate", subclass)]
628    struct BaseWithData(#[allow(unused)] u64);
629
630    #[pyclass(crate = "crate", extends = BaseWithData)]
631    struct ChildWithData(#[allow(unused)] u64);
632
633    #[pyclass(crate = "crate", extends = BaseWithData)]
634    struct ChildWithoutData;
635
636    #[test]
637    fn test_inherited_size() {
638        let base_size = PyStaticClassObject::<BaseWithData>::BASIC_SIZE;
639        assert!(base_size > 0); // negative indicates variable sized
640        assert_eq!(
641            base_size,
642            PyStaticClassObject::<ChildWithoutData>::BASIC_SIZE
643        );
644        assert!(base_size < PyStaticClassObject::<ChildWithData>::BASIC_SIZE);
645    }
646
647    fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {}
648    fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {}
649    fn assert_mutable_with_mutable_ancestor<
650        T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>,
651    >() {
652    }
653    fn assert_immutable_with_mutable_ancestor<
654        T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>,
655    >() {
656    }
657
658    #[test]
659    fn test_inherited_mutability() {
660        // mutable base
661        assert_mutable::<MutableBase>();
662
663        // children of mutable base have a mutable ancestor
664        assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>();
665        assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>();
666
667        // grandchildren of mutable base have a mutable ancestor
668        assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>();
669        assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>();
670        assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>();
671        assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>();
672
673        // immutable base and children
674        assert_immutable::<ImmutableBase>();
675        assert_immutable::<ImmutableChildOfImmutableBase>();
676        assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>();
677
678        // mutable children of immutable at any level are simply mutable
679        assert_mutable::<MutableChildOfImmutableBase>();
680        assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>();
681
682        // children of the mutable child display this property
683        assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>();
684        assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>();
685    }
686
687    #[test]
688    fn test_mutable_borrow_prevents_further_borrows() {
689        Python::attach(|py| {
690            let mmm = Py::new(
691                py,
692                PyClassInitializer::from(MutableBase)
693                    .add_subclass(MutableChildOfMutableBase)
694                    .add_subclass(MutableChildOfMutableChildOfMutableBase),
695            )
696            .unwrap();
697
698            let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
699
700            let mmm_refmut = mmm_bound.borrow_mut();
701
702            // Cannot take any other mutable or immutable borrows whilst the object is borrowed mutably
703            assert!(mmm_bound
704                .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
705                .is_err());
706            assert!(mmm_bound
707                .extract::<PyRef<'_, MutableChildOfMutableBase>>()
708                .is_err());
709            assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_err());
710            assert!(mmm_bound
711                .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
712                .is_err());
713            assert!(mmm_bound
714                .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
715                .is_err());
716            assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
717
718            // With the borrow dropped, all other borrow attempts will succeed
719            drop(mmm_refmut);
720
721            assert!(mmm_bound
722                .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
723                .is_ok());
724            assert!(mmm_bound
725                .extract::<PyRef<'_, MutableChildOfMutableBase>>()
726                .is_ok());
727            assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
728            assert!(mmm_bound
729                .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
730                .is_ok());
731            assert!(mmm_bound
732                .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
733                .is_ok());
734            assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
735        })
736    }
737
738    #[test]
739    fn test_immutable_borrows_prevent_mutable_borrows() {
740        Python::attach(|py| {
741            let mmm = Py::new(
742                py,
743                PyClassInitializer::from(MutableBase)
744                    .add_subclass(MutableChildOfMutableBase)
745                    .add_subclass(MutableChildOfMutableChildOfMutableBase),
746            )
747            .unwrap();
748
749            let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
750
751            let mmm_refmut = mmm_bound.borrow();
752
753            // Further immutable borrows are ok
754            assert!(mmm_bound
755                .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
756                .is_ok());
757            assert!(mmm_bound
758                .extract::<PyRef<'_, MutableChildOfMutableBase>>()
759                .is_ok());
760            assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
761
762            // Further mutable borrows are not ok
763            assert!(mmm_bound
764                .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
765                .is_err());
766            assert!(mmm_bound
767                .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
768                .is_err());
769            assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
770
771            // With the borrow dropped, all mutable borrow attempts will succeed
772            drop(mmm_refmut);
773
774            assert!(mmm_bound
775                .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
776                .is_ok());
777            assert!(mmm_bound
778                .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
779                .is_ok());
780            assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
781        })
782    }
783
784    #[test]
785    #[cfg(not(target_arch = "wasm32"))]
786    fn test_thread_safety() {
787        #[crate::pyclass(crate = "crate")]
788        struct MyClass {
789            x: u64,
790        }
791
792        Python::attach(|py| {
793            let inst = Py::new(py, MyClass { x: 0 }).unwrap();
794
795            let total_modifications = py.detach(|| {
796                std::thread::scope(|s| {
797                    // Spawn a bunch of threads all racing to write to
798                    // the same instance of `MyClass`.
799                    let threads = (0..10)
800                        .map(|_| {
801                            s.spawn(|| {
802                                Python::attach(|py| {
803                                    // Each thread records its own view of how many writes it made
804                                    let mut local_modifications = 0;
805                                    for _ in 0..100 {
806                                        if let Ok(mut i) = inst.try_borrow_mut(py) {
807                                            i.x += 1;
808                                            local_modifications += 1;
809                                        }
810                                    }
811                                    local_modifications
812                                })
813                            })
814                        })
815                        .collect::<Vec<_>>();
816
817                    // Sum up the total number of writes made by all threads
818                    threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>()
819                })
820            });
821
822            // If the implementation is free of data races, the total number of writes
823            // should match the final value of `x`.
824            assert_eq!(total_modifications, inst.borrow(py).x);
825        });
826    }
827
828    #[test]
829    #[cfg(not(target_arch = "wasm32"))]
830    fn test_thread_safety_2() {
831        struct SyncUnsafeCell<T>(UnsafeCell<T>);
832        unsafe impl<T> Sync for SyncUnsafeCell<T> {}
833
834        impl<T> SyncUnsafeCell<T> {
835            fn get(&self) -> *mut T {
836                self.0.get()
837            }
838        }
839
840        let data = SyncUnsafeCell(UnsafeCell::new(0));
841        let data2 = SyncUnsafeCell(UnsafeCell::new(0));
842        let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)));
843
844        std::thread::scope(|s| {
845            s.spawn(|| {
846                for _ in 0..1_000_000 {
847                    if borrow_checker.try_borrow_mut().is_ok() {
848                        // thread 1 writes to both values during the mutable borrow
849                        unsafe { *data.get() += 1 };
850                        unsafe { *data2.get() += 1 };
851                        borrow_checker.release_borrow_mut();
852                    }
853                }
854            });
855
856            s.spawn(|| {
857                for _ in 0..1_000_000 {
858                    if borrow_checker.try_borrow().is_ok() {
859                        // if the borrow checker is working correctly, it should be impossible
860                        // for thread 2 to observe a difference in the two values
861                        assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() });
862                        borrow_checker.release_borrow();
863                    }
864                }
865            });
866        });
867    }
868}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here