pyo3/conversions/
num_bigint.rs

1#![cfg(feature = "num-bigint")]
2//!  Conversions to and from [num-bigint](https://docs.rs/num-bigint)’s [`BigInt`] and [`BigUint`] types.
3//!
4//! This is useful for converting Python integers when they may not fit in Rust's built-in integer types.
5//!
6//! # Setup
7//!
8//! To use this feature, add this to your **`Cargo.toml`**:
9//!
10//! ```toml
11//! [dependencies]
12//! num-bigint = "*"
13#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"num-bigint\"] }")]
14//! ```
15//!
16//! Note that you must use compatible versions of num-bigint and PyO3.
17//! The required num-bigint version may vary based on the version of PyO3.
18//!
19//! ## Examples
20//!
21//! Using [`BigInt`] to correctly increment an arbitrary precision integer.
22//! This is not possible with Rust's native integers if the Python integer is too large,
23//! in which case it will fail its conversion and raise `OverflowError`.
24//! ```rust,no_run
25//! use num_bigint::BigInt;
26//! use pyo3::prelude::*;
27//!
28//! #[pyfunction]
29//! fn add_one(n: BigInt) -> BigInt {
30//!     n + 1
31//! }
32//!
33//! #[pymodule]
34//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
35//!     m.add_function(wrap_pyfunction!(add_one, m)?)?;
36//!     Ok(())
37//! }
38//! ```
39//!
40//! Python code:
41//! ```python
42//! from my_module import add_one
43//!
44//! n = 1 << 1337
45//! value = add_one(n)
46//!
47//! assert n + 1 == value
48//! ```
49
50#[cfg(Py_LIMITED_API)]
51use crate::types::{bytes::PyBytesMethods, PyBytes};
52use crate::{
53    conversion::IntoPyObject,
54    ffi,
55    instance::Bound,
56    types::{any::PyAnyMethods, PyInt},
57    FromPyObject, Py, PyAny, PyErr, PyResult, Python,
58};
59
60use num_bigint::{BigInt, BigUint};
61
62#[cfg(not(Py_LIMITED_API))]
63use num_bigint::Sign;
64
65// for identical functionality between BigInt and BigUint
66macro_rules! bigint_conversion {
67    ($rust_ty: ty, $is_signed: literal, $to_bytes: path) => {
68        #[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))]
69        impl<'py> IntoPyObject<'py> for $rust_ty {
70            type Target = PyInt;
71            type Output = Bound<'py, Self::Target>;
72            type Error = PyErr;
73
74            #[inline]
75            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
76                (&self).into_pyobject(py)
77            }
78        }
79
80        #[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))]
81        impl<'py> IntoPyObject<'py> for &$rust_ty {
82            type Target = PyInt;
83            type Output = Bound<'py, Self::Target>;
84            type Error = PyErr;
85
86            #[cfg(not(Py_LIMITED_API))]
87            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
88                use crate::ffi_ptr_ext::FfiPtrExt;
89                let bytes = $to_bytes(&self);
90                unsafe {
91                    Ok(ffi::_PyLong_FromByteArray(
92                        bytes.as_ptr().cast(),
93                        bytes.len(),
94                        1,
95                        $is_signed.into(),
96                    )
97                    .assume_owned(py)
98                    .downcast_into_unchecked())
99                }
100            }
101
102            #[cfg(Py_LIMITED_API)]
103            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
104                use $crate::py_result_ext::PyResultExt;
105                let bytes = $to_bytes(&self);
106                let bytes_obj = PyBytes::new(py, &bytes);
107                let kwargs = if $is_signed {
108                    let kwargs = crate::types::PyDict::new(py);
109                    kwargs.set_item(crate::intern!(py, "signed"), true)?;
110                    Some(kwargs)
111                } else {
112                    None
113                };
114                unsafe {
115                    py.get_type::<PyInt>()
116                        .call_method("from_bytes", (bytes_obj, "little"), kwargs.as_ref())
117                        .downcast_into_unchecked()
118                }
119            }
120        }
121    };
122}
123
124bigint_conversion!(BigUint, false, BigUint::to_bytes_le);
125bigint_conversion!(BigInt, true, BigInt::to_signed_bytes_le);
126
127#[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))]
128impl<'py> FromPyObject<'py> for BigInt {
129    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<BigInt> {
130        let py = ob.py();
131        // fast path - checking for subclass of `int` just checks a bit in the type object
132        let num_owned: Py<PyInt>;
133        let num = if let Ok(long) = ob.downcast::<PyInt>() {
134            long
135        } else {
136            num_owned = unsafe { Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))? };
137            num_owned.bind(py)
138        };
139        #[cfg(not(Py_LIMITED_API))]
140        {
141            let mut buffer = int_to_u32_vec::<true>(num)?;
142            let sign = if buffer.last().copied().map_or(false, |last| last >> 31 != 0) {
143                // BigInt::new takes an unsigned array, so need to convert from two's complement
144                // flip all bits, 'subtract' 1 (by adding one to the unsigned array)
145                let mut elements = buffer.iter_mut();
146                for element in elements.by_ref() {
147                    *element = (!*element).wrapping_add(1);
148                    if *element != 0 {
149                        // if the element didn't wrap over, no need to keep adding further ...
150                        break;
151                    }
152                }
153                // ... so just two's complement the rest
154                for element in elements {
155                    *element = !*element;
156                }
157                Sign::Minus
158            } else {
159                Sign::Plus
160            };
161            Ok(BigInt::new(sign, buffer))
162        }
163        #[cfg(Py_LIMITED_API)]
164        {
165            let n_bits = int_n_bits(num)?;
166            if n_bits == 0 {
167                return Ok(BigInt::from(0isize));
168            }
169            let bytes = int_to_py_bytes(num, (n_bits + 8) / 8, true)?;
170            Ok(BigInt::from_signed_bytes_le(bytes.as_bytes()))
171        }
172    }
173}
174
175#[cfg_attr(docsrs, doc(cfg(feature = "num-bigint")))]
176impl<'py> FromPyObject<'py> for BigUint {
177    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<BigUint> {
178        let py = ob.py();
179        // fast path - checking for subclass of `int` just checks a bit in the type object
180        let num_owned: Py<PyInt>;
181        let num = if let Ok(long) = ob.downcast::<PyInt>() {
182            long
183        } else {
184            num_owned = unsafe { Py::from_owned_ptr_or_err(py, ffi::PyNumber_Index(ob.as_ptr()))? };
185            num_owned.bind(py)
186        };
187        #[cfg(not(Py_LIMITED_API))]
188        {
189            let buffer = int_to_u32_vec::<false>(num)?;
190            Ok(BigUint::new(buffer))
191        }
192        #[cfg(Py_LIMITED_API)]
193        {
194            let n_bits = int_n_bits(num)?;
195            if n_bits == 0 {
196                return Ok(BigUint::from(0usize));
197            }
198            let bytes = int_to_py_bytes(num, (n_bits + 7) / 8, false)?;
199            Ok(BigUint::from_bytes_le(bytes.as_bytes()))
200        }
201    }
202}
203
204#[cfg(not(any(Py_LIMITED_API, Py_3_13)))]
205#[inline]
206fn int_to_u32_vec<const SIGNED: bool>(long: &Bound<'_, PyInt>) -> PyResult<Vec<u32>> {
207    let mut buffer = Vec::new();
208    let n_bits = int_n_bits(long)?;
209    if n_bits == 0 {
210        return Ok(buffer);
211    }
212    let n_digits = if SIGNED {
213        (n_bits + 32) / 32
214    } else {
215        (n_bits + 31) / 32
216    };
217    buffer.reserve_exact(n_digits);
218    unsafe {
219        crate::err::error_on_minusone(
220            long.py(),
221            ffi::_PyLong_AsByteArray(
222                long.as_ptr().cast(),
223                buffer.as_mut_ptr() as *mut u8,
224                n_digits * 4,
225                1,
226                SIGNED.into(),
227            ),
228        )?;
229        buffer.set_len(n_digits)
230    };
231    buffer
232        .iter_mut()
233        .for_each(|chunk| *chunk = u32::from_le(*chunk));
234
235    Ok(buffer)
236}
237
238#[cfg(all(not(Py_LIMITED_API), Py_3_13))]
239#[inline]
240fn int_to_u32_vec<const SIGNED: bool>(long: &Bound<'_, PyInt>) -> PyResult<Vec<u32>> {
241    let mut buffer = Vec::new();
242    let mut flags = ffi::Py_ASNATIVEBYTES_LITTLE_ENDIAN;
243    if !SIGNED {
244        flags |= ffi::Py_ASNATIVEBYTES_UNSIGNED_BUFFER | ffi::Py_ASNATIVEBYTES_REJECT_NEGATIVE;
245    }
246    let n_bytes =
247        unsafe { ffi::PyLong_AsNativeBytes(long.as_ptr().cast(), std::ptr::null_mut(), 0, flags) };
248    let n_bytes_unsigned: usize = n_bytes
249        .try_into()
250        .map_err(|_| crate::PyErr::fetch(long.py()))?;
251    if n_bytes == 0 {
252        return Ok(buffer);
253    }
254    // TODO: use div_ceil when MSRV >= 1.73
255    let n_digits = {
256        let adjust = if n_bytes % 4 == 0 { 0 } else { 1 };
257        (n_bytes_unsigned / 4) + adjust
258    };
259    buffer.reserve_exact(n_digits);
260    unsafe {
261        ffi::PyLong_AsNativeBytes(
262            long.as_ptr().cast(),
263            buffer.as_mut_ptr().cast(),
264            (n_digits * 4).try_into().unwrap(),
265            flags,
266        );
267        buffer.set_len(n_digits);
268    };
269    buffer
270        .iter_mut()
271        .for_each(|chunk| *chunk = u32::from_le(*chunk));
272
273    Ok(buffer)
274}
275
276#[cfg(Py_LIMITED_API)]
277fn int_to_py_bytes<'py>(
278    long: &Bound<'py, PyInt>,
279    n_bytes: usize,
280    is_signed: bool,
281) -> PyResult<Bound<'py, PyBytes>> {
282    use crate::intern;
283    let py = long.py();
284    let kwargs = if is_signed {
285        let kwargs = crate::types::PyDict::new(py);
286        kwargs.set_item(intern!(py, "signed"), true)?;
287        Some(kwargs)
288    } else {
289        None
290    };
291    let bytes = long.call_method(
292        intern!(py, "to_bytes"),
293        (n_bytes, intern!(py, "little")),
294        kwargs.as_ref(),
295    )?;
296    Ok(bytes.downcast_into()?)
297}
298
299#[inline]
300#[cfg(any(not(Py_3_13), Py_LIMITED_API))]
301fn int_n_bits(long: &Bound<'_, PyInt>) -> PyResult<usize> {
302    let py = long.py();
303    #[cfg(not(Py_LIMITED_API))]
304    {
305        // fast path
306        let n_bits = unsafe { ffi::_PyLong_NumBits(long.as_ptr()) };
307        if n_bits == (-1isize as usize) {
308            return Err(crate::PyErr::fetch(py));
309        }
310        Ok(n_bits)
311    }
312
313    #[cfg(Py_LIMITED_API)]
314    {
315        // slow path
316        long.call_method0(crate::intern!(py, "bit_length"))
317            .and_then(|any| any.extract())
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::tests::common::generate_unique_module_name;
325    use crate::types::{PyDict, PyModule};
326    use indoc::indoc;
327    use pyo3_ffi::c_str;
328
329    fn rust_fib<T>() -> impl Iterator<Item = T>
330    where
331        T: From<u16>,
332        for<'a> &'a T: std::ops::Add<Output = T>,
333    {
334        let mut f0: T = T::from(1);
335        let mut f1: T = T::from(1);
336        std::iter::from_fn(move || {
337            let f2 = &f0 + &f1;
338            Some(std::mem::replace(&mut f0, std::mem::replace(&mut f1, f2)))
339        })
340    }
341
342    fn python_fib(py: Python<'_>) -> impl Iterator<Item = Bound<'_, PyAny>> + '_ {
343        let mut f0 = 1i32.into_pyobject(py).unwrap().into_any();
344        let mut f1 = 1i32.into_pyobject(py).unwrap().into_any();
345        std::iter::from_fn(move || {
346            let f2 = f0.call_method1("__add__", (&f1,)).unwrap();
347            Some(std::mem::replace(&mut f0, std::mem::replace(&mut f1, f2)))
348        })
349    }
350
351    #[test]
352    fn convert_biguint() {
353        Python::with_gil(|py| {
354            // check the first 2000 numbers in the fibonacci sequence
355            for (py_result, rs_result) in python_fib(py).zip(rust_fib::<BigUint>()).take(2000) {
356                // Python -> Rust
357                assert_eq!(py_result.extract::<BigUint>().unwrap(), rs_result);
358                // Rust -> Python
359                assert!(py_result.eq(rs_result).unwrap());
360            }
361        });
362    }
363
364    #[test]
365    fn convert_bigint() {
366        Python::with_gil(|py| {
367            // check the first 2000 numbers in the fibonacci sequence
368            for (py_result, rs_result) in python_fib(py).zip(rust_fib::<BigInt>()).take(2000) {
369                // Python -> Rust
370                assert_eq!(py_result.extract::<BigInt>().unwrap(), rs_result);
371                // Rust -> Python
372                assert!(py_result.eq(&rs_result).unwrap());
373
374                // negate
375
376                let rs_result = rs_result * -1;
377                let py_result = py_result.call_method0("__neg__").unwrap();
378
379                // Python -> Rust
380                assert_eq!(py_result.extract::<BigInt>().unwrap(), rs_result);
381                // Rust -> Python
382                assert!(py_result.eq(rs_result).unwrap());
383            }
384        });
385    }
386
387    fn python_index_class(py: Python<'_>) -> Bound<'_, PyModule> {
388        let index_code = c_str!(indoc!(
389            r#"
390                class C:
391                    def __init__(self, x):
392                        self.x = x
393                    def __index__(self):
394                        return self.x
395                "#
396        ));
397        PyModule::from_code(
398            py,
399            index_code,
400            c_str!("index.py"),
401            &generate_unique_module_name("index"),
402        )
403        .unwrap()
404    }
405
406    #[test]
407    fn convert_index_class() {
408        Python::with_gil(|py| {
409            let index = python_index_class(py);
410            let locals = PyDict::new(py);
411            locals.set_item("index", index).unwrap();
412            let ob = py
413                .eval(ffi::c_str!("index.C(10)"), None, Some(&locals))
414                .unwrap();
415            let _: BigInt = ob.extract().unwrap();
416        });
417    }
418
419    #[test]
420    fn handle_zero() {
421        Python::with_gil(|py| {
422            let zero: BigInt = 0i32.into_pyobject(py).unwrap().extract().unwrap();
423            assert_eq!(zero, BigInt::from(0));
424        })
425    }
426
427    /// `OverflowError` on converting Python int to BigInt, see issue #629
428    #[test]
429    fn check_overflow() {
430        Python::with_gil(|py| {
431            macro_rules! test {
432                ($T:ty, $value:expr, $py:expr) => {
433                    let value = $value;
434                    println!("{}: {}", stringify!($T), value);
435                    let python_value = value.clone().into_pyobject(py).unwrap();
436                    let roundtrip_value = python_value.extract::<$T>().unwrap();
437                    assert_eq!(value, roundtrip_value);
438                };
439            }
440
441            for i in 0..=256usize {
442                // test a lot of values to help catch other bugs too
443                test!(BigInt, BigInt::from(i), py);
444                test!(BigUint, BigUint::from(i), py);
445                test!(BigInt, -BigInt::from(i), py);
446                test!(BigInt, BigInt::from(1) << i, py);
447                test!(BigUint, BigUint::from(1u32) << i, py);
448                test!(BigInt, -BigInt::from(1) << i, py);
449                test!(BigInt, (BigInt::from(1) << i) + 1u32, py);
450                test!(BigUint, (BigUint::from(1u32) << i) + 1u32, py);
451                test!(BigInt, (-BigInt::from(1) << i) + 1u32, py);
452                test!(BigInt, (BigInt::from(1) << i) - 1u32, py);
453                test!(BigUint, (BigUint::from(1u32) << i) - 1u32, py);
454                test!(BigInt, (-BigInt::from(1) << i) - 1u32, py);
455            }
456        });
457    }
458}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here