pyo3/conversions/std/
array.rs

1use crate::conversion::IntoPyObject;
2use crate::instance::Bound;
3use crate::types::any::PyAnyMethods;
4use crate::types::PySequence;
5use crate::{err::DowncastError, ffi, FromPyObject, PyAny, PyResult, Python};
6use crate::{exceptions, PyErr};
7
8impl<'py, T, const N: usize> IntoPyObject<'py> for [T; N]
9where
10    T: IntoPyObject<'py>,
11{
12    type Target = PyAny;
13    type Output = Bound<'py, Self::Target>;
14    type Error = PyErr;
15
16    /// Turns [`[u8; N]`](std::array) into [`PyBytes`], all other `T`s will be turned into a [`PyList`]
17    ///
18    /// [`PyBytes`]: crate::types::PyBytes
19    /// [`PyList`]: crate::types::PyList
20    #[inline]
21    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
22        T::owned_sequence_into_pyobject(self, py, crate::conversion::private::Token)
23    }
24}
25
26impl<'a, 'py, T, const N: usize> IntoPyObject<'py> for &'a [T; N]
27where
28    &'a T: IntoPyObject<'py>,
29{
30    type Target = PyAny;
31    type Output = Bound<'py, Self::Target>;
32    type Error = PyErr;
33
34    #[inline]
35    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
36        self.as_slice().into_pyobject(py)
37    }
38}
39
40impl<'py, T, const N: usize> FromPyObject<'py> for [T; N]
41where
42    T: FromPyObject<'py>,
43{
44    fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
45        create_array_from_obj(obj)
46    }
47}
48
49fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]>
50where
51    T: FromPyObject<'py>,
52{
53    // Types that pass `PySequence_Check` usually implement enough of the sequence protocol
54    // to support this function and if not, we will only fail extraction safely.
55    let seq = unsafe {
56        if ffi::PySequence_Check(obj.as_ptr()) != 0 {
57            obj.downcast_unchecked::<PySequence>()
58        } else {
59            return Err(DowncastError::new(obj, "Sequence").into());
60        }
61    };
62    let seq_len = seq.len()?;
63    if seq_len != N {
64        return Err(invalid_sequence_length(N, seq_len));
65    }
66    array_try_from_fn(|idx| seq.get_item(idx).and_then(|any| any.extract()))
67}
68
69// TODO use std::array::try_from_fn, if that stabilises:
70// (https://github.com/rust-lang/rust/issues/89379)
71fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
72where
73    F: FnMut(usize) -> Result<T, E>,
74{
75    // Helper to safely create arrays since the standard library doesn't
76    // provide one yet. Shouldn't be necessary in the future.
77    struct ArrayGuard<T, const N: usize> {
78        dst: *mut T,
79        initialized: usize,
80    }
81
82    impl<T, const N: usize> Drop for ArrayGuard<T, N> {
83        fn drop(&mut self) {
84            debug_assert!(self.initialized <= N);
85            let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
86            unsafe {
87                core::ptr::drop_in_place(initialized_part);
88            }
89        }
90    }
91
92    // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
93    // APIs which would make this easier.
94    let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
95    let mut guard: ArrayGuard<T, N> = ArrayGuard {
96        dst: array.as_mut_ptr() as _,
97        initialized: 0,
98    };
99    unsafe {
100        let mut value_ptr = array.as_mut_ptr() as *mut T;
101        for i in 0..N {
102            core::ptr::write(value_ptr, cb(i)?);
103            value_ptr = value_ptr.offset(1);
104            guard.initialized += 1;
105        }
106        core::mem::forget(guard);
107        Ok(array.assume_init())
108    }
109}
110
111fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
112    exceptions::PyValueError::new_err(format!(
113        "expected a sequence of length {} (got {})",
114        expected, actual
115    ))
116}
117
118#[cfg(test)]
119mod tests {
120    use std::{
121        panic,
122        sync::atomic::{AtomicUsize, Ordering},
123    };
124
125    use crate::{
126        conversion::IntoPyObject,
127        ffi,
128        types::{any::PyAnyMethods, PyBytes, PyBytesMethods},
129    };
130    use crate::{types::PyList, PyResult, Python};
131
132    #[test]
133    fn array_try_from_fn() {
134        static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
135        struct CountDrop;
136        impl Drop for CountDrop {
137            fn drop(&mut self) {
138                DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
139            }
140        }
141        let _ = catch_unwind_silent(move || {
142            let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
143                #[allow(clippy::manual_assert)]
144                if idx == 2 {
145                    panic!("peek a boo");
146                }
147                Ok(CountDrop)
148            });
149        });
150        assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
151    }
152
153    #[test]
154    fn test_extract_bytearray_to_array() {
155        Python::with_gil(|py| {
156            let v: [u8; 33] = py
157                .eval(
158                    ffi::c_str!("bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')"),
159                    None,
160                    None,
161                )
162                .unwrap()
163                .extract()
164                .unwrap();
165            assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
166        })
167    }
168
169    #[test]
170    fn test_extract_small_bytearray_to_array() {
171        Python::with_gil(|py| {
172            let v: [u8; 3] = py
173                .eval(ffi::c_str!("bytearray(b'abc')"), None, None)
174                .unwrap()
175                .extract()
176                .unwrap();
177            assert!(&v == b"abc");
178        });
179    }
180    #[test]
181    fn test_into_pyobject_array_conversion() {
182        Python::with_gil(|py| {
183            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
184            let pyobject = array.into_pyobject(py).unwrap();
185            let pylist = pyobject.downcast::<PyList>().unwrap();
186            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
187            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
188            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
189            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
190        });
191    }
192
193    #[test]
194    fn test_extract_invalid_sequence_length() {
195        Python::with_gil(|py| {
196            let v: PyResult<[u8; 3]> = py
197                .eval(ffi::c_str!("bytearray(b'abcdefg')"), None, None)
198                .unwrap()
199                .extract();
200            assert_eq!(
201                v.unwrap_err().to_string(),
202                "ValueError: expected a sequence of length 3 (got 7)"
203            );
204        })
205    }
206
207    #[test]
208    fn test_intopyobject_array_conversion() {
209        Python::with_gil(|py| {
210            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
211            let pylist = array
212                .into_pyobject(py)
213                .unwrap()
214                .downcast_into::<PyList>()
215                .unwrap();
216
217            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
218            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
219            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
220            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
221        });
222    }
223
224    #[test]
225    fn test_array_intopyobject_impl() {
226        Python::with_gil(|py| {
227            let bytes: [u8; 6] = *b"foobar";
228            let obj = bytes.into_pyobject(py).unwrap();
229            assert!(obj.is_instance_of::<PyBytes>());
230            let obj = obj.downcast_into::<PyBytes>().unwrap();
231            assert_eq!(obj.as_bytes(), &bytes);
232
233            let nums: [u16; 4] = [0, 1, 2, 3];
234            let obj = nums.into_pyobject(py).unwrap();
235            assert!(obj.is_instance_of::<PyList>());
236        });
237    }
238
239    #[test]
240    fn test_extract_non_iterable_to_array() {
241        Python::with_gil(|py| {
242            let v = py.eval(ffi::c_str!("42"), None, None).unwrap();
243            v.extract::<i32>().unwrap();
244            v.extract::<[i32; 1]>().unwrap_err();
245        });
246    }
247
248    #[cfg(feature = "macros")]
249    #[test]
250    fn test_pyclass_intopy_array_conversion() {
251        #[crate::pyclass(crate = "crate")]
252        struct Foo;
253
254        Python::with_gil(|py| {
255            let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
256            let list = array
257                .into_pyobject(py)
258                .unwrap()
259                .downcast_into::<PyList>()
260                .unwrap();
261            let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap();
262        });
263    }
264
265    // https://stackoverflow.com/a/59211505
266    fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
267    where
268        F: FnOnce() -> R + panic::UnwindSafe,
269    {
270        let prev_hook = panic::take_hook();
271        panic::set_hook(Box::new(|_| {}));
272        let result = panic::catch_unwind(f);
273        panic::set_hook(prev_hook);
274        result
275    }
276}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here