pyo3/conversions/
num_complex.rs

1#![cfg(feature = "num-complex")]
2
3//!  Conversions to and from [num-complex](https://docs.rs/num-complex)’
4//! [`Complex`]`<`[`f32`]`>` and [`Complex`]`<`[`f64`]`>`.
5//!
6//! num-complex’ [`Complex`] supports more operations than PyO3's [`PyComplex`]
7//! and can be used with the rest of the Rust ecosystem.
8//!
9//! # Setup
10//!
11//! To use this feature, add this to your **`Cargo.toml`**:
12//!
13//! ```toml
14//! [dependencies]
15//! # change * to the latest versions
16//! num-complex = "*"
17#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"num-complex\"] }")]
18//! ```
19//!
20//! Note that you must use compatible versions of num-complex and PyO3.
21//! The required num-complex version may vary based on the version of PyO3.
22//!
23//! # Examples
24//!
25//! Using [num-complex](https://docs.rs/num-complex) and [nalgebra](https://docs.rs/nalgebra)
26//! to create a pyfunction that calculates the eigenvalues of a 2x2 matrix.
27//! ```ignore
28//! # // not tested because nalgebra isn't supported on msrv
29//! # // please file an issue if it breaks!
30//! use nalgebra::base::{dimension::Const, Matrix};
31//! use num_complex::Complex;
32//! use pyo3::prelude::*;
33//!
34//! type T = Complex<f64>;
35//!
36//! #[pyfunction]
37//! fn get_eigenvalues(m11: T, m12: T, m21: T, m22: T) -> Vec<T> {
38//!     let mat = Matrix::<T, Const<2>, Const<2>, _>::new(m11, m12, m21, m22);
39//!
40//!     match mat.eigenvalues() {
41//!         Some(e) => e.data.as_slice().to_vec(),
42//!         None => vec![],
43//!     }
44//! }
45//!
46//! #[pymodule]
47//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
48//!     m.add_function(wrap_pyfunction!(get_eigenvalues, m)?)?;
49//!     Ok(())
50//! }
51//! # // test
52//! # use assert_approx_eq::assert_approx_eq;
53//! # use nalgebra::ComplexField;
54//! # use pyo3::types::PyComplex;
55//! #
56//! # fn main() -> PyResult<()> {
57//! #     Python::with_gil(|py| -> PyResult<()> {
58//! #         let module = PyModule::new(py, "my_module")?;
59//! #
60//! #         module.add_function(&wrap_pyfunction!(get_eigenvalues, module)?)?;
61//! #
62//! #         let m11 = PyComplex::from_doubles(py, 0_f64, -1_f64);
63//! #         let m12 = PyComplex::from_doubles(py, 1_f64, 0_f64);
64//! #         let m21 = PyComplex::from_doubles(py, 2_f64, -1_f64);
65//! #         let m22 = PyComplex::from_doubles(py, -1_f64, 0_f64);
66//! #
67//! #         let result = module
68//! #             .getattr("get_eigenvalues")?
69//! #             .call1((m11, m12, m21, m22))?;
70//! #         println!("eigenvalues: {:?}", result);
71//! #
72//! #         let result = result.extract::<Vec<T>>()?;
73//! #         let e0 = result[0];
74//! #         let e1 = result[1];
75//! #
76//! #         assert_approx_eq!(e0, Complex::new(1_f64, -1_f64));
77//! #         assert_approx_eq!(e1, Complex::new(-2_f64, 0_f64));
78//! #
79//! #         Ok(())
80//! #     })
81//! # }
82//! ```
83//!
84//! Python code:
85//! ```python
86//! from my_module import get_eigenvalues
87//!
88//! m11 = complex(0,-1)
89//! m12 = complex(1,0)
90//! m21 = complex(2,-1)
91//! m22 = complex(-1,0)
92//!
93//! result = get_eigenvalues(m11,m12,m21,m22)
94//! assert result == [complex(1,-1), complex(-2,0)]
95//! ```
96use crate::{
97    ffi,
98    ffi_ptr_ext::FfiPtrExt,
99    types::{any::PyAnyMethods, PyComplex},
100    Bound, FromPyObject, PyAny, PyErr, PyResult, Python,
101};
102use num_complex::Complex;
103use std::os::raw::c_double;
104
105impl PyComplex {
106    /// Creates a new Python `PyComplex` object from `num_complex`'s [`Complex`].
107    pub fn from_complex_bound<F: Into<c_double>>(
108        py: Python<'_>,
109        complex: Complex<F>,
110    ) -> Bound<'_, PyComplex> {
111        unsafe {
112            ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
113                .assume_owned(py)
114                .downcast_into_unchecked()
115        }
116    }
117}
118
119macro_rules! complex_conversion {
120    ($float: ty) => {
121        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
122        impl<'py> crate::conversion::IntoPyObject<'py> for Complex<$float> {
123            type Target = PyComplex;
124            type Output = Bound<'py, Self::Target>;
125            type Error = std::convert::Infallible;
126
127            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
128                unsafe {
129                    Ok(
130                        ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double)
131                            .assume_owned(py)
132                            .downcast_into_unchecked(),
133                    )
134                }
135            }
136        }
137
138        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
139        impl<'py> crate::conversion::IntoPyObject<'py> for &Complex<$float> {
140            type Target = PyComplex;
141            type Output = Bound<'py, Self::Target>;
142            type Error = std::convert::Infallible;
143
144            #[inline]
145            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
146                (*self).into_pyobject(py)
147            }
148        }
149
150        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
151        impl FromPyObject<'_> for Complex<$float> {
152            fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
153                #[cfg(not(any(Py_LIMITED_API, PyPy)))]
154                unsafe {
155                    let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
156                    if val.real == -1.0 {
157                        if let Some(err) = PyErr::take(obj.py()) {
158                            return Err(err);
159                        }
160                    }
161                    Ok(Complex::new(val.real as $float, val.imag as $float))
162                }
163
164                #[cfg(any(Py_LIMITED_API, PyPy))]
165                unsafe {
166                    let complex;
167                    let obj = if obj.is_instance_of::<PyComplex>() {
168                        obj
169                    } else if let Some(method) =
170                        obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
171                    {
172                        complex = method.call0()?;
173                        &complex
174                    } else {
175                        // `obj` might still implement `__float__` or `__index__`, which will be
176                        // handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
177                        // errors if those methods don't exist / raise exceptions.
178                        obj
179                    };
180                    let ptr = obj.as_ptr();
181                    let real = ffi::PyComplex_RealAsDouble(ptr);
182                    if real == -1.0 {
183                        if let Some(err) = PyErr::take(obj.py()) {
184                            return Err(err);
185                        }
186                    }
187                    let imag = ffi::PyComplex_ImagAsDouble(ptr);
188                    Ok(Complex::new(real as $float, imag as $float))
189                }
190            }
191        }
192    };
193}
194complex_conversion!(f32);
195complex_conversion!(f64);
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::tests::common::generate_unique_module_name;
201    use crate::types::{complex::PyComplexMethods, PyModule};
202    use crate::IntoPyObject;
203    use pyo3_ffi::c_str;
204
205    #[test]
206    fn from_complex() {
207        Python::with_gil(|py| {
208            let complex = Complex::new(3.0, 1.2);
209            let py_c = PyComplex::from_complex_bound(py, complex);
210            assert_eq!(py_c.real(), 3.0);
211            assert_eq!(py_c.imag(), 1.2);
212        });
213    }
214    #[test]
215    fn to_from_complex() {
216        Python::with_gil(|py| {
217            let val = Complex::new(3.0f64, 1.2);
218            let obj = val.into_pyobject(py).unwrap();
219            assert_eq!(obj.extract::<Complex<f64>>().unwrap(), val);
220        });
221    }
222    #[test]
223    fn from_complex_err() {
224        Python::with_gil(|py| {
225            let obj = vec![1i32].into_pyobject(py).unwrap();
226            assert!(obj.extract::<Complex<f64>>().is_err());
227        });
228    }
229    #[test]
230    fn from_python_magic() {
231        Python::with_gil(|py| {
232            let module = PyModule::from_code(
233                py,
234                c_str!(
235                    r#"
236class A:
237    def __complex__(self): return 3.0+1.2j
238class B:
239    def __float__(self): return 3.0
240class C:
241    def __index__(self): return 3
242                "#
243                ),
244                c_str!("test.py"),
245                &generate_unique_module_name("test"),
246            )
247            .unwrap();
248            let from_complex = module.getattr("A").unwrap().call0().unwrap();
249            assert_eq!(
250                from_complex.extract::<Complex<f64>>().unwrap(),
251                Complex::new(3.0, 1.2)
252            );
253            let from_float = module.getattr("B").unwrap().call0().unwrap();
254            assert_eq!(
255                from_float.extract::<Complex<f64>>().unwrap(),
256                Complex::new(3.0, 0.0)
257            );
258            // Before Python 3.8, `__index__` wasn't tried by `float`/`complex`.
259            #[cfg(Py_3_8)]
260            {
261                let from_index = module.getattr("C").unwrap().call0().unwrap();
262                assert_eq!(
263                    from_index.extract::<Complex<f64>>().unwrap(),
264                    Complex::new(3.0, 0.0)
265                );
266            }
267        })
268    }
269    #[test]
270    fn from_python_inherited_magic() {
271        Python::with_gil(|py| {
272            let module = PyModule::from_code(
273                py,
274                c_str!(
275                    r#"
276class First: pass
277class ComplexMixin:
278    def __complex__(self): return 3.0+1.2j
279class FloatMixin:
280    def __float__(self): return 3.0
281class IndexMixin:
282    def __index__(self): return 3
283class A(First, ComplexMixin): pass
284class B(First, FloatMixin): pass
285class C(First, IndexMixin): pass
286                "#
287                ),
288                c_str!("test.py"),
289                &generate_unique_module_name("test"),
290            )
291            .unwrap();
292            let from_complex = module.getattr("A").unwrap().call0().unwrap();
293            assert_eq!(
294                from_complex.extract::<Complex<f64>>().unwrap(),
295                Complex::new(3.0, 1.2)
296            );
297            let from_float = module.getattr("B").unwrap().call0().unwrap();
298            assert_eq!(
299                from_float.extract::<Complex<f64>>().unwrap(),
300                Complex::new(3.0, 0.0)
301            );
302            #[cfg(Py_3_8)]
303            {
304                let from_index = module.getattr("C").unwrap().call0().unwrap();
305                assert_eq!(
306                    from_index.extract::<Complex<f64>>().unwrap(),
307                    Complex::new(3.0, 0.0)
308                );
309            }
310        })
311    }
312    #[test]
313    fn from_python_noncallable_descriptor_magic() {
314        // Functions and lambdas implement the descriptor protocol in a way that makes
315        // `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only
316        // way the descriptor protocol might be implemented.
317        Python::with_gil(|py| {
318            let module = PyModule::from_code(
319                py,
320                c_str!(
321                    r#"
322class A:
323    @property
324    def __complex__(self):
325        return lambda: 3.0+1.2j
326                "#
327                ),
328                c_str!("test.py"),
329                &generate_unique_module_name("test"),
330            )
331            .unwrap();
332            let obj = module.getattr("A").unwrap().call0().unwrap();
333            assert_eq!(
334                obj.extract::<Complex<f64>>().unwrap(),
335                Complex::new(3.0, 1.2)
336            );
337        })
338    }
339    #[test]
340    fn from_python_nondescriptor_magic() {
341        // Magic methods don't need to implement the descriptor protocol, if they're callable.
342        Python::with_gil(|py| {
343            let module = PyModule::from_code(
344                py,
345                c_str!(
346                    r#"
347class MyComplex:
348    def __call__(self): return 3.0+1.2j
349class A:
350    __complex__ = MyComplex()
351                "#
352                ),
353                c_str!("test.py"),
354                &generate_unique_module_name("test"),
355            )
356            .unwrap();
357            let obj = module.getattr("A").unwrap().call0().unwrap();
358            assert_eq!(
359                obj.extract::<Complex<f64>>().unwrap(),
360                Complex::new(3.0, 1.2)
361            );
362        })
363    }
364}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here