pyo3/conversions/
bigdecimal.rs1#![cfg(feature = "bigdecimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"bigdecimal\"] }")]
13use std::str::FromStr;
53
54use crate::{
55 exceptions::PyValueError,
56 sync::GILOnceCell,
57 types::{PyAnyMethods, PyStringMethods, PyType},
58 Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
59};
60use bigdecimal::BigDecimal;
61
62static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
63
64fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
65 DECIMAL_CLS.import(py, "decimal", "Decimal")
66}
67
68impl FromPyObject<'_> for BigDecimal {
69 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
70 let py_str = &obj.str()?;
71 let rs_str = &py_str.to_cow()?;
72 BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
73 }
74}
75
76impl<'py> IntoPyObject<'py> for BigDecimal {
77 type Target = PyAny;
78
79 type Output = Bound<'py, Self::Target>;
80
81 type Error = PyErr;
82
83 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
84 let cls = get_decimal_cls(py)?;
85 cls.call1((self.to_string(),))
86 }
87}
88
89#[cfg(test)]
90mod test_bigdecimal {
91 use super::*;
92 use crate::types::dict::PyDictMethods;
93 use crate::types::PyDict;
94 use std::ffi::CString;
95
96 use crate::ffi;
97 use bigdecimal::{One, Zero};
98 #[cfg(not(target_arch = "wasm32"))]
99 use proptest::prelude::*;
100
101 macro_rules! convert_constants {
102 ($name:ident, $rs:expr, $py:literal) => {
103 #[test]
104 fn $name() {
105 Python::with_gil(|py| {
106 let rs_orig = $rs;
107 let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
108 let locals = PyDict::new(py);
109 locals.set_item("rs_dec", &rs_dec).unwrap();
110 py.run(
112 &CString::new(format!(
113 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
114 $py
115 ))
116 .unwrap(),
117 None,
118 Some(&locals),
119 )
120 .unwrap();
121 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
123 let py_result: BigDecimal = py_dec.extract().unwrap();
124 assert_eq!(rs_orig, py_result);
125 })
126 }
127 };
128 }
129
130 convert_constants!(convert_zero, BigDecimal::zero(), "0");
131 convert_constants!(convert_one, BigDecimal::one(), "1");
132 convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
133 convert_constants!(convert_two, BigDecimal::from(2), "2");
134 convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
135 convert_constants!(
136 convert_one_hundred_point_one,
137 BigDecimal::from_str("100.1").unwrap(),
138 "100.1"
139 );
140 convert_constants!(
141 convert_one_thousand,
142 BigDecimal::from_str("1000").unwrap(),
143 "1000"
144 );
145 convert_constants!(
146 convert_scientific,
147 BigDecimal::from_str("1e10").unwrap(),
148 "1e10"
149 );
150
151 #[cfg(not(target_arch = "wasm32"))]
152 proptest! {
153 #[test]
154 fn test_roundtrip(
155 number in 0..28u32
156 ) {
157 let num = BigDecimal::from(number);
158 Python::with_gil(|py| {
159 let rs_dec = num.clone().into_pyobject(py).unwrap();
160 let locals = PyDict::new(py);
161 locals.set_item("rs_dec", &rs_dec).unwrap();
162 py.run(
163 &CString::new(format!(
164 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
165 num)).unwrap(),
166 None, Some(&locals)).unwrap();
167 let roundtripped: BigDecimal = rs_dec.extract().unwrap();
168 assert_eq!(num, roundtripped);
169 })
170 }
171
172 #[test]
173 fn test_integers(num in any::<i64>()) {
174 Python::with_gil(|py| {
175 let py_num = num.into_pyobject(py).unwrap();
176 let roundtripped: BigDecimal = py_num.extract().unwrap();
177 let rs_dec = BigDecimal::from(num);
178 assert_eq!(rs_dec, roundtripped);
179 })
180 }
181 }
182
183 #[test]
184 fn test_nan() {
185 Python::with_gil(|py| {
186 let locals = PyDict::new(py);
187 py.run(
188 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
189 None,
190 Some(&locals),
191 )
192 .unwrap();
193 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
194 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
195 assert!(roundtripped.is_err());
196 })
197 }
198
199 #[test]
200 fn test_infinity() {
201 Python::with_gil(|py| {
202 let locals = PyDict::new(py);
203 py.run(
204 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
205 None,
206 Some(&locals),
207 )
208 .unwrap();
209 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
210 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
211 assert!(roundtripped.is_err());
212 })
213 }
214}