1use crate::conversion::IntoPyObject;
2use crate::instance::Bound;
3use crate::types::any::PyAnyMethods;
4use crate::types::PySequence;
5use crate::{err::DowncastError, ffi, FromPyObject, PyAny, PyResult, Python};
6use crate::{exceptions, PyErr};
7
8impl<'py, T, const N: usize> IntoPyObject<'py> for [T; N]
9where
10 T: IntoPyObject<'py>,
11{
12 type Target = PyAny;
13 type Output = Bound<'py, Self::Target>;
14 type Error = PyErr;
15
16 #[inline]
21 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
22 T::owned_sequence_into_pyobject(self, py, crate::conversion::private::Token)
23 }
24}
25
26impl<'a, 'py, T, const N: usize> IntoPyObject<'py> for &'a [T; N]
27where
28 &'a T: IntoPyObject<'py>,
29{
30 type Target = PyAny;
31 type Output = Bound<'py, Self::Target>;
32 type Error = PyErr;
33
34 #[inline]
35 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
36 self.as_slice().into_pyobject(py)
37 }
38}
39
40impl<'py, T, const N: usize> FromPyObject<'py> for [T; N]
41where
42 T: FromPyObject<'py>,
43{
44 fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
45 create_array_from_obj(obj)
46 }
47}
48
49fn create_array_from_obj<'py, T, const N: usize>(obj: &Bound<'py, PyAny>) -> PyResult<[T; N]>
50where
51 T: FromPyObject<'py>,
52{
53 let seq = unsafe {
56 if ffi::PySequence_Check(obj.as_ptr()) != 0 {
57 obj.downcast_unchecked::<PySequence>()
58 } else {
59 return Err(DowncastError::new(obj, "Sequence").into());
60 }
61 };
62 let seq_len = seq.len()?;
63 if seq_len != N {
64 return Err(invalid_sequence_length(N, seq_len));
65 }
66 array_try_from_fn(|idx| seq.get_item(idx).and_then(|any| any.extract()))
67}
68
69fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
72where
73 F: FnMut(usize) -> Result<T, E>,
74{
75 struct ArrayGuard<T, const N: usize> {
78 dst: *mut T,
79 initialized: usize,
80 }
81
82 impl<T, const N: usize> Drop for ArrayGuard<T, N> {
83 fn drop(&mut self) {
84 debug_assert!(self.initialized <= N);
85 let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
86 unsafe {
87 core::ptr::drop_in_place(initialized_part);
88 }
89 }
90 }
91
92 let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
95 let mut guard: ArrayGuard<T, N> = ArrayGuard {
96 dst: array.as_mut_ptr() as _,
97 initialized: 0,
98 };
99 unsafe {
100 let mut value_ptr = array.as_mut_ptr() as *mut T;
101 for i in 0..N {
102 core::ptr::write(value_ptr, cb(i)?);
103 value_ptr = value_ptr.offset(1);
104 guard.initialized += 1;
105 }
106 core::mem::forget(guard);
107 Ok(array.assume_init())
108 }
109}
110
111fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
112 exceptions::PyValueError::new_err(format!(
113 "expected a sequence of length {} (got {})",
114 expected, actual
115 ))
116}
117
118#[cfg(test)]
119mod tests {
120 use std::{
121 panic,
122 sync::atomic::{AtomicUsize, Ordering},
123 };
124
125 use crate::{
126 conversion::IntoPyObject,
127 ffi,
128 types::{any::PyAnyMethods, PyBytes, PyBytesMethods},
129 };
130 use crate::{types::PyList, PyResult, Python};
131
132 #[test]
133 fn array_try_from_fn() {
134 static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
135 struct CountDrop;
136 impl Drop for CountDrop {
137 fn drop(&mut self) {
138 DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
139 }
140 }
141 let _ = catch_unwind_silent(move || {
142 let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
143 #[allow(clippy::manual_assert)]
144 if idx == 2 {
145 panic!("peek a boo");
146 }
147 Ok(CountDrop)
148 });
149 });
150 assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
151 }
152
153 #[test]
154 fn test_extract_bytearray_to_array() {
155 Python::with_gil(|py| {
156 let v: [u8; 33] = py
157 .eval(
158 ffi::c_str!("bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')"),
159 None,
160 None,
161 )
162 .unwrap()
163 .extract()
164 .unwrap();
165 assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
166 })
167 }
168
169 #[test]
170 fn test_extract_small_bytearray_to_array() {
171 Python::with_gil(|py| {
172 let v: [u8; 3] = py
173 .eval(ffi::c_str!("bytearray(b'abc')"), None, None)
174 .unwrap()
175 .extract()
176 .unwrap();
177 assert!(&v == b"abc");
178 });
179 }
180 #[test]
181 fn test_into_pyobject_array_conversion() {
182 Python::with_gil(|py| {
183 let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
184 let pyobject = array.into_pyobject(py).unwrap();
185 let pylist = pyobject.downcast::<PyList>().unwrap();
186 assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
187 assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
188 assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
189 assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
190 });
191 }
192
193 #[test]
194 fn test_extract_invalid_sequence_length() {
195 Python::with_gil(|py| {
196 let v: PyResult<[u8; 3]> = py
197 .eval(ffi::c_str!("bytearray(b'abcdefg')"), None, None)
198 .unwrap()
199 .extract();
200 assert_eq!(
201 v.unwrap_err().to_string(),
202 "ValueError: expected a sequence of length 3 (got 7)"
203 );
204 })
205 }
206
207 #[test]
208 fn test_intopyobject_array_conversion() {
209 Python::with_gil(|py| {
210 let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
211 let pylist = array
212 .into_pyobject(py)
213 .unwrap()
214 .downcast_into::<PyList>()
215 .unwrap();
216
217 assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
218 assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
219 assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
220 assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
221 });
222 }
223
224 #[test]
225 fn test_array_intopyobject_impl() {
226 Python::with_gil(|py| {
227 let bytes: [u8; 6] = *b"foobar";
228 let obj = bytes.into_pyobject(py).unwrap();
229 assert!(obj.is_instance_of::<PyBytes>());
230 let obj = obj.downcast_into::<PyBytes>().unwrap();
231 assert_eq!(obj.as_bytes(), &bytes);
232
233 let nums: [u16; 4] = [0, 1, 2, 3];
234 let obj = nums.into_pyobject(py).unwrap();
235 assert!(obj.is_instance_of::<PyList>());
236 });
237 }
238
239 #[test]
240 fn test_extract_non_iterable_to_array() {
241 Python::with_gil(|py| {
242 let v = py.eval(ffi::c_str!("42"), None, None).unwrap();
243 v.extract::<i32>().unwrap();
244 v.extract::<[i32; 1]>().unwrap_err();
245 });
246 }
247
248 #[cfg(feature = "macros")]
249 #[test]
250 fn test_pyclass_intopy_array_conversion() {
251 #[crate::pyclass(crate = "crate")]
252 struct Foo;
253
254 Python::with_gil(|py| {
255 let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
256 let list = array
257 .into_pyobject(py)
258 .unwrap()
259 .downcast_into::<PyList>()
260 .unwrap();
261 let _bound = list.get_item(4).unwrap().downcast::<Foo>().unwrap();
262 });
263 }
264
265 fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
267 where
268 F: FnOnce() -> R + panic::UnwindSafe,
269 {
270 let prev_hook = panic::take_hook();
271 panic::set_hook(Box::new(|_| {}));
272 let result = panic::catch_unwind(f);
273 panic::set_hook(prev_hook);
274 result
275 }
276}