Skip to main content

pyo3/conversions/
rust_decimal.rs

1#![cfg(feature = "rust_decimal")]
2//! Conversions to and from [rust_decimal](https://docs.rs/rust_decimal)'s [`Decimal`] type.
3//!
4//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
5//!
6//! # Setup
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"rust_decimal\"] }")]
13//! rust_decimal = "1.0"
14//! ```
15//!
16//! Note that you must use a compatible version of rust_decimal and PyO3.
17//! The required rust_decimal version may vary based on the version of PyO3.
18//!
19//! # Example
20//!
21//! Rust code to create a function that adds one to a Decimal
22//!
23//! ```rust,no_run
24//! use rust_decimal::Decimal;
25//! use pyo3::prelude::*;
26//!
27//! #[pyfunction]
28//! fn add_one(d: Decimal) -> Decimal {
29//!     d + Decimal::ONE
30//! }
31//!
32//! #[pymodule]
33//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
34//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
35//!     Ok(())
36//! }
37//! ```
38//!
39//! Python code that validates the functionality
40//!
41//!
42//! ```python
43//! from my_module import add_one
44//! from decimal import Decimal
45//!
46//! d = Decimal("2")
47//! value = add_one(d)
48//!
49//! assert d + 1 == value
50//! ```
51
52use crate::conversion::IntoPyObject;
53use crate::exceptions::PyValueError;
54#[cfg(feature = "experimental-inspect")]
55use crate::inspect::PyStaticExpr;
56use crate::sync::PyOnceLock;
57#[cfg(feature = "experimental-inspect")]
58use crate::type_hint_identifier;
59use crate::types::any::PyAnyMethods;
60use crate::types::string::PyStringMethods;
61use crate::types::PyType;
62use crate::{Borrowed, Bound, FromPyObject, Py, PyAny, PyErr, PyResult, Python};
63use rust_decimal::Decimal;
64use std::str::FromStr;
65
66impl FromPyObject<'_, '_> for Decimal {
67    type Error = PyErr;
68
69    #[cfg(feature = "experimental-inspect")]
70    const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
71
72    fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
73        // use the string representation to not be lossy
74        if let Ok(val) = obj.extract() {
75            Ok(Decimal::new(val, 0))
76        } else {
77            let py_str = &obj.str()?;
78            let rs_str = &py_str.to_cow()?;
79            Decimal::from_str(rs_str).or_else(|_| {
80                Decimal::from_scientific(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
81            })
82        }
83    }
84}
85
86static DECIMAL_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
87
88fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
89    DECIMAL_CLS.import(py, "decimal", "Decimal")
90}
91
92impl<'py> IntoPyObject<'py> for Decimal {
93    type Target = PyAny;
94    type Output = Bound<'py, Self::Target>;
95    type Error = PyErr;
96
97    #[cfg(feature = "experimental-inspect")]
98    const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
99
100    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
101        let dec_cls = get_decimal_cls(py)?;
102        // now call the constructor with the Rust Decimal string-ified
103        // to not be lossy
104        dec_cls.call1((self.to_string(),))
105    }
106}
107
108impl<'py> IntoPyObject<'py> for &Decimal {
109    type Target = PyAny;
110    type Output = Bound<'py, Self::Target>;
111    type Error = PyErr;
112
113    #[cfg(feature = "experimental-inspect")]
114    const OUTPUT_TYPE: PyStaticExpr = Decimal::OUTPUT_TYPE;
115
116    #[inline]
117    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
118        (*self).into_pyobject(py)
119    }
120}
121
122#[cfg(test)]
123mod test_rust_decimal {
124    use super::*;
125    use crate::types::dict::PyDictMethods;
126    use crate::types::PyDict;
127    use std::ffi::CString;
128
129    #[cfg(not(target_arch = "wasm32"))]
130    use proptest::prelude::*;
131
132    macro_rules! convert_constants {
133        ($name:ident, $rs:expr, $py:literal) => {
134            #[test]
135            fn $name() {
136                Python::attach(|py| {
137                    let rs_orig = $rs;
138                    let rs_dec = rs_orig.into_pyobject(py).unwrap();
139                    let locals = PyDict::new(py);
140                    locals.set_item("rs_dec", &rs_dec).unwrap();
141                    // Checks if Rust Decimal -> Python Decimal conversion is correct
142                    py.run(
143                        &CString::new(format!(
144                            "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
145                            $py
146                        ))
147                        .unwrap(),
148                        None,
149                        Some(&locals),
150                    )
151                    .unwrap();
152                    // Checks if Python Decimal -> Rust Decimal conversion is correct
153                    let py_dec = locals.get_item("py_dec").unwrap().unwrap();
154                    let py_result: Decimal = py_dec.extract().unwrap();
155                    assert_eq!(rs_orig, py_result);
156                })
157            }
158        };
159    }
160
161    convert_constants!(convert_zero, Decimal::ZERO, "0");
162    convert_constants!(convert_one, Decimal::ONE, "1");
163    convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
164    convert_constants!(convert_two, Decimal::TWO, "2");
165    convert_constants!(convert_ten, Decimal::TEN, "10");
166    convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
167    convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
168
169    #[cfg(not(target_arch = "wasm32"))]
170    proptest! {
171        #[test]
172        fn test_roundtrip(
173            lo in any::<u32>(),
174            mid in any::<u32>(),
175            high in any::<u32>(),
176            negative in any::<bool>(),
177            scale in 0..28u32
178        ) {
179            let num = Decimal::from_parts(lo, mid, high, negative, scale);
180            Python::attach(|py| {
181                let rs_dec = num.into_pyobject(py).unwrap();
182                let locals = PyDict::new(py);
183                locals.set_item("rs_dec", &rs_dec).unwrap();
184                py.run(
185                    &CString::new(format!(
186                       "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
187                None, Some(&locals)).unwrap();
188                let roundtripped: Decimal = rs_dec.extract().unwrap();
189                assert_eq!(num, roundtripped);
190            })
191        }
192
193        #[test]
194        fn test_integers(num in any::<i64>()) {
195            Python::attach(|py| {
196                let py_num = num.into_pyobject(py).unwrap();
197                let roundtripped: Decimal = py_num.extract().unwrap();
198                let rs_dec = Decimal::new(num, 0);
199                assert_eq!(rs_dec, roundtripped);
200            })
201        }
202    }
203
204    #[test]
205    fn test_nan() {
206        Python::attach(|py| {
207            let locals = PyDict::new(py);
208            py.run(
209                c"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
210                None,
211                Some(&locals),
212            )
213            .unwrap();
214            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
215            let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
216            assert!(roundtripped.is_err());
217        })
218    }
219
220    #[test]
221    fn test_scientific_notation() {
222        Python::attach(|py| {
223            let locals = PyDict::new(py);
224            py.run(
225                c"import decimal\npy_dec = decimal.Decimal(\"1e3\")",
226                None,
227                Some(&locals),
228            )
229            .unwrap();
230            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
231            let roundtripped: Decimal = py_dec.extract().unwrap();
232            let rs_dec = Decimal::from_scientific("1e3").unwrap();
233            assert_eq!(rs_dec, roundtripped);
234        })
235    }
236
237    #[test]
238    fn test_infinity() {
239        Python::attach(|py| {
240            let locals = PyDict::new(py);
241            py.run(
242                c"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
243                None,
244                Some(&locals),
245            )
246            .unwrap();
247            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
248            let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
249            assert!(roundtripped.is_err());
250        })
251    }
252}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here