1#![allow(missing_docs)]
2use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::mem::ManuallyDrop;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use crate::impl_::pyclass::{
10 PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef,
11};
12use crate::internal::get_slot::TP_FREE;
13use crate::type_object::{PyLayout, PySizedLayout};
14use crate::types::{PyType, PyTypeMethods};
15use crate::{ffi, PyClass, PyTypeInfo, Python};
16
17use super::{PyBorrowError, PyBorrowMutError};
18
19pub trait PyClassMutability {
20 type Storage: PyClassBorrowChecker;
23 type Checker: PyClassBorrowChecker;
27 type ImmutableChild: PyClassMutability;
28 type MutableChild: PyClassMutability;
29}
30
31pub struct ImmutableClass(());
32pub struct MutableClass(());
33pub struct ExtendsMutableAncestor<M: PyClassMutability>(PhantomData<M>);
34
35impl PyClassMutability for ImmutableClass {
36 type Storage = EmptySlot;
37 type Checker = EmptySlot;
38 type ImmutableChild = ImmutableClass;
39 type MutableChild = MutableClass;
40}
41
42impl PyClassMutability for MutableClass {
43 type Storage = BorrowChecker;
44 type Checker = BorrowChecker;
45 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
46 type MutableChild = ExtendsMutableAncestor<MutableClass>;
47}
48
49impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
50 type Storage = EmptySlot;
51 type Checker = BorrowChecker;
52 type ImmutableChild = ExtendsMutableAncestor<ImmutableClass>;
53 type MutableChild = ExtendsMutableAncestor<MutableClass>;
54}
55
56#[derive(Debug)]
57struct BorrowFlag(AtomicUsize);
58
59impl BorrowFlag {
60 pub(crate) const UNUSED: usize = 0;
61 const HAS_MUTABLE_BORROW: usize = usize::MAX;
62 fn increment(&self) -> Result<(), PyBorrowError> {
63 let mut value = self.0.load(Ordering::Relaxed);
65 loop {
66 if value == BorrowFlag::HAS_MUTABLE_BORROW {
67 return Err(PyBorrowError { _private: () });
68 }
69 match self.0.compare_exchange(
70 value,
73 value + 1,
74 Ordering::AcqRel,
77 Ordering::Relaxed,
79 ) {
80 Ok(..) => {
81 break Ok(());
82 }
83 Err(changed_value) => {
84 value = changed_value;
86 }
87 }
88 }
89 }
90 fn decrement(&self) {
91 self.0.fetch_sub(1, Ordering::Release);
93 }
94}
95
96pub struct EmptySlot(());
97pub struct BorrowChecker(BorrowFlag);
98
99pub trait PyClassBorrowChecker {
100 fn new() -> Self;
102
103 fn try_borrow(&self) -> Result<(), PyBorrowError>;
105
106 fn release_borrow(&self);
108 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError>;
110 fn release_borrow_mut(&self);
112}
113
114impl PyClassBorrowChecker for EmptySlot {
115 #[inline]
116 fn new() -> Self {
117 EmptySlot(())
118 }
119
120 #[inline]
121 fn try_borrow(&self) -> Result<(), PyBorrowError> {
122 Ok(())
123 }
124
125 #[inline]
126 fn release_borrow(&self) {}
127
128 #[inline]
129 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
130 unreachable!()
131 }
132
133 #[inline]
134 fn release_borrow_mut(&self) {
135 unreachable!()
136 }
137}
138
139impl PyClassBorrowChecker for BorrowChecker {
140 #[inline]
141 fn new() -> Self {
142 Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
143 }
144
145 fn try_borrow(&self) -> Result<(), PyBorrowError> {
146 self.0.increment()
147 }
148
149 fn release_borrow(&self) {
150 self.0.decrement();
151 }
152
153 fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
154 let flag = &self.0;
155 match flag.0.compare_exchange(
156 BorrowFlag::UNUSED,
159 BorrowFlag::HAS_MUTABLE_BORROW,
160 Ordering::AcqRel,
163 Ordering::Relaxed,
166 ) {
167 Ok(..) => Ok(()),
168 Err(..) => Err(PyBorrowMutError { _private: () }),
169 }
170 }
171
172 fn release_borrow_mut(&self) {
173 self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
174 }
175}
176
177pub trait GetBorrowChecker<T: PyClassImpl> {
178 fn borrow_checker(
179 class_object: &PyClassObject<T>,
180 ) -> &<T::PyClassMutability as PyClassMutability>::Checker;
181}
182
183impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for MutableClass {
184 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
185 &class_object.contents.borrow_checker
186 }
187}
188
189impl<T: PyClassImpl<PyClassMutability = Self>> GetBorrowChecker<T> for ImmutableClass {
190 fn borrow_checker(class_object: &PyClassObject<T>) -> &EmptySlot {
191 &class_object.contents.borrow_checker
192 }
193}
194
195impl<T: PyClassImpl<PyClassMutability = Self>, M: PyClassMutability> GetBorrowChecker<T>
196 for ExtendsMutableAncestor<M>
197where
198 T::BaseType: PyClassImpl + PyClassBaseType<LayoutAsBase = PyClassObject<T::BaseType>>,
199 <T::BaseType as PyClassImpl>::PyClassMutability: PyClassMutability<Checker = BorrowChecker>,
200{
201 fn borrow_checker(class_object: &PyClassObject<T>) -> &BorrowChecker {
202 <<T::BaseType as PyClassImpl>::PyClassMutability as GetBorrowChecker<T::BaseType>>::borrow_checker(&class_object.ob_base)
203 }
204}
205
206#[doc(hidden)]
208#[repr(C)]
209pub struct PyClassObjectBase<T> {
210 ob_base: T,
211}
212
213unsafe impl<T, U> PyLayout<T> for PyClassObjectBase<U> where U: PySizedLayout<T> {}
214
215#[doc(hidden)]
216pub trait PyClassObjectLayout<T>: PyLayout<T> {
217 fn ensure_threadsafe(&self);
218 fn check_threadsafe(&self) -> Result<(), PyBorrowError>;
219 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject);
224}
225
226impl<T, U> PyClassObjectLayout<T> for PyClassObjectBase<U>
227where
228 U: PySizedLayout<T>,
229 T: PyTypeInfo,
230{
231 fn ensure_threadsafe(&self) {}
232 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
233 Ok(())
234 }
235 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
236 unsafe {
237 let type_obj = T::type_object(py);
240 let type_ptr = type_obj.as_type_ptr();
241 let actual_type = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(slf));
242
243 if type_ptr == std::ptr::addr_of_mut!(ffi::PyBaseObject_Type) {
245 let tp_free = actual_type
246 .get_slot(TP_FREE)
247 .expect("PyBaseObject_Type should have tp_free");
248 return tp_free(slf.cast());
249 }
250
251 #[cfg(not(Py_LIMITED_API))]
253 {
254 if let Some(dealloc) = (*type_ptr).tp_dealloc {
256 #[cfg(not(any(Py_3_11, PyPy)))]
260 if ffi::PyType_FastSubclass(type_ptr, ffi::Py_TPFLAGS_BASE_EXC_SUBCLASS) == 1 {
261 ffi::PyObject_GC_Track(slf.cast());
262 }
263 dealloc(slf);
264 } else {
265 (*actual_type.as_type_ptr())
266 .tp_free
267 .expect("type missing tp_free")(slf.cast());
268 }
269 }
270
271 #[cfg(Py_LIMITED_API)]
272 unreachable!("subclassing native types is not possible with the `abi3` feature");
273 }
274 }
275}
276
277#[repr(C)]
279pub struct PyClassObject<T: PyClassImpl> {
280 pub(crate) ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
281 pub(crate) contents: PyClassObjectContents<T>,
282}
283
284#[repr(C)]
285pub(crate) struct PyClassObjectContents<T: PyClassImpl> {
286 pub(crate) value: ManuallyDrop<UnsafeCell<T>>,
287 pub(crate) borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage,
288 pub(crate) thread_checker: T::ThreadChecker,
289 pub(crate) dict: T::Dict,
290 pub(crate) weakref: T::WeakRef,
291}
292
293impl<T: PyClassImpl> PyClassObject<T> {
294 pub(crate) fn get_ptr(&self) -> *mut T {
295 self.contents.value.get()
296 }
297
298 pub(crate) fn dict_offset() -> ffi::Py_ssize_t {
300 use memoffset::offset_of;
301
302 let offset =
303 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, dict);
304
305 #[allow(clippy::useless_conversion)]
307 offset.try_into().expect("offset should fit in Py_ssize_t")
308 }
309
310 pub(crate) fn weaklist_offset() -> ffi::Py_ssize_t {
312 use memoffset::offset_of;
313
314 let offset =
315 offset_of!(PyClassObject<T>, contents) + offset_of!(PyClassObjectContents<T>, weakref);
316
317 #[allow(clippy::useless_conversion)]
319 offset.try_into().expect("offset should fit in Py_ssize_t")
320 }
321}
322
323impl<T: PyClassImpl> PyClassObject<T> {
324 pub(crate) fn borrow_checker(&self) -> &<T::PyClassMutability as PyClassMutability>::Checker {
325 T::PyClassMutability::borrow_checker(self)
326 }
327}
328
329unsafe impl<T: PyClassImpl> PyLayout<T> for PyClassObject<T> {}
330impl<T: PyClass> PySizedLayout<T> for PyClassObject<T> {}
331
332impl<T: PyClassImpl> PyClassObjectLayout<T> for PyClassObject<T>
333where
334 <T::BaseType as PyClassBaseType>::LayoutAsBase: PyClassObjectLayout<T::BaseType>,
335{
336 fn ensure_threadsafe(&self) {
337 self.contents.thread_checker.ensure();
338 self.ob_base.ensure_threadsafe();
339 }
340 fn check_threadsafe(&self) -> Result<(), PyBorrowError> {
341 if !self.contents.thread_checker.check() {
342 return Err(PyBorrowError { _private: () });
343 }
344 self.ob_base.check_threadsafe()
345 }
346 unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) {
347 let class_object = unsafe { &mut *(slf.cast::<PyClassObject<T>>()) };
349 if class_object.contents.thread_checker.can_drop(py) {
350 unsafe { ManuallyDrop::drop(&mut class_object.contents.value) };
351 }
352 class_object.contents.dict.clear_dict(py);
353 unsafe {
354 class_object.contents.weakref.clear_weakrefs(slf, py);
355 <T::BaseType as PyClassBaseType>::LayoutAsBase::tp_dealloc(py, slf)
356 }
357 }
358}
359
360#[cfg(test)]
361#[cfg(feature = "macros")]
362mod tests {
363 use super::*;
364
365 use crate::prelude::*;
366 use crate::pyclass::boolean_struct::{False, True};
367
368 #[pyclass(crate = "crate", subclass)]
369 struct MutableBase;
370
371 #[pyclass(crate = "crate", extends = MutableBase, subclass)]
372 struct MutableChildOfMutableBase;
373
374 #[pyclass(crate = "crate", extends = MutableBase, frozen, subclass)]
375 struct ImmutableChildOfMutableBase;
376
377 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase)]
378 struct MutableChildOfMutableChildOfMutableBase;
379
380 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase)]
381 struct MutableChildOfImmutableChildOfMutableBase;
382
383 #[pyclass(crate = "crate", extends = MutableChildOfMutableBase, frozen)]
384 struct ImmutableChildOfMutableChildOfMutableBase;
385
386 #[pyclass(crate = "crate", extends = ImmutableChildOfMutableBase, frozen)]
387 struct ImmutableChildOfImmutableChildOfMutableBase;
388
389 #[pyclass(crate = "crate", frozen, subclass)]
390 struct ImmutableBase;
391
392 #[pyclass(crate = "crate", extends = ImmutableBase, subclass)]
393 struct MutableChildOfImmutableBase;
394
395 #[pyclass(crate = "crate", extends = ImmutableBase, frozen, subclass)]
396 struct ImmutableChildOfImmutableBase;
397
398 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase)]
399 struct MutableChildOfMutableChildOfImmutableBase;
400
401 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase)]
402 struct MutableChildOfImmutableChildOfImmutableBase;
403
404 #[pyclass(crate = "crate", extends = MutableChildOfImmutableBase, frozen)]
405 struct ImmutableChildOfMutableChildOfImmutableBase;
406
407 #[pyclass(crate = "crate", extends = ImmutableChildOfImmutableBase, frozen)]
408 struct ImmutableChildOfImmutableChildOfImmutableBase;
409
410 fn assert_mutable<T: PyClass<Frozen = False, PyClassMutability = MutableClass>>() {}
411 fn assert_immutable<T: PyClass<Frozen = True, PyClassMutability = ImmutableClass>>() {}
412 fn assert_mutable_with_mutable_ancestor<
413 T: PyClass<Frozen = False, PyClassMutability = ExtendsMutableAncestor<MutableClass>>,
414 >() {
415 }
416 fn assert_immutable_with_mutable_ancestor<
417 T: PyClass<Frozen = True, PyClassMutability = ExtendsMutableAncestor<ImmutableClass>>,
418 >() {
419 }
420
421 #[test]
422 fn test_inherited_mutability() {
423 assert_mutable::<MutableBase>();
425
426 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableBase>();
428 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableBase>();
429
430 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfMutableBase>();
432 assert_mutable_with_mutable_ancestor::<MutableChildOfImmutableChildOfMutableBase>();
433 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfMutableBase>();
434 assert_immutable_with_mutable_ancestor::<ImmutableChildOfImmutableChildOfMutableBase>();
435
436 assert_immutable::<ImmutableBase>();
438 assert_immutable::<ImmutableChildOfImmutableBase>();
439 assert_immutable::<ImmutableChildOfImmutableChildOfImmutableBase>();
440
441 assert_mutable::<MutableChildOfImmutableBase>();
443 assert_mutable::<MutableChildOfImmutableChildOfImmutableBase>();
444
445 assert_mutable_with_mutable_ancestor::<MutableChildOfMutableChildOfImmutableBase>();
447 assert_immutable_with_mutable_ancestor::<ImmutableChildOfMutableChildOfImmutableBase>();
448 }
449
450 #[test]
451 fn test_mutable_borrow_prevents_further_borrows() {
452 Python::with_gil(|py| {
453 let mmm = Py::new(
454 py,
455 PyClassInitializer::from(MutableBase)
456 .add_subclass(MutableChildOfMutableBase)
457 .add_subclass(MutableChildOfMutableChildOfMutableBase),
458 )
459 .unwrap();
460
461 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
462
463 let mmm_refmut = mmm_bound.borrow_mut();
464
465 assert!(mmm_bound
467 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
468 .is_err());
469 assert!(mmm_bound
470 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
471 .is_err());
472 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_err());
473 assert!(mmm_bound
474 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
475 .is_err());
476 assert!(mmm_bound
477 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
478 .is_err());
479 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
480
481 drop(mmm_refmut);
483
484 assert!(mmm_bound
485 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
486 .is_ok());
487 assert!(mmm_bound
488 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
489 .is_ok());
490 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
491 assert!(mmm_bound
492 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
493 .is_ok());
494 assert!(mmm_bound
495 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
496 .is_ok());
497 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
498 })
499 }
500
501 #[test]
502 fn test_immutable_borrows_prevent_mutable_borrows() {
503 Python::with_gil(|py| {
504 let mmm = Py::new(
505 py,
506 PyClassInitializer::from(MutableBase)
507 .add_subclass(MutableChildOfMutableBase)
508 .add_subclass(MutableChildOfMutableChildOfMutableBase),
509 )
510 .unwrap();
511
512 let mmm_bound: &Bound<'_, MutableChildOfMutableChildOfMutableBase> = mmm.bind(py);
513
514 let mmm_refmut = mmm_bound.borrow();
515
516 assert!(mmm_bound
518 .extract::<PyRef<'_, MutableChildOfMutableChildOfMutableBase>>()
519 .is_ok());
520 assert!(mmm_bound
521 .extract::<PyRef<'_, MutableChildOfMutableBase>>()
522 .is_ok());
523 assert!(mmm_bound.extract::<PyRef<'_, MutableBase>>().is_ok());
524
525 assert!(mmm_bound
527 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
528 .is_err());
529 assert!(mmm_bound
530 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
531 .is_err());
532 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_err());
533
534 drop(mmm_refmut);
536
537 assert!(mmm_bound
538 .extract::<PyRefMut<'_, MutableChildOfMutableChildOfMutableBase>>()
539 .is_ok());
540 assert!(mmm_bound
541 .extract::<PyRefMut<'_, MutableChildOfMutableBase>>()
542 .is_ok());
543 assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
544 })
545 }
546
547 #[test]
548 #[cfg(not(target_arch = "wasm32"))]
549 fn test_thread_safety() {
550 #[crate::pyclass(crate = "crate")]
551 struct MyClass {
552 x: u64,
553 }
554
555 Python::with_gil(|py| {
556 let inst = Py::new(py, MyClass { x: 0 }).unwrap();
557
558 let total_modifications = py.allow_threads(|| {
559 std::thread::scope(|s| {
560 let threads = (0..10)
563 .map(|_| {
564 s.spawn(|| {
565 Python::with_gil(|py| {
566 let mut local_modifications = 0;
568 for _ in 0..100 {
569 if let Ok(mut i) = inst.try_borrow_mut(py) {
570 i.x += 1;
571 local_modifications += 1;
572 }
573 }
574 local_modifications
575 })
576 })
577 })
578 .collect::<Vec<_>>();
579
580 threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>()
582 })
583 });
584
585 assert_eq!(total_modifications, inst.borrow(py).x);
588 });
589 }
590
591 #[test]
592 #[cfg(not(target_arch = "wasm32"))]
593 fn test_thread_safety_2() {
594 struct SyncUnsafeCell<T>(UnsafeCell<T>);
595 unsafe impl<T> Sync for SyncUnsafeCell<T> {}
596
597 impl<T> SyncUnsafeCell<T> {
598 fn get(&self) -> *mut T {
599 self.0.get()
600 }
601 }
602
603 let data = SyncUnsafeCell(UnsafeCell::new(0));
604 let data2 = SyncUnsafeCell(UnsafeCell::new(0));
605 let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)));
606
607 std::thread::scope(|s| {
608 s.spawn(|| {
609 for _ in 0..1_000_000 {
610 if borrow_checker.try_borrow_mut().is_ok() {
611 unsafe { *data.get() += 1 };
613 unsafe { *data2.get() += 1 };
614 borrow_checker.release_borrow_mut();
615 }
616 }
617 });
618
619 s.spawn(|| {
620 for _ in 0..1_000_000 {
621 if borrow_checker.try_borrow().is_ok() {
622 assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() });
625 borrow_checker.release_borrow();
626 }
627 }
628 });
629 });
630 }
631}