pyo3/conversions/
bigdecimal.rs1#![cfg(feature = "bigdecimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"bigdecimal\"] }")]
13use std::str::FromStr;
53
54use crate::types::PyTuple;
55use crate::{
56 exceptions::PyValueError,
57 sync::GILOnceCell,
58 types::{PyAnyMethods, PyStringMethods, PyType},
59 Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
60};
61use bigdecimal::BigDecimal;
62use num_bigint::Sign;
63
64fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
65 static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
66 DECIMAL_CLS.import(py, "decimal", "Decimal")
67}
68
69fn get_invalid_operation_error_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
70 static INVALID_OPERATION_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
71 INVALID_OPERATION_CLS.import(py, "decimal", "InvalidOperation")
72}
73
74impl FromPyObject<'_> for BigDecimal {
75 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
76 let py_str = &obj.str()?;
77 let rs_str = &py_str.to_cow()?;
78 BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
79 }
80}
81
82impl<'py> IntoPyObject<'py> for BigDecimal {
83 type Target = PyAny;
84
85 type Output = Bound<'py, Self::Target>;
86
87 type Error = PyErr;
88
89 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
90 let cls = get_decimal_cls(py)?;
91 let (bigint, scale) = self.into_bigint_and_scale();
92 if scale == 0 {
93 return cls.call1((bigint,));
94 }
95 let exponent = scale.checked_neg().ok_or_else(|| {
96 get_invalid_operation_error_cls(py)
97 .map_or_else(|err| err, |cls| PyErr::from_type(cls.clone(), ()))
98 })?;
99 let (sign, digits) = bigint.to_radix_be(10);
100 let signed = matches!(sign, Sign::Minus).into_pyobject(py)?;
101 let digits = PyTuple::new(py, digits)?;
102
103 cls.call1(((signed, digits, exponent),))
104 }
105}
106
107#[cfg(test)]
108mod test_bigdecimal {
109 use super::*;
110 use crate::types::dict::PyDictMethods;
111 use crate::types::PyDict;
112 use std::ffi::CString;
113
114 use crate::ffi;
115 use bigdecimal::{One, Zero};
116 #[cfg(not(target_arch = "wasm32"))]
117 use proptest::prelude::*;
118
119 macro_rules! convert_constants {
120 ($name:ident, $rs:expr, $py:literal) => {
121 #[test]
122 fn $name() {
123 Python::attach(|py| {
124 let rs_orig = $rs;
125 let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
126 let locals = PyDict::new(py);
127 locals.set_item("rs_dec", &rs_dec).unwrap();
128 py.run(
130 &CString::new(format!(
131 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
132 $py
133 ))
134 .unwrap(),
135 None,
136 Some(&locals),
137 )
138 .unwrap();
139 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
141 let py_result: BigDecimal = py_dec.extract().unwrap();
142 assert_eq!(rs_orig, py_result);
143 })
144 }
145 };
146 }
147
148 convert_constants!(convert_zero, BigDecimal::zero(), "0");
149 convert_constants!(convert_one, BigDecimal::one(), "1");
150 convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
151 convert_constants!(convert_two, BigDecimal::from(2), "2");
152 convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
153 convert_constants!(
154 convert_one_hundred_point_one,
155 BigDecimal::from_str("100.1").unwrap(),
156 "100.1"
157 );
158 convert_constants!(
159 convert_one_thousand,
160 BigDecimal::from_str("1000").unwrap(),
161 "1000"
162 );
163 convert_constants!(
164 convert_scientific,
165 BigDecimal::from_str("1e10").unwrap(),
166 "1e10"
167 );
168
169 #[cfg(not(target_arch = "wasm32"))]
170 proptest! {
171 #[test]
172 fn test_roundtrip(
173 number in 0..28u32
174 ) {
175 let num = BigDecimal::from(number);
176 Python::attach(|py| {
177 let rs_dec = num.clone().into_pyobject(py).unwrap();
178 let locals = PyDict::new(py);
179 locals.set_item("rs_dec", &rs_dec).unwrap();
180 py.run(
181 &CString::new(format!(
182 "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
183 None, Some(&locals)).unwrap();
184 let roundtripped: BigDecimal = rs_dec.extract().unwrap();
185 assert_eq!(num, roundtripped);
186 })
187 }
188
189 #[test]
190 fn test_integers(num in any::<i64>()) {
191 Python::attach(|py| {
192 let py_num = num.into_pyobject(py).unwrap();
193 let roundtripped: BigDecimal = py_num.extract().unwrap();
194 let rs_dec = BigDecimal::from(num);
195 assert_eq!(rs_dec, roundtripped);
196 })
197 }
198 }
199
200 #[test]
201 fn test_nan() {
202 Python::attach(|py| {
203 let locals = PyDict::new(py);
204 py.run(
205 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
206 None,
207 Some(&locals),
208 )
209 .unwrap();
210 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
211 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
212 assert!(roundtripped.is_err());
213 })
214 }
215
216 #[test]
217 fn test_infinity() {
218 Python::attach(|py| {
219 let locals = PyDict::new(py);
220 py.run(
221 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
222 None,
223 Some(&locals),
224 )
225 .unwrap();
226 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
227 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
228 assert!(roundtripped.is_err());
229 })
230 }
231
232 #[test]
233 fn test_no_precision_loss() {
234 Python::attach(|py| {
235 let src = "1e4";
236 let expected = get_decimal_cls(py)
237 .unwrap()
238 .call1((src,))
239 .unwrap()
240 .call_method0("as_tuple")
241 .unwrap();
242 let actual = src
243 .parse::<BigDecimal>()
244 .unwrap()
245 .into_pyobject(py)
246 .unwrap()
247 .call_method0("as_tuple")
248 .unwrap();
249
250 assert!(actual.eq(expected).unwrap());
251 });
252 }
253}