pyo3/conversions/
rust_decimal.rs1#![cfg(feature = "rust_decimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"rust_decimal\"] }")]
13use crate::conversion::IntoPyObject;
53use crate::exceptions::PyValueError;
54use crate::sync::GILOnceCell;
55use crate::types::any::PyAnyMethods;
56use crate::types::string::PyStringMethods;
57use crate::types::PyType;
58use crate::{Bound, FromPyObject, Py, PyAny, PyErr, PyResult, Python};
59use rust_decimal::Decimal;
60use std::str::FromStr;
61
62impl FromPyObject<'_> for Decimal {
63 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
64 if let Ok(val) = obj.extract() {
66 Ok(Decimal::new(val, 0))
67 } else {
68 let py_str = &obj.str()?;
69 let rs_str = &py_str.to_cow()?;
70 Decimal::from_str(rs_str).or_else(|_| {
71 Decimal::from_scientific(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
72 })
73 }
74 }
75}
76
77static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
78
79fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
80 DECIMAL_CLS.import(py, "decimal", "Decimal")
81}
82
83impl<'py> IntoPyObject<'py> for Decimal {
84 type Target = PyAny;
85 type Output = Bound<'py, Self::Target>;
86 type Error = PyErr;
87
88 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
89 let dec_cls = get_decimal_cls(py)?;
90 dec_cls.call1((self.to_string(),))
93 }
94}
95
96impl<'py> IntoPyObject<'py> for &Decimal {
97 type Target = PyAny;
98 type Output = Bound<'py, Self::Target>;
99 type Error = PyErr;
100
101 #[inline]
102 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
103 (*self).into_pyobject(py)
104 }
105}
106
107#[cfg(test)]
108mod test_rust_decimal {
109 use super::*;
110 use crate::types::dict::PyDictMethods;
111 use crate::types::PyDict;
112 use std::ffi::CString;
113
114 use crate::ffi;
115 #[cfg(not(target_arch = "wasm32"))]
116 use proptest::prelude::*;
117
118 macro_rules! convert_constants {
119 ($name:ident, $rs:expr, $py:literal) => {
120 #[test]
121 fn $name() {
122 Python::with_gil(|py| {
123 let rs_orig = $rs;
124 let rs_dec = rs_orig.into_pyobject(py).unwrap();
125 let locals = PyDict::new(py);
126 locals.set_item("rs_dec", &rs_dec).unwrap();
127 py.run(
129 &CString::new(format!(
130 "import decimal\npy_dec = decimal.Decimal({})\nassert py_dec == rs_dec",
131 $py
132 ))
133 .unwrap(),
134 None,
135 Some(&locals),
136 )
137 .unwrap();
138 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
140 let py_result: Decimal = py_dec.extract().unwrap();
141 assert_eq!(rs_orig, py_result);
142 })
143 }
144 };
145 }
146
147 convert_constants!(convert_zero, Decimal::ZERO, "0");
148 convert_constants!(convert_one, Decimal::ONE, "1");
149 convert_constants!(convert_neg_one, Decimal::NEGATIVE_ONE, "-1");
150 convert_constants!(convert_two, Decimal::TWO, "2");
151 convert_constants!(convert_ten, Decimal::TEN, "10");
152 convert_constants!(convert_one_hundred, Decimal::ONE_HUNDRED, "100");
153 convert_constants!(convert_one_thousand, Decimal::ONE_THOUSAND, "1000");
154
155 #[cfg(not(target_arch = "wasm32"))]
156 proptest! {
157 #[test]
158 fn test_roundtrip(
159 lo in any::<u32>(),
160 mid in any::<u32>(),
161 high in any::<u32>(),
162 negative in any::<bool>(),
163 scale in 0..28u32
164 ) {
165 let num = Decimal::from_parts(lo, mid, high, negative, scale);
166 Python::with_gil(|py| {
167 let rs_dec = num.into_pyobject(py).unwrap();
168 let locals = PyDict::new(py);
169 locals.set_item("rs_dec", &rs_dec).unwrap();
170 py.run(
171 &CString::new(format!(
172 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
173 num)).unwrap(),
174 None, Some(&locals)).unwrap();
175 let roundtripped: Decimal = rs_dec.extract().unwrap();
176 assert_eq!(num, roundtripped);
177 })
178 }
179
180 #[test]
181 fn test_integers(num in any::<i64>()) {
182 Python::with_gil(|py| {
183 let py_num = num.into_pyobject(py).unwrap();
184 let roundtripped: Decimal = py_num.extract().unwrap();
185 let rs_dec = Decimal::new(num, 0);
186 assert_eq!(rs_dec, roundtripped);
187 })
188 }
189 }
190
191 #[test]
192 fn test_nan() {
193 Python::with_gil(|py| {
194 let locals = PyDict::new(py);
195 py.run(
196 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"NaN\")"),
197 None,
198 Some(&locals),
199 )
200 .unwrap();
201 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
202 let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
203 assert!(roundtripped.is_err());
204 })
205 }
206
207 #[test]
208 fn test_scientific_notation() {
209 Python::with_gil(|py| {
210 let locals = PyDict::new(py);
211 py.run(
212 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"1e3\")"),
213 None,
214 Some(&locals),
215 )
216 .unwrap();
217 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
218 let roundtripped: Decimal = py_dec.extract().unwrap();
219 let rs_dec = Decimal::from_scientific("1e3").unwrap();
220 assert_eq!(rs_dec, roundtripped);
221 })
222 }
223
224 #[test]
225 fn test_infinity() {
226 Python::with_gil(|py| {
227 let locals = PyDict::new(py);
228 py.run(
229 ffi::c_str!("import decimal\npy_dec = decimal.Decimal(\"Infinity\")"),
230 None,
231 Some(&locals),
232 )
233 .unwrap();
234 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
235 let roundtripped: Result<Decimal, PyErr> = py_dec.extract();
236 assert!(roundtripped.is_err());
237 })
238 }
239}