Skip to main content

pyo3/conversions/
bigdecimal.rs

1#![cfg(feature = "bigdecimal")]
2//! Conversions to and from [bigdecimal](https://docs.rs/bigdecimal)'s [`BigDecimal`] type.
3//!
4//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
5//!
6//! # Setup
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"bigdecimal\"] }")]
13//! bigdecimal = "0.4"
14//! ```
15//!
16//! Note that you must use a compatible version of bigdecimal and PyO3.
17//! The required bigdecimal version may vary based on the version of PyO3.
18//!
19//! # Example
20//!
21//! Rust code to create a function that adds one to a BigDecimal
22//!
23//! ```rust
24//! use bigdecimal::BigDecimal;
25//! use pyo3::prelude::*;
26//!
27//! #[pyfunction]
28//! fn add_one(d: BigDecimal) -> BigDecimal {
29//!     d + 1
30//! }
31//!
32//! #[pymodule]
33//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
34//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
35//!     Ok(())
36//! }
37//! ```
38//!
39//! Python code that validates the functionality
40//!
41//!
42//! ```python
43//! from my_module import add_one
44//! from decimal import Decimal
45//!
46//! d = Decimal("2")
47//! value = add_one(d)
48//!
49//! assert d + 1 == value
50//! ```
51
52use core::str::FromStr;
53
54#[cfg(feature = "experimental-inspect")]
55use crate::inspect::PyStaticExpr;
56use crate::platform::prelude::*;
57#[cfg(feature = "experimental-inspect")]
58use crate::type_hint_identifier;
59use crate::types::PyTuple;
60use crate::{
61    exceptions::PyValueError,
62    sync::PyOnceLock,
63    types::{PyAnyMethods, PyStringMethods, PyType},
64    Borrowed, Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
65};
66use bigdecimal::BigDecimal;
67use num_bigint::Sign;
68
69fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
70    static DECIMAL_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
71    DECIMAL_CLS.import(py, "decimal", "Decimal")
72}
73
74fn get_invalid_operation_error_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
75    static INVALID_OPERATION_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
76    INVALID_OPERATION_CLS.import(py, "decimal", "InvalidOperation")
77}
78
79impl FromPyObject<'_, '_> for BigDecimal {
80    type Error = PyErr;
81
82    #[cfg(feature = "experimental-inspect")]
83    const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
84
85    fn extract(obj: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
86        let py_str = &obj.str()?;
87        let rs_str = &py_str.to_cow()?;
88        BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
89    }
90}
91
92impl<'py> IntoPyObject<'py> for BigDecimal {
93    type Target = PyAny;
94
95    type Output = Bound<'py, Self::Target>;
96
97    type Error = PyErr;
98
99    #[cfg(feature = "experimental-inspect")]
100    const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
101
102    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
103        let cls = get_decimal_cls(py)?;
104        let (bigint, scale) = self.into_bigint_and_scale();
105        if scale == 0 {
106            return cls.call1((bigint,));
107        }
108        let exponent = scale.checked_neg().ok_or_else(|| {
109            get_invalid_operation_error_cls(py)
110                .map_or_else(|err| err, |cls| PyErr::from_type(cls.clone(), ()))
111        })?;
112        let (sign, digits) = bigint.to_radix_be(10);
113        let signed = matches!(sign, Sign::Minus).into_pyobject(py)?;
114        let digits = PyTuple::new(py, digits)?;
115
116        cls.call1(((signed, digits, exponent),))
117    }
118}
119
120#[cfg(test)]
121mod test_bigdecimal {
122    use super::*;
123    use crate::types::dict::PyDictMethods;
124    use crate::types::PyDict;
125    use alloc::ffi::CString;
126
127    use bigdecimal::{One, Zero};
128    #[cfg(not(target_arch = "wasm32"))]
129    use proptest::prelude::*;
130
131    macro_rules! convert_constants {
132        ($name:ident, $rs:expr, $py:literal) => {
133            #[test]
134            fn $name() {
135                Python::attach(|py| {
136                    let rs_orig = $rs;
137                    let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
138                    let locals = PyDict::new(py);
139                    locals.set_item("rs_dec", &rs_dec).unwrap();
140                    // Checks if BigDecimal -> Python Decimal conversion is correct
141                    py.run(
142                        &CString::new(format!(
143                            "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
144                            $py
145                        ))
146                        .unwrap(),
147                        None,
148                        Some(&locals),
149                    )
150                    .unwrap();
151                    // Checks if Python Decimal -> BigDecimal conversion is correct
152                    let py_dec = locals.get_item("py_dec").unwrap().unwrap();
153                    let py_result: BigDecimal = py_dec.extract().unwrap();
154                    assert_eq!(rs_orig, py_result);
155                })
156            }
157        };
158    }
159
160    convert_constants!(convert_zero, BigDecimal::zero(), "0");
161    convert_constants!(convert_one, BigDecimal::one(), "1");
162    convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
163    convert_constants!(convert_two, BigDecimal::from(2), "2");
164    convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
165    convert_constants!(
166        convert_one_hundred_point_one,
167        BigDecimal::from_str("100.1").unwrap(),
168        "100.1"
169    );
170    convert_constants!(
171        convert_one_thousand,
172        BigDecimal::from_str("1000").unwrap(),
173        "1000"
174    );
175    convert_constants!(
176        convert_scientific,
177        BigDecimal::from_str("1e10").unwrap(),
178        "1e10"
179    );
180
181    #[cfg(not(target_arch = "wasm32"))]
182    proptest! {
183        #[test]
184        fn test_roundtrip(
185            number in 0..28u32
186        ) {
187            let num = BigDecimal::from(number);
188            Python::attach(|py| {
189                let rs_dec = num.clone().into_pyobject(py).unwrap();
190                let locals = PyDict::new(py);
191                locals.set_item("rs_dec", &rs_dec).unwrap();
192                py.run(
193                    &CString::new(format!(
194                       "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
195                None, Some(&locals)).unwrap();
196                let roundtripped: BigDecimal = rs_dec.extract().unwrap();
197                assert_eq!(num, roundtripped);
198            })
199        }
200
201        #[test]
202        fn test_integers(num in any::<i64>()) {
203            Python::attach(|py| {
204                let py_num = num.into_pyobject(py).unwrap();
205                let roundtripped: BigDecimal = py_num.extract().unwrap();
206                let rs_dec = BigDecimal::from(num);
207                assert_eq!(rs_dec, roundtripped);
208            })
209        }
210    }
211
212    #[test]
213    fn test_nan() {
214        Python::attach(|py| {
215            let locals = PyDict::new(py);
216            py.run(
217                c"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
218                None,
219                Some(&locals),
220            )
221            .unwrap();
222            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
223            let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
224            assert!(roundtripped.is_err());
225        })
226    }
227
228    #[test]
229    fn test_infinity() {
230        Python::attach(|py| {
231            let locals = PyDict::new(py);
232            py.run(
233                c"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
234                None,
235                Some(&locals),
236            )
237            .unwrap();
238            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
239            let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
240            assert!(roundtripped.is_err());
241        })
242    }
243
244    #[test]
245    fn test_no_precision_loss() {
246        Python::attach(|py| {
247            let src = "1e4";
248            let expected = get_decimal_cls(py)
249                .unwrap()
250                .call1((src,))
251                .unwrap()
252                .call_method0("as_tuple")
253                .unwrap();
254            let actual = src
255                .parse::<BigDecimal>()
256                .unwrap()
257                .into_pyobject(py)
258                .unwrap()
259                .call_method0("as_tuple")
260                .unwrap();
261
262            assert!(actual.eq(expected).unwrap());
263        });
264    }
265}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here