pyo3/conversions/
num_complex.rs1#![cfg(feature = "num-complex")]
2
3#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"num-complex\"] }")]
18use crate::{
97 ffi,
98 ffi_ptr_ext::FfiPtrExt,
99 types::{any::PyAnyMethods, PyComplex},
100 Bound, FromPyObject, PyAny, PyErr, PyResult, Python,
101};
102use num_complex::Complex;
103use std::os::raw::c_double;
104
105impl PyComplex {
106 pub fn from_complex_bound<F: Into<c_double>>(
108 py: Python<'_>,
109 complex: Complex<F>,
110 ) -> Bound<'_, PyComplex> {
111 unsafe {
112 ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
113 .assume_owned(py)
114 .downcast_into_unchecked()
115 }
116 }
117}
118
119macro_rules! complex_conversion {
120 ($float: ty) => {
121 #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
122 impl<'py> crate::conversion::IntoPyObject<'py> for Complex<$float> {
123 type Target = PyComplex;
124 type Output = Bound<'py, Self::Target>;
125 type Error = std::convert::Infallible;
126
127 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
128 unsafe {
129 Ok(
130 ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double)
131 .assume_owned(py)
132 .downcast_into_unchecked(),
133 )
134 }
135 }
136 }
137
138 #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
139 impl<'py> crate::conversion::IntoPyObject<'py> for &Complex<$float> {
140 type Target = PyComplex;
141 type Output = Bound<'py, Self::Target>;
142 type Error = std::convert::Infallible;
143
144 #[inline]
145 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
146 (*self).into_pyobject(py)
147 }
148 }
149
150 #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
151 impl FromPyObject<'_> for Complex<$float> {
152 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
153 #[cfg(not(any(Py_LIMITED_API, PyPy)))]
154 unsafe {
155 let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
156 if val.real == -1.0 {
157 if let Some(err) = PyErr::take(obj.py()) {
158 return Err(err);
159 }
160 }
161 Ok(Complex::new(val.real as $float, val.imag as $float))
162 }
163
164 #[cfg(any(Py_LIMITED_API, PyPy))]
165 unsafe {
166 let complex;
167 let obj = if obj.is_instance_of::<PyComplex>() {
168 obj
169 } else if let Some(method) =
170 obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
171 {
172 complex = method.call0()?;
173 &complex
174 } else {
175 obj
179 };
180 let ptr = obj.as_ptr();
181 let real = ffi::PyComplex_RealAsDouble(ptr);
182 if real == -1.0 {
183 if let Some(err) = PyErr::take(obj.py()) {
184 return Err(err);
185 }
186 }
187 let imag = ffi::PyComplex_ImagAsDouble(ptr);
188 Ok(Complex::new(real as $float, imag as $float))
189 }
190 }
191 }
192 };
193}
194complex_conversion!(f32);
195complex_conversion!(f64);
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::tests::common::generate_unique_module_name;
201 use crate::types::{complex::PyComplexMethods, PyModule};
202 use crate::IntoPyObject;
203 use pyo3_ffi::c_str;
204
205 #[test]
206 fn from_complex() {
207 Python::with_gil(|py| {
208 let complex = Complex::new(3.0, 1.2);
209 let py_c = PyComplex::from_complex_bound(py, complex);
210 assert_eq!(py_c.real(), 3.0);
211 assert_eq!(py_c.imag(), 1.2);
212 });
213 }
214 #[test]
215 fn to_from_complex() {
216 Python::with_gil(|py| {
217 let val = Complex::new(3.0f64, 1.2);
218 let obj = val.into_pyobject(py).unwrap();
219 assert_eq!(obj.extract::<Complex<f64>>().unwrap(), val);
220 });
221 }
222 #[test]
223 fn from_complex_err() {
224 Python::with_gil(|py| {
225 let obj = vec![1i32].into_pyobject(py).unwrap();
226 assert!(obj.extract::<Complex<f64>>().is_err());
227 });
228 }
229 #[test]
230 fn from_python_magic() {
231 Python::with_gil(|py| {
232 let module = PyModule::from_code(
233 py,
234 c_str!(
235 r#"
236class A:
237 def __complex__(self): return 3.0+1.2j
238class B:
239 def __float__(self): return 3.0
240class C:
241 def __index__(self): return 3
242 "#
243 ),
244 c_str!("test.py"),
245 &generate_unique_module_name("test"),
246 )
247 .unwrap();
248 let from_complex = module.getattr("A").unwrap().call0().unwrap();
249 assert_eq!(
250 from_complex.extract::<Complex<f64>>().unwrap(),
251 Complex::new(3.0, 1.2)
252 );
253 let from_float = module.getattr("B").unwrap().call0().unwrap();
254 assert_eq!(
255 from_float.extract::<Complex<f64>>().unwrap(),
256 Complex::new(3.0, 0.0)
257 );
258 #[cfg(Py_3_8)]
260 {
261 let from_index = module.getattr("C").unwrap().call0().unwrap();
262 assert_eq!(
263 from_index.extract::<Complex<f64>>().unwrap(),
264 Complex::new(3.0, 0.0)
265 );
266 }
267 })
268 }
269 #[test]
270 fn from_python_inherited_magic() {
271 Python::with_gil(|py| {
272 let module = PyModule::from_code(
273 py,
274 c_str!(
275 r#"
276class First: pass
277class ComplexMixin:
278 def __complex__(self): return 3.0+1.2j
279class FloatMixin:
280 def __float__(self): return 3.0
281class IndexMixin:
282 def __index__(self): return 3
283class A(First, ComplexMixin): pass
284class B(First, FloatMixin): pass
285class C(First, IndexMixin): pass
286 "#
287 ),
288 c_str!("test.py"),
289 &generate_unique_module_name("test"),
290 )
291 .unwrap();
292 let from_complex = module.getattr("A").unwrap().call0().unwrap();
293 assert_eq!(
294 from_complex.extract::<Complex<f64>>().unwrap(),
295 Complex::new(3.0, 1.2)
296 );
297 let from_float = module.getattr("B").unwrap().call0().unwrap();
298 assert_eq!(
299 from_float.extract::<Complex<f64>>().unwrap(),
300 Complex::new(3.0, 0.0)
301 );
302 #[cfg(Py_3_8)]
303 {
304 let from_index = module.getattr("C").unwrap().call0().unwrap();
305 assert_eq!(
306 from_index.extract::<Complex<f64>>().unwrap(),
307 Complex::new(3.0, 0.0)
308 );
309 }
310 })
311 }
312 #[test]
313 fn from_python_noncallable_descriptor_magic() {
314 Python::with_gil(|py| {
318 let module = PyModule::from_code(
319 py,
320 c_str!(
321 r#"
322class A:
323 @property
324 def __complex__(self):
325 return lambda: 3.0+1.2j
326 "#
327 ),
328 c_str!("test.py"),
329 &generate_unique_module_name("test"),
330 )
331 .unwrap();
332 let obj = module.getattr("A").unwrap().call0().unwrap();
333 assert_eq!(
334 obj.extract::<Complex<f64>>().unwrap(),
335 Complex::new(3.0, 1.2)
336 );
337 })
338 }
339 #[test]
340 fn from_python_nondescriptor_magic() {
341 Python::with_gil(|py| {
343 let module = PyModule::from_code(
344 py,
345 c_str!(
346 r#"
347class MyComplex:
348 def __call__(self): return 3.0+1.2j
349class A:
350 __complex__ = MyComplex()
351 "#
352 ),
353 c_str!("test.py"),
354 &generate_unique_module_name("test"),
355 )
356 .unwrap();
357 let obj = module.getattr("A").unwrap().call0().unwrap();
358 assert_eq!(
359 obj.extract::<Complex<f64>>().unwrap(),
360 Complex::new(3.0, 1.2)
361 );
362 })
363 }
364}