pyo3/conversions/
num_rational.rs1#![cfg(feature = "num-rational")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"num-rational\"] }")]
13use crate::conversion::IntoPyObject;
47use crate::ffi;
48use crate::sync::GILOnceCell;
49use crate::types::any::PyAnyMethods;
50use crate::types::PyType;
51use crate::{Bound, FromPyObject, Py, PyAny, PyErr, PyResult, Python};
52
53#[cfg(feature = "num-bigint")]
54use num_bigint::BigInt;
55use num_rational::Ratio;
56
57static FRACTION_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
58
59fn get_fraction_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
60 FRACTION_CLS.import(py, "fractions", "Fraction")
61}
62
63macro_rules! rational_conversion {
64 ($int: ty) => {
65 impl<'py> FromPyObject<'py> for Ratio<$int> {
66 fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
67 let py = obj.py();
68 let py_numerator_obj = obj.getattr(crate::intern!(py, "numerator"))?;
69 let py_denominator_obj = obj.getattr(crate::intern!(py, "denominator"))?;
70 let numerator_owned = unsafe {
71 Bound::from_owned_ptr_or_err(py, ffi::PyNumber_Long(py_numerator_obj.as_ptr()))?
72 };
73 let denominator_owned = unsafe {
74 Bound::from_owned_ptr_or_err(
75 py,
76 ffi::PyNumber_Long(py_denominator_obj.as_ptr()),
77 )?
78 };
79 let rs_numerator: $int = numerator_owned.extract()?;
80 let rs_denominator: $int = denominator_owned.extract()?;
81 Ok(Ratio::new(rs_numerator, rs_denominator))
82 }
83 }
84
85 impl<'py> IntoPyObject<'py> for Ratio<$int> {
86 type Target = PyAny;
87 type Output = Bound<'py, Self::Target>;
88 type Error = PyErr;
89
90 #[inline]
91 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
92 (&self).into_pyobject(py)
93 }
94 }
95
96 impl<'py> IntoPyObject<'py> for &Ratio<$int> {
97 type Target = PyAny;
98 type Output = Bound<'py, Self::Target>;
99 type Error = PyErr;
100
101 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
102 get_fraction_cls(py)?.call1((self.numer().clone(), self.denom().clone()))
103 }
104 }
105 };
106}
107rational_conversion!(i8);
108rational_conversion!(i16);
109rational_conversion!(i32);
110rational_conversion!(isize);
111rational_conversion!(i64);
112#[cfg(feature = "num-bigint")]
113rational_conversion!(BigInt);
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::types::dict::PyDictMethods;
118 use crate::types::PyDict;
119
120 #[cfg(not(target_arch = "wasm32"))]
121 use proptest::prelude::*;
122 #[test]
123 fn test_negative_fraction() {
124 Python::with_gil(|py| {
125 let locals = PyDict::new(py);
126 py.run(
127 ffi::c_str!("import fractions\npy_frac = fractions.Fraction(-0.125)"),
128 None,
129 Some(&locals),
130 )
131 .unwrap();
132 let py_frac = locals.get_item("py_frac").unwrap().unwrap();
133 let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
134 let rs_frac = Ratio::new(-1, 8);
135 assert_eq!(roundtripped, rs_frac);
136 })
137 }
138 #[test]
139 fn test_obj_with_incorrect_atts() {
140 Python::with_gil(|py| {
141 let locals = PyDict::new(py);
142 py.run(
143 ffi::c_str!("not_fraction = \"contains_incorrect_atts\""),
144 None,
145 Some(&locals),
146 )
147 .unwrap();
148 let py_frac = locals.get_item("not_fraction").unwrap().unwrap();
149 assert!(py_frac.extract::<Ratio<i32>>().is_err());
150 })
151 }
152
153 #[test]
154 fn test_fraction_with_fraction_type() {
155 Python::with_gil(|py| {
156 let locals = PyDict::new(py);
157 py.run(
158 ffi::c_str!(
159 "import fractions\npy_frac = fractions.Fraction(fractions.Fraction(10))"
160 ),
161 None,
162 Some(&locals),
163 )
164 .unwrap();
165 let py_frac = locals.get_item("py_frac").unwrap().unwrap();
166 let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
167 let rs_frac = Ratio::new(10, 1);
168 assert_eq!(roundtripped, rs_frac);
169 })
170 }
171
172 #[test]
173 fn test_fraction_with_decimal() {
174 Python::with_gil(|py| {
175 let locals = PyDict::new(py);
176 py.run(
177 ffi::c_str!("import fractions\n\nfrom decimal import Decimal\npy_frac = fractions.Fraction(Decimal(\"1.1\"))"),
178 None,
179 Some(&locals),
180 )
181 .unwrap();
182 let py_frac = locals.get_item("py_frac").unwrap().unwrap();
183 let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
184 let rs_frac = Ratio::new(11, 10);
185 assert_eq!(roundtripped, rs_frac);
186 })
187 }
188
189 #[test]
190 fn test_fraction_with_num_den() {
191 Python::with_gil(|py| {
192 let locals = PyDict::new(py);
193 py.run(
194 ffi::c_str!("import fractions\npy_frac = fractions.Fraction(10,5)"),
195 None,
196 Some(&locals),
197 )
198 .unwrap();
199 let py_frac = locals.get_item("py_frac").unwrap().unwrap();
200 let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
201 let rs_frac = Ratio::new(10, 5);
202 assert_eq!(roundtripped, rs_frac);
203 })
204 }
205
206 #[cfg(target_arch = "wasm32")]
207 #[test]
208 fn test_int_roundtrip() {
209 Python::with_gil(|py| {
210 let rs_frac = Ratio::new(1i32, 2);
211 let py_frac = rs_frac.into_pyobject(py).unwrap();
212 let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
213 assert_eq!(rs_frac, roundtripped);
214 })
216 }
217
218 #[cfg(target_arch = "wasm32")]
219 #[test]
220 fn test_big_int_roundtrip() {
221 Python::with_gil(|py| {
222 let rs_frac = Ratio::from_float(5.5).unwrap();
223 let py_frac = rs_frac.clone().into_pyobject(py).unwrap();
224 let roundtripped: Ratio<BigInt> = py_frac.extract().unwrap();
225 assert_eq!(rs_frac, roundtripped);
226 })
227 }
228
229 #[cfg(not(target_arch = "wasm32"))]
230 proptest! {
231 #[test]
232 fn test_int_roundtrip(num in any::<i32>(), den in any::<i32>()) {
233 Python::with_gil(|py| {
234 let rs_frac = Ratio::new(num, den);
235 let py_frac = rs_frac.into_pyobject(py).unwrap();
236 let roundtripped: Ratio<i32> = py_frac.extract().unwrap();
237 assert_eq!(rs_frac, roundtripped);
238 })
239 }
240
241 #[test]
242 #[cfg(feature = "num-bigint")]
243 fn test_big_int_roundtrip(num in any::<f32>()) {
244 Python::with_gil(|py| {
245 let rs_frac = Ratio::from_float(num).unwrap();
246 let py_frac = rs_frac.clone().into_pyobject(py).unwrap();
247 let roundtripped: Ratio<BigInt> = py_frac.extract().unwrap();
248 assert_eq!(roundtripped, rs_frac);
249 })
250 }
251
252 }
253
254 #[test]
255 fn test_infinity() {
256 Python::with_gil(|py| {
257 let locals = PyDict::new(py);
258 let py_bound = py.run(
259 ffi::c_str!("import fractions\npy_frac = fractions.Fraction(\"Infinity\")"),
260 None,
261 Some(&locals),
262 );
263 assert!(py_bound.is_err());
264 })
265 }
266}