pyo3/conversions/
bigdecimal.rs1#![cfg(feature = "bigdecimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"bigdecimal\"] }")]
13use std::str::FromStr;
53
54#[cfg(feature = "experimental-inspect")]
55use crate::inspect::PyStaticExpr;
56#[cfg(feature = "experimental-inspect")]
57use crate::type_hint_identifier;
58use crate::types::PyTuple;
59use crate::{
60 exceptions::PyValueError,
61 sync::PyOnceLock,
62 types::{PyAnyMethods, PyStringMethods, PyType},
63 Borrowed, Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
64};
65use bigdecimal::BigDecimal;
66use num_bigint::Sign;
67
68fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
69 static DECIMAL_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
70 DECIMAL_CLS.import(py, "decimal", "Decimal")
71}
72
73fn get_invalid_operation_error_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
74 static INVALID_OPERATION_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
75 INVALID_OPERATION_CLS.import(py, "decimal", "InvalidOperation")
76}
77
78impl FromPyObject<'_, '_> for BigDecimal {
79 type Error = PyErr;
80
81 #[cfg(feature = "experimental-inspect")]
82 const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
83
84 fn extract(obj: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
85 let py_str = &obj.str()?;
86 let rs_str = &py_str.to_cow()?;
87 BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
88 }
89}
90
91impl<'py> IntoPyObject<'py> for BigDecimal {
92 type Target = PyAny;
93
94 type Output = Bound<'py, Self::Target>;
95
96 type Error = PyErr;
97
98 #[cfg(feature = "experimental-inspect")]
99 const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
100
101 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
102 let cls = get_decimal_cls(py)?;
103 let (bigint, scale) = self.into_bigint_and_scale();
104 if scale == 0 {
105 return cls.call1((bigint,));
106 }
107 let exponent = scale.checked_neg().ok_or_else(|| {
108 get_invalid_operation_error_cls(py)
109 .map_or_else(|err| err, |cls| PyErr::from_type(cls.clone(), ()))
110 })?;
111 let (sign, digits) = bigint.to_radix_be(10);
112 let signed = matches!(sign, Sign::Minus).into_pyobject(py)?;
113 let digits = PyTuple::new(py, digits)?;
114
115 cls.call1(((signed, digits, exponent),))
116 }
117}
118
119#[cfg(test)]
120mod test_bigdecimal {
121 use super::*;
122 use crate::types::dict::PyDictMethods;
123 use crate::types::PyDict;
124 use std::ffi::CString;
125
126 use bigdecimal::{One, Zero};
127 #[cfg(not(target_arch = "wasm32"))]
128 use proptest::prelude::*;
129
130 macro_rules! convert_constants {
131 ($name:ident, $rs:expr, $py:literal) => {
132 #[test]
133 fn $name() {
134 Python::attach(|py| {
135 let rs_orig = $rs;
136 let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
137 let locals = PyDict::new(py);
138 locals.set_item("rs_dec", &rs_dec).unwrap();
139 py.run(
141 &CString::new(format!(
142 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
143 $py
144 ))
145 .unwrap(),
146 None,
147 Some(&locals),
148 )
149 .unwrap();
150 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
152 let py_result: BigDecimal = py_dec.extract().unwrap();
153 assert_eq!(rs_orig, py_result);
154 })
155 }
156 };
157 }
158
159 convert_constants!(convert_zero, BigDecimal::zero(), "0");
160 convert_constants!(convert_one, BigDecimal::one(), "1");
161 convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
162 convert_constants!(convert_two, BigDecimal::from(2), "2");
163 convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
164 convert_constants!(
165 convert_one_hundred_point_one,
166 BigDecimal::from_str("100.1").unwrap(),
167 "100.1"
168 );
169 convert_constants!(
170 convert_one_thousand,
171 BigDecimal::from_str("1000").unwrap(),
172 "1000"
173 );
174 convert_constants!(
175 convert_scientific,
176 BigDecimal::from_str("1e10").unwrap(),
177 "1e10"
178 );
179
180 #[cfg(not(target_arch = "wasm32"))]
181 proptest! {
182 #[test]
183 fn test_roundtrip(
184 number in 0..28u32
185 ) {
186 let num = BigDecimal::from(number);
187 Python::attach(|py| {
188 let rs_dec = num.clone().into_pyobject(py).unwrap();
189 let locals = PyDict::new(py);
190 locals.set_item("rs_dec", &rs_dec).unwrap();
191 py.run(
192 &CString::new(format!(
193 "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
194 None, Some(&locals)).unwrap();
195 let roundtripped: BigDecimal = rs_dec.extract().unwrap();
196 assert_eq!(num, roundtripped);
197 })
198 }
199
200 #[test]
201 fn test_integers(num in any::<i64>()) {
202 Python::attach(|py| {
203 let py_num = num.into_pyobject(py).unwrap();
204 let roundtripped: BigDecimal = py_num.extract().unwrap();
205 let rs_dec = BigDecimal::from(num);
206 assert_eq!(rs_dec, roundtripped);
207 })
208 }
209 }
210
211 #[test]
212 fn test_nan() {
213 Python::attach(|py| {
214 let locals = PyDict::new(py);
215 py.run(
216 c"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
217 None,
218 Some(&locals),
219 )
220 .unwrap();
221 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
222 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
223 assert!(roundtripped.is_err());
224 })
225 }
226
227 #[test]
228 fn test_infinity() {
229 Python::attach(|py| {
230 let locals = PyDict::new(py);
231 py.run(
232 c"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
233 None,
234 Some(&locals),
235 )
236 .unwrap();
237 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
238 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
239 assert!(roundtripped.is_err());
240 })
241 }
242
243 #[test]
244 fn test_no_precision_loss() {
245 Python::attach(|py| {
246 let src = "1e4";
247 let expected = get_decimal_cls(py)
248 .unwrap()
249 .call1((src,))
250 .unwrap()
251 .call_method0("as_tuple")
252 .unwrap();
253 let actual = src
254 .parse::<BigDecimal>()
255 .unwrap()
256 .into_pyobject(py)
257 .unwrap()
258 .call_method0("as_tuple")
259 .unwrap();
260
261 assert!(actual.eq(expected).unwrap());
262 });
263 }
264}