pyo3/conversions/
rust_decimal.rs1#![cfg(feature = "rust_decimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"rust_decimal\"] }")]
13use crate::conversion::IntoPyObject;
53use crate::exceptions::PyValueError;
54#[cfg(feature = "experimental-inspect")]
55use crate::inspect::PyStaticExpr;
56use crate::sync::PyOnceLock;
57#[cfg(feature = "experimental-inspect")]
58use crate::type_hint_identifier;
59use crate::types::any::PyAnyMethods;
60use crate::types::string::PyStringMethods;
61use crate::types::PyType;
62use crate::{Borrowed, Bound, FromPyObject, Py, PyAny, PyErr, PyResult, Python};
63use rust_decimal::Decimal;
64use std::str::FromStr;
65
66impl FromPyObject<'_, '_> for Decimal {
67 type Error = PyErr;
68
69 #[cfg(feature = "experimental-inspect")]
70 const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
71
72 fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
73 if let Ok(val) = obj.extract() {
75 Ok(Decimal::new(val, 0))
76 } else {
77 let py_str = &obj.str()?;
78 let rs_str = &py_str.to_cow()?;
79 Decimal::from_str(rs_str).or_else(|_| {
80 Decimal::from_scientific(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
81 })
82 }
83 }
84}
85
86static DECIMAL_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
87
88fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
89 DECIMAL_CLS.import(py, "decimal", "Decimal")
90}
91
92impl<'py> IntoPyObject<'py> for Decimal {
93 type Target = PyAny;
94 type Output = Bound<'py, Self::Target>;
95 type Error = PyErr;
96
97 #[cfg(feature = "experimental-inspect")]
98 const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("decimal", "Decimal");
99
100 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
101 let dec_cls = get_decimal_cls(py)?;
102 dec_cls.call1((self.to_string(),))
105 }
106}
107
108impl<'py> IntoPyObject<'py> for &Decimal {
109 type Target = PyAny;
110 type Output = Bound<'py, Self::Target>;
111 type Error = PyErr;
112
113 #[cfg(feature = "experimental-inspect")]
114 const OUTPUT_TYPE: PyStaticExpr = Decimal::OUTPUT_TYPE;
115
116 #[inline]
117 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
118 (*self).into_pyobject(py)
119 }
120}
121
122#[cfg(test)]
123mod test_rust_decimal {
124 use super::*;
125 use crate::types::dict::PyDictMethods;
126 use crate::types::PyDict;
127 use std::ffi::CString;
128
129 #[cfg(not(target_arch = "wasm32"))]
130 use proptest::prelude::*;
131
132 macro_rules! convert_constants {
133 ($name:ident, $rs:expr, $py:literal) => {
134 #[test]
135 fn $name() {
136 Python::attach(|py| {
137 let rs_orig = $rs;
138 let rs_dec = rs_orig.into_pyobject(py).unwrap();
139 let locals = PyDict::new(py);
140 locals.set_item("rs_dec", &rs_dec).unwrap();
141 py.run(
143 &CString::new(format!(
144 "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
145 $py
146 ))
147 .unwrap(),
148 None,
149 Some(&locals),
150 )
151 .unwrap();
152 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
154 let py_result: Decimal = py_dec.extract().unwrap();
155 assert_eq!(rs_orig, py_result);
156 })
157 }
158 };
159 }
160
161 convert_constants!(convert_zero, Decimal::ZERO, "0");
162 convert_constants!(convert_one, Decimal::ONE, "1");
163 convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
164 convert_constants!(convert_two, Decimal::TWO, "2");
165 convert_constants!(convert_ten, Decimal::TEN, "10");
166 convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
167 convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
168
169 #[cfg(not(target_arch = "wasm32"))]
170 proptest! {
171 #[test]
172 fn test_roundtrip(
173 lo in any::<u32>(),
174 mid in any::<u32>(),
175 high in any::<u32>(),
176 negative in any::<bool>(),
177 scale in 0..28u32
178 ) {
179 let num = Decimal::from_parts(lo, mid, high, negative, scale);
180 Python::attach(|py| {
181 let rs_dec = num.into_pyobject(py).unwrap();
182 let locals = PyDict::new(py);
183 locals.set_item("rs_dec", &rs_dec).unwrap();
184 py.run(
185 &CString::new(format!(
186 "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
187 None, Some(&locals)).unwrap();
188 let roundtripped: Decimal = rs_dec.extract().unwrap();
189 assert_eq!(num, roundtripped);
190 })
191 }
192
193 #[test]
194 fn test_integers(num in any::<i64>()) {
195 Python::attach(|py| {
196 let py_num = num.into_pyobject(py).unwrap();
197 let roundtripped: Decimal = py_num.extract().unwrap();
198 let rs_dec = Decimal::new(num, 0);
199 assert_eq!(rs_dec, roundtripped);
200 })
201 }
202 }
203
204 #[test]
205 fn test_nan() {
206 Python::attach(|py| {
207 let locals = PyDict::new(py);
208 py.run(
209 c"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
210 None,
211 Some(&locals),
212 )
213 .unwrap();
214 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
215 let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
216 assert!(roundtripped.is_err());
217 })
218 }
219
220 #[test]
221 fn test_scientific_notation() {
222 Python::attach(|py| {
223 let locals = PyDict::new(py);
224 py.run(
225 c"import decimal\npy_dec = decimal.Decimal(\"1e3\")",
226 None,
227 Some(&locals),
228 )
229 .unwrap();
230 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
231 let roundtripped: Decimal = py_dec.extract().unwrap();
232 let rs_dec = Decimal::from_scientific("1e3").unwrap();
233 assert_eq!(rs_dec, roundtripped);
234 })
235 }
236
237 #[test]
238 fn test_infinity() {
239 Python::attach(|py| {
240 let locals = PyDict::new(py);
241 py.run(
242 c"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
243 None,
244 Some(&locals),
245 )
246 .unwrap();
247 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
248 let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
249 assert!(roundtripped.is_err());
250 })
251 }
252}