pyo3/conversions/
bigdecimal.rs1#![cfg(feature = "bigdecimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"bigdecimal\"] }")]
13use core::str::FromStr;
53
54#[cfg(feature = "experimental-inspect")]
55use crate::inspect::PyStaticExpr;
56use crate::platform::prelude::*;
57#[cfg(feature = "experimental-inspect")]
58use crate::type_hint_identifier;
59use crate::types::PyTuple;
60use crate::{
61 exceptions::PyValueError,
62 sync::PyOnceLock,
63 types::{PyAnyMethods, PyStringMethods, PyType},
64 Borrowed, Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
65};
66use bigdecimal::BigDecimal;
67use num_bigint::Sign;
68
69fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
70 static DECIMAL_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
71 DECIMAL_CLS.import(py, "decimal", "Decimal")
72}
73
74fn get_invalid_operation_error_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
75 static INVALID_OPERATION_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
76 INVALID_OPERATION_CLS.import(py, "decimal", "InvalidOperation")
77}
78
79impl FromPyObject<'_, '_> for BigDecimal {
80 type Error = PyErr;
81
82 #[cfg(feature = "experimental-inspect")]
83 const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
84
85 fn extract(obj: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
86 let py_str = &obj.str()?;
87 let rs_str = &py_str.to_cow()?;
88 BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
89 }
90}
91
92impl<'py> IntoPyObject<'py> for BigDecimal {
93 type Target = PyAny;
94
95 type Output = Bound<'py, Self::Target>;
96
97 type Error = PyErr;
98
99 #[cfg(feature = "experimental-inspect")]
100 const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
101
102 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
103 let cls = get_decimal_cls(py)?;
104 let (bigint, scale) = self.into_bigint_and_scale();
105 if scale == 0 {
106 return cls.call1((bigint,));
107 }
108 let exponent = scale.checked_neg().ok_or_else(|| {
109 get_invalid_operation_error_cls(py)
110 .map_or_else(|err| err, |cls| PyErr::from_type(cls.clone(), ()))
111 })?;
112 let (sign, digits) = bigint.to_radix_be(10);
113 let signed = matches!(sign, Sign::Minus).into_pyobject(py)?;
114 let digits = PyTuple::new(py, digits)?;
115
116 cls.call1(((signed, digits, exponent),))
117 }
118}
119
120#[cfg(test)]
121mod test_bigdecimal {
122 use super::*;
123 use crate::types::dict::PyDictMethods;
124 use crate::types::PyDict;
125 use alloc::ffi::CString;
126
127 use bigdecimal::{One, Zero};
128 #[cfg(not(target_arch = "wasm32"))]
129 use proptest::prelude::*;
130
131 macro_rules! convert_constants {
132 ($name:ident, $rs:expr, $py:literal) => {
133 #[test]
134 fn $name() {
135 Python::attach(|py| {
136 let rs_orig = $rs;
137 let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
138 let locals = PyDict::new(py);
139 locals.set_item("rs_dec", &rs_dec).unwrap();
140 py.run(
142 &CString::new(format!(
143 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
144 $py
145 ))
146 .unwrap(),
147 None,
148 Some(&locals),
149 )
150 .unwrap();
151 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
153 let py_result: BigDecimal = py_dec.extract().unwrap();
154 assert_eq!(rs_orig, py_result);
155 })
156 }
157 };
158 }
159
160 convert_constants!(convert_zero, BigDecimal::zero(), "0");
161 convert_constants!(convert_one, BigDecimal::one(), "1");
162 convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
163 convert_constants!(convert_two, BigDecimal::from(2), "2");
164 convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
165 convert_constants!(
166 convert_one_hundred_point_one,
167 BigDecimal::from_str("100.1").unwrap(),
168 "100.1"
169 );
170 convert_constants!(
171 convert_one_thousand,
172 BigDecimal::from_str("1000").unwrap(),
173 "1000"
174 );
175 convert_constants!(
176 convert_scientific,
177 BigDecimal::from_str("1e10").unwrap(),
178 "1e10"
179 );
180
181 #[cfg(not(target_arch = "wasm32"))]
182 proptest! {
183 #[test]
184 fn test_roundtrip(
185 number in 0..28u32
186 ) {
187 let num = BigDecimal::from(number);
188 Python::attach(|py| {
189 let rs_dec = num.clone().into_pyobject(py).unwrap();
190 let locals = PyDict::new(py);
191 locals.set_item("rs_dec", &rs_dec).unwrap();
192 py.run(
193 &CString::new(format!(
194 "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
195 None, Some(&locals)).unwrap();
196 let roundtripped: BigDecimal = rs_dec.extract().unwrap();
197 assert_eq!(num, roundtripped);
198 })
199 }
200
201 #[test]
202 fn test_integers(num in any::<i64>()) {
203 Python::attach(|py| {
204 let py_num = num.into_pyobject(py).unwrap();
205 let roundtripped: BigDecimal = py_num.extract().unwrap();
206 let rs_dec = BigDecimal::from(num);
207 assert_eq!(rs_dec, roundtripped);
208 })
209 }
210 }
211
212 #[test]
213 fn test_nan() {
214 Python::attach(|py| {
215 let locals = PyDict::new(py);
216 py.run(
217 c"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
218 None,
219 Some(&locals),
220 )
221 .unwrap();
222 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
223 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
224 assert!(roundtripped.is_err());
225 })
226 }
227
228 #[test]
229 fn test_infinity() {
230 Python::attach(|py| {
231 let locals = PyDict::new(py);
232 py.run(
233 c"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
234 None,
235 Some(&locals),
236 )
237 .unwrap();
238 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
239 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
240 assert!(roundtripped.is_err());
241 })
242 }
243
244 #[test]
245 fn test_no_precision_loss() {
246 Python::attach(|py| {
247 let src = "1e4";
248 let expected = get_decimal_cls(py)
249 .unwrap()
250 .call1((src,))
251 .unwrap()
252 .call_method0("as_tuple")
253 .unwrap();
254 let actual = src
255 .parse::<BigDecimal>()
256 .unwrap()
257 .into_pyobject(py)
258 .unwrap()
259 .call_method0("as_tuple")
260 .unwrap();
261
262 assert!(actual.eq(expected).unwrap());
263 });
264 }
265}