Skip to main content

pyo3/conversions/
num_rational.rs

1#![cfg(feature = "num-rational")]
2//! Conversions to and from [num-rational](https://docs.rs/num-rational) types.
3//!
4//! This is useful for converting between Python's [fractions.Fraction](https://docs.python.org/3/library/fractions.html) into and from a native Rust
5//! type.
6//!
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"num-rational\"] }")]
13//! num-rational = "0.4.1"
14//! ```
15//!
16//! # Example
17//!
18//! Rust code to create a function that adds five to a fraction:
19//!
20//! ```rust,no_run
21//! use num_rational::Ratio;
22//! use pyo3::prelude::*;
23//!
24//! #[pyfunction]
25//! fn add_five_to_fraction(fraction: Ratio<i32>) -> Ratio<i32> {
26//!     fraction + Ratio::new(5, 1)
27//! }
28//!
29//! #[pymodule]
30//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
31//!     m.add_function(wrap_pyfunction!(add_five_to_fraction, m)?)?;
32//!     Ok(())
33//! }
34//! ```
35//!
36//! Python code that validates the functionality:
37//! ```python
38//! from my_module import add_five_to_fraction
39//! from fractions import Fraction
40//!
41//! fraction = Fraction(2,1)
42//! fraction_plus_five = add_five_to_fraction(f)
43//! assert fraction + 5 == fraction_plus_five
44//! ```
45
46use crate::conversion::IntoPyObject;
47use crate::ffi;
48#[cfg(feature = "experimental-inspect")]
49use crate::inspect::PyStaticExpr;
50use crate::sync::PyOnceLock;
51#[cfg(feature = "experimental-inspect")]
52use crate::type_hint_identifier;
53use crate::types::any::PyAnyMethods;
54use crate::types::PyType;
55use crate::{Borrowed, Bound, FromPyObject, Py, PyAny, PyErr, PyResult, Python};
56#[cfg(feature = "num-bigint")]
57use num_bigint::BigInt;
58use num_rational::Ratio;
59
60static FRACTION_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
61
62fn get_fraction_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
63    FRACTION_CLS.import(py, "fractions", "Fraction")
64}
65
66macro_rules! rational_conversion {
67    ($int: ty) => {
68        impl<'py> FromPyObject<'_, 'py> for Ratio<$int> {
69            type Error = PyErr;
70
71            #[cfg(feature = "experimental-inspect")]
72            const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("fractions", "Fraction");
73
74            fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
75                let py = obj.py();
76                let py_numerator_obj = obj.getattr(crate::intern!(py, "numerator"))?;
77                let py_denominator_obj = obj.getattr(crate::intern!(py, "denominator"))?;
78                let numerator_owned = unsafe {
79                    Bound::from_owned_ptr_or_err(py, ffi::PyNumber_Long(py_numerator_obj.as_ptr()))?
80                };
81                let denominator_owned = unsafe {
82                    Bound::from_owned_ptr_or_err(
83                        py,
84                        ffi::PyNumber_Long(py_denominator_obj.as_ptr()),
85                    )?
86                };
87                let rs_numerator: $int = numerator_owned.extract()?;
88                let rs_denominator: $int = denominator_owned.extract()?;
89                Ok(Ratio::new(rs_numerator, rs_denominator))
90            }
91        }
92
93        impl<'py> IntoPyObject<'py> for Ratio<$int> {
94            type Target = PyAny;
95            type Output = Bound<'py, Self::Target>;
96            type Error = PyErr;
97
98            #[cfg(feature = "experimental-inspect")]
99            const OUTPUT_TYPE: PyStaticExpr = <&Ratio<$int>>::OUTPUT_TYPE;
100
101            #[inline]
102            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
103                (&self).into_pyobject(py)
104            }
105        }
106
107        impl<'py> IntoPyObject<'py> for &Ratio<$int> {
108            type Target = PyAny;
109            type Output = Bound<'py, Self::Target>;
110            type Error = PyErr;
111
112            #[cfg(feature = "experimental-inspect")]
113            const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("fractions", "Fraction");
114
115            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
116                get_fraction_cls(py)?.call1((self.numer().clone(), self.denom().clone()))
117            }
118        }
119    };
120}
121rational_conversion!(i8);
122rational_conversion!(i16);
123rational_conversion!(i32);
124rational_conversion!(isize);
125rational_conversion!(i64);
126#[cfg(feature = "num-bigint")]
127rational_conversion!(BigInt);
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::types::dict::PyDictMethods;
132    use crate::types::PyDict;
133
134    #[cfg(not(target_arch = "wasm32"))]
135    use proptest::prelude::*;
136    #[test]
137    fn test_negative_fraction() {
138        Python::attach(|py| {
139            let locals = PyDict::new(py);
140            py.run(
141                c"import fractions\npy_frac = fractions.Fraction(-0.125)",
142                None,
143                Some(&locals),
144            )
145            .unwrap();
146            let py_frac = locals.get_item("py_frac").unwrap().unwrap();
147            let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
148            let rs_frac = Ratio::new(-1, 8);
149            assert_eq!(roundtripped, rs_frac);
150        })
151    }
152    #[test]
153    fn test_obj_with_incorrect_atts() {
154        Python::attach(|py| {
155            let locals = PyDict::new(py);
156            py.run(
157                c"not_fraction = \"contains_incorrect_atts\"",
158                None,
159                Some(&locals),
160            )
161            .unwrap();
162            let py_frac = locals.get_item("not_fraction").unwrap().unwrap();
163            assert!(py_frac.extract::<Ratio<i32>>().is_err());
164        })
165    }
166
167    #[test]
168    fn test_fraction_with_fraction_type() {
169        Python::attach(|py| {
170            let locals = PyDict::new(py);
171            py.run(
172                c"import fractions\npy_frac = fractions.Fraction(fractions.Fraction(10))",
173                None,
174                Some(&locals),
175            )
176            .unwrap();
177            let py_frac = locals.get_item("py_frac").unwrap().unwrap();
178            let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
179            let rs_frac = Ratio::new(10, 1);
180            assert_eq!(roundtripped, rs_frac);
181        })
182    }
183
184    #[test]
185    fn test_fraction_with_decimal() {
186        Python::attach(|py| {
187            let locals = PyDict::new(py);
188            py.run(
189                c"import fractions\n\nfrom decimal import Decimal\npy_frac = fractions.Fraction(Decimal(\"1.1\"))",
190                None,
191                Some(&locals),
192            )
193            .unwrap();
194            let py_frac = locals.get_item("py_frac").unwrap().unwrap();
195            let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
196            let rs_frac = Ratio::new(11, 10);
197            assert_eq!(roundtripped, rs_frac);
198        })
199    }
200
201    #[test]
202    fn test_fraction_with_num_den() {
203        Python::attach(|py| {
204            let locals = PyDict::new(py);
205            py.run(
206                c"import fractions\npy_frac = fractions.Fraction(10,5)",
207                None,
208                Some(&locals),
209            )
210            .unwrap();
211            let py_frac = locals.get_item("py_frac").unwrap().unwrap();
212            let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
213            let rs_frac = Ratio::new(10, 5);
214            assert_eq!(roundtripped, rs_frac);
215        })
216    }
217
218    #[cfg(target_arch = "wasm32")]
219    #[test]
220    fn test_int_roundtrip() {
221        Python::attach(|py| {
222            let rs_frac = Ratio::new(1i32, 2);
223            let py_frac = rs_frac.into_pyobject(py).unwrap();
224            let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
225            assert_eq!(rs_frac, roundtripped);
226            // float conversion
227        })
228    }
229
230    #[cfg(target_arch = "wasm32")]
231    #[test]
232    fn test_big_int_roundtrip() {
233        Python::attach(|py| {
234            let rs_frac = Ratio::from_float(5.5).unwrap();
235            let py_frac = rs_frac.clone().into_pyobject(py).unwrap();
236            let roundtripped: Ratio<BigInt> = py_frac.extract().unwrap();
237            assert_eq!(rs_frac, roundtripped);
238        })
239    }
240
241    #[cfg(not(target_arch = "wasm32"))]
242    proptest! {
243        #[test]
244        fn test_int_roundtrip(num in any::<i32>(), den in any::<i32>()) {
245            Python::attach(|py| {
246                let rs_frac = Ratio::new(num, den);
247                let py_frac = rs_frac.into_pyobject(py).unwrap();
248                let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
249                assert_eq!(rs_frac, roundtripped);
250            })
251        }
252
253        #[test]
254        #[cfg(feature = "num-bigint")]
255        fn test_big_int_roundtrip(num in any::<f32>()) {
256            Python::attach(|py| {
257                let rs_frac = Ratio::from_float(num).unwrap();
258                let py_frac = rs_frac.clone().into_pyobject(py).unwrap();
259                let roundtripped: Ratio<BigInt> = py_frac.extract().unwrap();
260                assert_eq!(roundtripped, rs_frac);
261            })
262        }
263
264    }
265
266    #[test]
267    fn test_infinity() {
268        Python::attach(|py| {
269            let locals = PyDict::new(py);
270            let py_bound = py.run(
271                c"import fractions\npy_frac = fractions.Fraction(\"Infinity\")",
272                None,
273                Some(&locals),
274            );
275            assert!(py_bound.is_err());
276        })
277    }
278}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here