pyo3/conversions/
bigdecimal.rs

1#![cfg(feature = "bigdecimal")]
2//! Conversions to and from [bigdecimal](https://docs.rs/bigdecimal)'s [`BigDecimal`] type.
3//!
4//! This is useful for converting Python's decimal.Decimal into and from a native Rust type.
5//!
6//! # Setup
7//!
8//! To use this feature, add to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"bigdecimal\"] }")]
13//! bigdecimal = "0.4"
14//! ```
15//!
16//! Note that you must use a compatible version of bigdecimal and PyO3.
17//! The required bigdecimal version may vary based on the version of PyO3.
18//!
19//! # Example
20//!
21//! Rust code to create a function that adds one to a BigDecimal
22//!
23//! ```rust
24//! use bigdecimal::BigDecimal;
25//! use pyo3::prelude::*;
26//!
27//! #[pyfunction]
28//! fn add_one(d: BigDecimal) -> BigDecimal {
29//!     d + 1
30//! }
31//!
32//! #[pymodule]
33//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
34//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
35//!     Ok(())
36//! }
37//! ```
38//!
39//! Python code that validates the functionality
40//!
41//!
42//! ```python
43//! from my_module import add_one
44//! from decimal import Decimal
45//!
46//! d = Decimal("2")
47//! value = add_one(d)
48//!
49//! assert d + 1 == value
50//! ```
51
52use std::str::FromStr;
53
54use crate::{
55    exceptions::PyValueError,
56    sync::GILOnceCell,
57    types::{PyAnyMethods, PyStringMethods, PyType},
58    Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
59};
60use bigdecimal::BigDecimal;
61
62static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
63
64fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
65    DECIMAL_CLS.import(py, "decimal", "Decimal")
66}
67
68impl FromPyObject<'_> for BigDecimal {
69    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
70        let py_str = &obj.str()?;
71        let rs_str = &py_str.to_cow()?;
72        BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
73    }
74}
75
76impl<'py> IntoPyObject<'py> for BigDecimal {
77    type Target = PyAny;
78
79    type Output = Bound<'py, Self::Target>;
80
81    type Error = PyErr;
82
83    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
84        let cls = get_decimal_cls(py)?;
85        cls.call1((self.to_string(),))
86    }
87}
88
89#[cfg(test)]
90mod test_bigdecimal {
91    use super::*;
92    use crate::types::dict::PyDictMethods;
93    use crate::types::PyDict;
94    use std::ffi::CString;
95
96    use crate::ffi;
97    use bigdecimal::{One, Zero};
98    #[cfg(not(target_arch = "wasm32"))]
99    use proptest::prelude::*;
100
101    macro_rules! convert_constants {
102        ($name:ident, $rs:expr, $py:literal) => {
103            #[test]
104            fn $name() {
105                Python::with_gil(|py| {
106                    let rs_orig = $rs;
107                    let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
108                    let locals = PyDict::new(py);
109                    locals.set_item("rs_dec", &rs_dec).unwrap();
110                    // Checks if BigDecimal -> Python Decimal conversion is correct
111                    py.run(
112                        &CString::new(format!(
113                            "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
114                            $py
115                        ))
116                        .unwrap(),
117                        None,
118                        Some(&locals),
119                    )
120                    .unwrap();
121                    // Checks if Python Decimal -> BigDecimal conversion is correct
122                    let py_dec = locals.get_item("py_dec").unwrap().unwrap();
123                    let py_result: BigDecimal = py_dec.extract().unwrap();
124                    assert_eq!(rs_orig, py_result);
125                })
126            }
127        };
128    }
129
130    convert_constants!(convert_zero, BigDecimal::zero(), "0");
131    convert_constants!(convert_one, BigDecimal::one(), "1");
132    convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
133    convert_constants!(convert_two, BigDecimal::from(2), "2");
134    convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
135    convert_constants!(
136        convert_one_hundred_point_one,
137        BigDecimal::from_str("100.1").unwrap(),
138        "100.1"
139    );
140    convert_constants!(
141        convert_one_thousand,
142        BigDecimal::from_str("1000").unwrap(),
143        "1000"
144    );
145    convert_constants!(
146        convert_scientific,
147        BigDecimal::from_str("1e10").unwrap(),
148        "1e10"
149    );
150
151    #[cfg(not(target_arch = "wasm32"))]
152    proptest! {
153        #[test]
154        fn test_roundtrip(
155            number in 0..28u32
156        ) {
157            let num = BigDecimal::from(number);
158            Python::with_gil(|py| {
159                let rs_dec = num.clone().into_pyobject(py).unwrap();
160                let locals = PyDict::new(py);
161                locals.set_item("rs_dec", &rs_dec).unwrap();
162                py.run(
163                    &CString::new(format!(
164                       "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
165                     num)).unwrap(),
166                None, Some(&locals)).unwrap();
167                let roundtripped: BigDecimal = rs_dec.extract().unwrap();
168                assert_eq!(num, roundtripped);
169            })
170        }
171
172        #[test]
173        fn test_integers(num in any::<i64>()) {
174            Python::with_gil(|py| {
175                let py_num = num.into_pyobject(py).unwrap();
176                let roundtripped: BigDecimal = py_num.extract().unwrap();
177                let rs_dec = BigDecimal::from(num);
178                assert_eq!(rs_dec, roundtripped);
179            })
180        }
181    }
182
183    #[test]
184    fn test_nan() {
185        Python::with_gil(|py| {
186            let locals = PyDict::new(py);
187            py.run(
188                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
189                None,
190                Some(&locals),
191            )
192            .unwrap();
193            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
194            let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
195            assert!(roundtripped.is_err());
196        })
197    }
198
199    #[test]
200    fn test_infinity() {
201        Python::with_gil(|py| {
202            let locals = PyDict::new(py);
203            py.run(
204                ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
205                None,
206                Some(&locals),
207            )
208            .unwrap();
209            let py_dec = locals.get_item("py_dec").unwrap().unwrap();
210            let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
211            assert!(roundtripped.is_err());
212        })
213    }
214}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here