pyo3/conversions/
rust_decimal.rs

1#![cfg(feature = "rust_decimal")]
2//! Conversions to and from [rust_decimal](https://docs.rs/rust_decimal)'s [`Decimal`] type.
3//!
4//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
5//!
6//! # Setup
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"rust_decimal\"] }")]
13//! rust_decimal = "1.0"
14//! ```
15//!
16//! Note that you must use a compatible version of rust_decimal and PyO3.
17//! The required rust_decimal version may vary based on the version of PyO3.
18//!
19//! # Example
20//!
21//! Rust code to create a function that adds one to a Decimal
22//!
23//! ```rust,no_run
24//! use rust_decimal::Decimal;
25//! use pyo3::prelude::*;
26//!
27//! #[pyfunction]
28//! fn add_one(d: Decimal) -> Decimal {
29//!     d + Decimal::ONE
30//! }
31//!
32//! #[pymodule]
33//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
34//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
35//!     Ok(())
36//! }
37//! ```
38//!
39//! Python code that validates the functionality
40//!
41//!
42//! ```python
43//! from my_module import add_one
44//! from decimal import Decimal
45//!
46//! d = Decimal("2")
47//! value = add_one(d)
48//!
49//! assert d + 1 == value
50//! ```
51
52use crate::conversion::IntoPyObject;
53use crate::exceptions::PyValueError;
54use crate::sync::GILOnceCell;
55use crate::types::any::PyAnyMethods;
56use crate::types::string::PyStringMethods;
57use crate::types::PyType;
58use crate::{Bound, FromPyObject, Py, PyAny, PyErr, PyResult, Python};
59use rust_decimal::Decimal;
60use std::str::FromStr;
61
62impl FromPyObject<'_> for Decimal {
63    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
64        // use the string representation to not be lossy
65        if let Ok(val) = obj.extract() {
66            Ok(Decimal::new(val, 0))
67        } else {
68            let py_str = &obj.str()?;
69            let rs_str = &py_str.to_cow()?;
70            Decimal::from_str(rs_str).or_else(|_| {
71                Decimal::from_scientific(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
72            })
73        }
74    }
75}
76
77static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
78
79fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
80    DECIMAL_CLS.import(py, "decimal", "Decimal")
81}
82
83impl<'py> IntoPyObject<'py> for Decimal {
84    type Target = PyAny;
85    type Output = Bound<'py, Self::Target>;
86    type Error = PyErr;
87
88    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
89        let dec_cls = get_decimal_cls(py)?;
90        // now call the constructor with the Rust Decimal string-ified
91        // to not be lossy
92        dec_cls.call1((self.to_string(),))
93    }
94}
95
96impl<'py> IntoPyObject<'py> for &Decimal {
97    type Target = PyAny;
98    type Output = Bound<'py, Self::Target>;
99    type Error = PyErr;
100
101    #[inline]
102    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
103        (*self).into_pyobject(py)
104    }
105}
106
107#[cfg(test)]
108mod test_rust_decimal {
109    use super::*;
110    use crate::types::dict::PyDictMethods;
111    use crate::types::PyDict;
112    use std::ffi::CString;
113
114    use crate::ffi;
115    #[cfg(not(target_arch = "wasm32"))]
116    use proptest::prelude::*;
117
118    macro_rules! convert_constants {
119        ($name:ident, $rs:expr, $py:literal) => {
120            #[test]
121            fn $name() {
122                Python::with_gil(|py| {
123                    let rs_orig = $rs;
124                    let rs_dec = rs_orig.into_pyobject(py).unwrap();
125                    let locals = PyDict::new(py);
126                    locals.set_item("rs_dec", &rs_dec).unwrap();
127                    // Checks if Rust Decimal -> Python Decimal conversion is correct
128                    py.run(
129                        &CString::new(format!(
130                            "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
131                            $py
132                        ))
133                        .unwrap(),
134                        None,
135                        Some(&locals),
136                    )
137                    .unwrap();
138                    // Checks if Python Decimal -> Rust Decimal conversion is correct
139                    let py_dec = locals.get_item("py_dec").unwrap().unwrap();
140                    let py_result: Decimal = py_dec.extract().unwrap();
141                    assert_eq!(rs_orig, py_result);
142                })
143            }
144        };
145    }
146
147    convert_constants!(convert_zero, Decimal::ZERO, "0");
148    convert_constants!(convert_one, Decimal::ONE, "1");
149    convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
150    convert_constants!(convert_two, Decimal::TWO, "2");
151    convert_constants!(convert_ten, Decimal::TEN, "10");
152    convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
153    convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
154
155    #[cfg(not(target_arch = "wasm32"))]
156    proptest! {
157        #[test]
158        fn test_roundtrip(
159            lo in any::<u32>(),
160            mid in any::<u32>(),
161            high in any::<u32>(),
162            negative in any::<bool>(),
163            scale in 0..28u32
164        ) {
165            let num = Decimal::from_parts(lo, mid, high, negative, scale);
166            Python::with_gil(|py| {
167                let rs_dec = num.into_pyobject(py).unwrap();
168                let locals = PyDict::new(py);
169                locals.set_item("rs_dec", &rs_dec).unwrap();
170                py.run(
171                    &CString::new(format!(
172                       "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
173                     num)).unwrap(),
174                None, Some(&locals)).unwrap();
175                let roundtripped: Decimal = rs_dec.extract().unwrap();
176                assert_eq!(num, roundtripped);
177            })
178        }
179
180        #[test]
181        fn test_integers(num in any::<i64>()) {
182            Python::with_gil(|py| {
183                let py_num = num.into_pyobject(py).unwrap();
184                let roundtripped: Decimal = py_num.extract().unwrap();
185                let rs_dec = Decimal::new(num, 0);
186                assert_eq!(rs_dec, roundtripped);
187            })
188        }
189    }
190
191    #[test]
192    fn test_nan() {
193        Python::with_gil(|py| {
194            let locals = PyDict::new(py);
195            py.run(
196                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
197                None,
198                Some(&locals),
199            )
200            .unwrap();
201            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
202            let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
203            assert!(roundtripped.is_err());
204        })
205    }
206
207    #[test]
208    fn test_scientific_notation() {
209        Python::with_gil(|py| {
210            let locals = PyDict::new(py);
211            py.run(
212                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"1e3\")"),
213                None,
214                Some(&locals),
215            )
216            .unwrap();
217            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
218            let roundtripped: Decimal = py_dec.extract().unwrap();
219            let rs_dec = Decimal::from_scientific("1e3").unwrap();
220            assert_eq!(rs_dec, roundtripped);
221        })
222    }
223
224    #[test]
225    fn test_infinity() {
226        Python::with_gil(|py| {
227            let locals = PyDict::new(py);
228            py.run(
229                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
230                None,
231                Some(&locals),
232            )
233            .unwrap();
234            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
235            let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
236            assert!(roundtripped.is_err());
237        })
238    }
239}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here