pyo3/conversions/std/
array.rs

1use crate::conversion::IntoPyObject;
2use crate::instance::Bound;
3use crate::types::any::PyAnyMethods;
4use crate::types::PySequence;
5use crate::{err::DowncastError, ffi, FromPyObject, PyAny, PyResult, Python};
6use crate::{exceptions, PyErr};
7
8impl<'py, T, const N: usize> IntoPyObject<'py> for [T; N]
9where
10    T: IntoPyObject<'py>,
11{
12    type Target = PyAny;
13    type Output = Bound<'py, Self::Target>;
14    type Error = PyErr;
15
16    /// Turns [`[u8; N]`](std::array) into [`PyBytes`], all other `T`s will be turned into a [`PyList`]
17    ///
18    /// [`PyBytes`]: crate::types::PyBytes
19    /// [`PyList`]: crate::types::PyList
20    #[inline]
21    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
22        T::owned_sequence_into_pyobject(self, py, crate::conversion::private::Token)
23    }
24}
25
26impl<'a, 'py, T, const N: usize> IntoPyObject<'py> for &'a [T; N]
27where
28    &'a T: IntoPyObject<'py>,
29{
30    type Target = PyAny;
31    type Output = Bound<'py, Self::Target>;
32    type Error = PyErr;
33
34    #[inline]
35    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
36        self.as_slice().into_pyobject(py)
37    }
38}
39
40impl<'py, T, const N: usize> FromPyObject<'py> for [T; N]
41where
42    T: FromPyObject<'py>,
43{
44    fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
45        create_array_from_obj(obj)
46    }
47}
48
49fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]>
50where
51    T: FromPyObject<'py>,
52{
53    // Types that pass `PySequence_Check` usually implement enough of the sequence protocol
54    // to support this function and if not, we will only fail extraction safely.
55    let seq = unsafe {
56        if ffi::PySequence_Check(obj.as_ptr()) != 0 {
57            obj.downcast_unchecked::<PySequence>()
58        } else {
59            return Err(DowncastError::new(obj, "Sequence").into());
60        }
61    };
62    let seq_len = seq.len()?;
63    if seq_len != N {
64        return Err(invalid_sequence_length(N, seq_len));
65    }
66    array_try_from_fn(|idx| seq.get_item(idx).and_then(|any| any.extract()))
67}
68
69// TODO use std::array::try_from_fn, if that stabilises:
70// (https://github.com/rust-lang/rust/issues/89379)
71fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
72where
73    F: FnMut(usize) -> Result<T, E>,
74{
75    // Helper to safely create arrays since the standard library doesn't
76    // provide one yet. Shouldn't be necessary in the future.
77    struct ArrayGuard<T, const N: usize> {
78        dst: *mut T,
79        initialized: usize,
80    }
81
82    impl<T, const N: usize> Drop for ArrayGuard<T, N> {
83        fn drop(&mut self) {
84            debug_assert!(self.initialized <= N);
85            let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
86            unsafe {
87                core::ptr::drop_in_place(initialized_part);
88            }
89        }
90    }
91
92    // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
93    // APIs which would make this easier.
94    let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
95    let mut guard: ArrayGuard<T, N> = ArrayGuard {
96        dst: array.as_mut_ptr() as _,
97        initialized: 0,
98    };
99    unsafe {
100        let mut value_ptr = array.as_mut_ptr() as *mut T;
101        for i in 0..N {
102            core::ptr::write(value_ptr, cb(i)?);
103            value_ptr = value_ptr.offset(1);
104            guard.initialized += 1;
105        }
106        core::mem::forget(guard);
107        Ok(array.assume_init())
108    }
109}
110
111fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
112    exceptions::PyValueError::new_err(format!(
113        "expected a sequence of length {expected} (got {actual})"
114    ))
115}
116
117#[cfg(test)]
118mod tests {
119    use std::{
120        panic,
121        sync::atomic::{AtomicUsize, Ordering},
122    };
123
124    use crate::{
125        conversion::IntoPyObject,
126        ffi,
127        types::{any::PyAnyMethods, PyBytes, PyBytesMethods},
128    };
129    use crate::{types::PyList, PyResult, Python};
130
131    #[test]
132    fn array_try_from_fn() {
133        static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
134        struct CountDrop;
135        impl Drop for CountDrop {
136            fn drop(&mut self) {
137                DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
138            }
139        }
140        let _ = catch_unwind_silent(move || {
141            let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
142                #[allow(clippy::manual_assert)]
143                if idx == 2 {
144                    panic!("peek a boo");
145                }
146                Ok(CountDrop)
147            });
148        });
149        assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
150    }
151
152    #[test]
153    fn test_extract_bytearray_to_array() {
154        Python::attach(|py| {
155            let v: [u8; 33] = py
156                .eval(
157                    ffi::c_str!("bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')"),
158                    None,
159                    None,
160                )
161                .unwrap()
162                .extract()
163                .unwrap();
164            assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
165        })
166    }
167
168    #[test]
169    fn test_extract_small_bytearray_to_array() {
170        Python::attach(|py| {
171            let v: [u8; 3] = py
172                .eval(ffi::c_str!("bytearray(b'abc')"), None, None)
173                .unwrap()
174                .extract()
175                .unwrap();
176            assert!(&v == b"abc");
177        });
178    }
179    #[test]
180    fn test_into_pyobject_array_conversion() {
181        Python::attach(|py| {
182            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
183            let pyobject = array.into_pyobject(py).unwrap();
184            let pylist = pyobject.downcast::<PyList>().unwrap();
185            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
186            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
187            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
188            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
189        });
190    }
191
192    #[test]
193    fn test_extract_invalid_sequence_length() {
194        Python::attach(|py| {
195            let v: PyResult<[u8; 3]> = py
196                .eval(ffi::c_str!("bytearray(b'abcdefg')"), None, None)
197                .unwrap()
198                .extract();
199            assert_eq!(
200                v.unwrap_err().to_string(),
201                "ValueError: expected a sequence of length 3 (got 7)"
202            );
203        })
204    }
205
206    #[test]
207    fn test_intopyobject_array_conversion() {
208        Python::attach(|py| {
209            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
210            let pylist = array
211                .into_pyobject(py)
212                .unwrap()
213                .downcast_into::<PyList>()
214                .unwrap();
215
216            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
217            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
218            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
219            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
220        });
221    }
222
223    #[test]
224    fn test_array_intopyobject_impl() {
225        Python::attach(|py| {
226            let bytes: [u8; 6] = *b"foobar";
227            let obj = bytes.into_pyobject(py).unwrap();
228            assert!(obj.is_instance_of::<PyBytes>());
229            let obj = obj.downcast_into::<PyBytes>().unwrap();
230            assert_eq!(obj.as_bytes(), &bytes);
231
232            let nums: [u16; 4] = [0, 1, 2, 3];
233            let obj = nums.into_pyobject(py).unwrap();
234            assert!(obj.is_instance_of::<PyList>());
235        });
236    }
237
238    #[test]
239    fn test_extract_non_iterable_to_array() {
240        Python::attach(|py| {
241            let v = py.eval(ffi::c_str!("42"), None, None).unwrap();
242            v.extract::<i32>().unwrap();
243            v.extract::<[i32; 1]>().unwrap_err();
244        });
245    }
246
247    #[cfg(feature = "macros")]
248    #[test]
249    fn test_pyclass_intopy_array_conversion() {
250        #[crate::pyclass(crate = "crate")]
251        struct Foo;
252
253        Python::attach(|py| {
254            let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
255            let list = array
256                .into_pyobject(py)
257                .unwrap()
258                .downcast_into::<PyList>()
259                .unwrap();
260            let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap();
261        });
262    }
263
264    // https://stackoverflow.com/a/59211505
265    fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
266    where
267        F: FnOnce() -> R + panic::UnwindSafe,
268    {
269        let prev_hook = panic::take_hook();
270        panic::set_hook(Box::new(|_| {}));
271        let result = panic::catch_unwind(f);
272        panic::set_hook(prev_hook);
273        result
274    }
275}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here