pyo3/conversions/
bigdecimal.rs

1#![cfg(feature = "bigdecimal")]
2//! Conversions to and from [bigdecimal](https://docs.rs/bigdecimal)'s [`BigDecimal`] type.
3//!
4//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
5//!
6//! # Setup
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"bigdecimal\"] }")]
13//! bigdecimal = "0.4"
14//! ```
15//!
16//! Note that you must use a compatible version of bigdecimal and PyO3.
17//! The required bigdecimal version may vary based on the version of PyO3.
18//!
19//! # Example
20//!
21//! Rust code to create a function that adds one to a BigDecimal
22//!
23//! ```rust
24//! use bigdecimal::BigDecimal;
25//! use pyo3::prelude::*;
26//!
27//! #[pyfunction]
28//! fn add_one(d: BigDecimal) -> BigDecimal {
29//!     d + 1
30//! }
31//!
32//! #[pymodule]
33//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
34//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
35//!     Ok(())
36//! }
37//! ```
38//!
39//! Python code that validates the functionality
40//!
41//!
42//! ```python
43//! from my_module import add_one
44//! from decimal import Decimal
45//!
46//! d = Decimal("2")
47//! value = add_one(d)
48//!
49//! assert d + 1 == value
50//! ```
51
52use std::str::FromStr;
53
54use crate::types::PyTuple;
55use crate::{
56    exceptions::PyValueError,
57    sync::GILOnceCell,
58    types::{PyAnyMethods, PyStringMethods, PyType},
59    Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
60};
61use bigdecimal::BigDecimal;
62use num_bigint::Sign;
63
64fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
65    static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
66    DECIMAL_CLS.import(py, "decimal", "Decimal")
67}
68
69fn get_invalid_operation_error_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
70    static INVALID_OPERATION_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
71    INVALID_OPERATION_CLS.import(py, "decimal", "InvalidOperation")
72}
73
74impl FromPyObject<'_> for BigDecimal {
75    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
76        let py_str = &obj.str()?;
77        let rs_str = &py_str.to_cow()?;
78        BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
79    }
80}
81
82impl<'py> IntoPyObject<'py> for BigDecimal {
83    type Target = PyAny;
84
85    type Output = Bound<'py, Self::Target>;
86
87    type Error = PyErr;
88
89    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
90        let cls = get_decimal_cls(py)?;
91        let (bigint, scale) = self.into_bigint_and_scale();
92        if scale == 0 {
93            return cls.call1((bigint,));
94        }
95        let exponent = scale.checked_neg().ok_or_else(|| {
96            get_invalid_operation_error_cls(py)
97                .map_or_else(|err| err, |cls| PyErr::from_type(cls.clone(), ()))
98        })?;
99        let (sign, digits) = bigint.to_radix_be(10);
100        let signed = matches!(sign, Sign::Minus).into_pyobject(py)?;
101        let digits = PyTuple::new(py, digits)?;
102
103        cls.call1(((signed, digits, exponent),))
104    }
105}
106
107#[cfg(test)]
108mod test_bigdecimal {
109    use super::*;
110    use crate::types::dict::PyDictMethods;
111    use crate::types::PyDict;
112    use std::ffi::CString;
113
114    use crate::ffi;
115    use bigdecimal::{One, Zero};
116    #[cfg(not(target_arch = "wasm32"))]
117    use proptest::prelude::*;
118
119    macro_rules! convert_constants {
120        ($name:ident, $rs:expr, $py:literal) => {
121            #[test]
122            fn $name() {
123                Python::attach(|py| {
124                    let rs_orig = $rs;
125                    let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
126                    let locals = PyDict::new(py);
127                    locals.set_item("rs_dec", &rs_dec).unwrap();
128                    // Checks if BigDecimal -> Python Decimal conversion is correct
129                    py.run(
130                        &CString::new(format!(
131                            "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
132                            $py
133                        ))
134                        .unwrap(),
135                        None,
136                        Some(&locals),
137                    )
138                    .unwrap();
139                    // Checks if Python Decimal -> BigDecimal conversion is correct
140                    let py_dec = locals.get_item("py_dec").unwrap().unwrap();
141                    let py_result: BigDecimal = py_dec.extract().unwrap();
142                    assert_eq!(rs_orig, py_result);
143                })
144            }
145        };
146    }
147
148    convert_constants!(convert_zero, BigDecimal::zero(), "0");
149    convert_constants!(convert_one, BigDecimal::one(), "1");
150    convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
151    convert_constants!(convert_two, BigDecimal::from(2), "2");
152    convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
153    convert_constants!(
154        convert_one_hundred_point_one,
155        BigDecimal::from_str("100.1").unwrap(),
156        "100.1"
157    );
158    convert_constants!(
159        convert_one_thousand,
160        BigDecimal::from_str("1000").unwrap(),
161        "1000"
162    );
163    convert_constants!(
164        convert_scientific,
165        BigDecimal::from_str("1e10").unwrap(),
166        "1e10"
167    );
168
169    #[cfg(not(target_arch = "wasm32"))]
170    proptest! {
171        #[test]
172        fn test_roundtrip(
173            number in 0..28u32
174        ) {
175            let num = BigDecimal::from(number);
176            Python::attach(|py| {
177                let rs_dec = num.clone().into_pyobject(py).unwrap();
178                let locals = PyDict::new(py);
179                locals.set_item("rs_dec", &rs_dec).unwrap();
180                py.run(
181                    &CString::new(format!(
182                       "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
183                None, Some(&locals)).unwrap();
184                let roundtripped: BigDecimal = rs_dec.extract().unwrap();
185                assert_eq!(num, roundtripped);
186            })
187        }
188
189        #[test]
190        fn test_integers(num in any::<i64>()) {
191            Python::attach(|py| {
192                let py_num = num.into_pyobject(py).unwrap();
193                let roundtripped: BigDecimal = py_num.extract().unwrap();
194                let rs_dec = BigDecimal::from(num);
195                assert_eq!(rs_dec, roundtripped);
196            })
197        }
198    }
199
200    #[test]
201    fn test_nan() {
202        Python::attach(|py| {
203            let locals = PyDict::new(py);
204            py.run(
205                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
206                None,
207                Some(&locals),
208            )
209            .unwrap();
210            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
211            let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
212            assert!(roundtripped.is_err());
213        })
214    }
215
216    #[test]
217    fn test_infinity() {
218        Python::attach(|py| {
219            let locals = PyDict::new(py);
220            py.run(
221                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
222                None,
223                Some(&locals),
224            )
225            .unwrap();
226            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
227            let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
228            assert!(roundtripped.is_err());
229        })
230    }
231
232    #[test]
233    fn test_no_precision_loss() {
234        Python::attach(|py| {
235            let src = "1e4";
236            let expected = get_decimal_cls(py)
237                .unwrap()
238                .call1((src,))
239                .unwrap()
240                .call_method0("as_tuple")
241                .unwrap();
242            let actual = src
243                .parse::<BigDecimal>()
244                .unwrap()
245                .into_pyobject(py)
246                .unwrap()
247                .call_method0("as_tuple")
248                .unwrap();
249
250            assert!(actual.eq(expected).unwrap());
251        });
252    }
253}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here