Skip to main content

pyo3/
byteswriter.rs

1#[cfg(feature = "experimental-inspect")]
2use crate::inspect::PyStaticExpr;
3#[cfg(feature = "experimental-inspect")]
4use crate::PyTypeInfo;
5#[cfg(not(Py_LIMITED_API))]
6use crate::{
7    err::error_on_minusone,
8    ffi::{
9        self,
10        compat::{
11            PyBytesWriter_Create, PyBytesWriter_Discard, PyBytesWriter_Finish,
12            PyBytesWriter_GetData, PyBytesWriter_GetSize, PyBytesWriter_Resize,
13        },
14    },
15    ffi_ptr_ext::FfiPtrExt,
16    py_result_ext::PyResultExt,
17};
18use crate::{types::PyBytes, Bound, IntoPyObject, PyErr, PyResult, Python};
19use std::io::IoSlice;
20#[cfg(not(Py_LIMITED_API))]
21use std::{
22    mem::ManuallyDrop,
23    ptr::{self, NonNull},
24};
25
26pub struct PyBytesWriter<'py> {
27    python: Python<'py>,
28    #[cfg(not(Py_LIMITED_API))]
29    writer: NonNull<ffi::compat::PyBytesWriter>,
30    #[cfg(Py_LIMITED_API)]
31    buffer: Vec<u8>,
32}
33
34impl<'py> PyBytesWriter<'py> {
35    /// Create a new `PyBytesWriter` with a default initial capacity.
36    #[inline]
37    pub fn new(py: Python<'py>) -> PyResult<Self> {
38        Self::with_capacity(py, 0)
39    }
40
41    /// Create a new `PyBytesWriter` with the specified initial capacity.
42    #[inline]
43    #[cfg_attr(Py_LIMITED_API, allow(clippy::unnecessary_wraps))]
44    pub fn with_capacity(py: Python<'py>, capacity: usize) -> PyResult<Self> {
45        #[cfg(not(Py_LIMITED_API))]
46        {
47            NonNull::new(unsafe { PyBytesWriter_Create(capacity as _) }).map_or_else(
48                || Err(PyErr::fetch(py)),
49                |writer| {
50                    let mut writer = PyBytesWriter { python: py, writer };
51                    if capacity > 0 {
52                        // SAFETY: By setting the length to 0, we ensure no bytes are considered uninitialized.
53                        unsafe { writer.set_len(0)? };
54                    }
55                    Ok(writer)
56                },
57            )
58        }
59
60        #[cfg(Py_LIMITED_API)]
61        {
62            Ok(PyBytesWriter {
63                python: py,
64                buffer: Vec::with_capacity(capacity),
65            })
66        }
67    }
68
69    /// Get the current length of the internal buffer.
70    #[inline]
71    pub fn len(&self) -> usize {
72        #[cfg(not(Py_LIMITED_API))]
73        unsafe {
74            PyBytesWriter_GetSize(self.writer.as_ptr()) as _
75        }
76
77        #[cfg(Py_LIMITED_API)]
78        {
79            self.buffer.len()
80        }
81    }
82
83    #[inline]
84    #[cfg(not(Py_LIMITED_API))]
85    fn as_mut_ptr(&mut self) -> *mut u8 {
86        unsafe { PyBytesWriter_GetData(self.writer.as_ptr()) as _ }
87    }
88
89    /// Set the length of the internal buffer to `new_len`. The new bytes are uninitialized.
90    ///
91    /// # Safety
92    /// The caller must ensure the new bytes are initialized. This will also make all pointers
93    /// returned by `as_mut_ptr` invalid, so the caller must not hold any references to the buffer
94    /// across this call.
95    #[inline]
96    #[cfg(not(Py_LIMITED_API))]
97    unsafe fn set_len(&mut self, new_len: usize) -> PyResult<()> {
98        unsafe {
99            error_on_minusone(
100                self.python,
101                PyBytesWriter_Resize(self.writer.as_ptr(), new_len as _),
102            )
103        }
104    }
105}
106
107impl<'py> TryFrom<PyBytesWriter<'py>> for Bound<'py, PyBytes> {
108    type Error = PyErr;
109
110    #[inline]
111    fn try_from(value: PyBytesWriter<'py>) -> Result<Self, Self::Error> {
112        let py = value.python;
113
114        #[cfg(not(Py_LIMITED_API))]
115        unsafe {
116            PyBytesWriter_Finish(ManuallyDrop::new(value).writer.as_ptr())
117                .assume_owned_or_err(py)
118                .cast_into_unchecked()
119        }
120
121        #[cfg(Py_LIMITED_API)]
122        {
123            Ok(PyBytes::new(py, &value.buffer))
124        }
125    }
126}
127
128impl<'py> IntoPyObject<'py> for PyBytesWriter<'py> {
129    type Target = PyBytes;
130    type Output = Bound<'py, PyBytes>;
131    type Error = PyErr;
132
133    #[cfg(feature = "experimental-inspect")]
134    const OUTPUT_TYPE: PyStaticExpr = PyBytes::TYPE_HINT;
135
136    #[inline]
137    fn into_pyobject(self, _py: Python<'py>) -> Result<Self::Output, Self::Error> {
138        self.try_into()
139    }
140}
141
142#[cfg(not(Py_LIMITED_API))]
143impl<'py> Drop for PyBytesWriter<'py> {
144    #[inline]
145    fn drop(&mut self) {
146        unsafe { PyBytesWriter_Discard(self.writer.as_ptr()) }
147    }
148}
149
150#[cfg(not(Py_LIMITED_API))]
151impl std::io::Write for PyBytesWriter<'_> {
152    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
153        self.write_all(buf)?;
154        Ok(buf.len())
155    }
156
157    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
158        let len = bufs.iter().map(|b| b.len()).sum();
159        let pos = self.len();
160
161        // SAFETY: We write the new uninitialized bytes below.
162        unsafe { self.set_len(self.len() + len)? }
163
164        // SAFETY: We ensured enough capacity above and the ptr will be valid because we will not be
165        // resizing the buffer until we have written all the data.
166        let mut ptr = unsafe { self.as_mut_ptr().add(pos) };
167
168        for buf in bufs {
169            // SAFETY: We have ensured enough capacity above.
170            unsafe { ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len()) };
171
172            // SAFETY: We just wrote buf.len() bytes
173            ptr = unsafe { ptr.add(buf.len()) };
174        }
175        Ok(len)
176    }
177
178    fn flush(&mut self) -> std::io::Result<()> {
179        Ok(())
180    }
181
182    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
183        let len = buf.len();
184        let pos = self.len();
185
186        // SAFETY: We write the new uninitialized bytes below.
187        unsafe { self.set_len(pos + len)? }
188
189        // SAFETY: We have ensured enough capacity above.
190        unsafe { ptr::copy_nonoverlapping(buf.as_ptr(), self.as_mut_ptr().add(pos), len) };
191
192        Ok(())
193    }
194}
195
196#[cfg(Py_LIMITED_API)]
197impl std::io::Write for PyBytesWriter<'_> {
198    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
199        self.buffer.write(buf)
200    }
201
202    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
203        self.buffer.write_vectored(bufs)
204    }
205
206    fn flush(&mut self) -> std::io::Result<()> {
207        self.buffer.flush()
208    }
209
210    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
211        self.buffer.write_all(buf)
212    }
213
214    fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> {
215        self.buffer.write_fmt(args)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::types::PyBytesMethods;
223    use std::io::Write;
224
225    #[test]
226    fn test_io_write() {
227        Python::attach(|py| {
228            let buf = b"hallo world";
229            let mut writer = PyBytesWriter::new(py).unwrap();
230            assert_eq!(writer.write(buf).unwrap(), 11);
231            let bytes: Bound<'_, PyBytes> = writer.try_into().unwrap();
232            assert_eq!(bytes.as_bytes(), buf);
233        })
234    }
235
236    #[test]
237    fn test_pre_allocated() {
238        Python::attach(|py| {
239            let buf = b"hallo world";
240            let mut writer = PyBytesWriter::with_capacity(py, buf.len()).unwrap();
241            assert_eq!(writer.len(), 0, "Writer position should be zero");
242            assert_eq!(writer.write(buf).unwrap(), 11);
243            let bytes: Bound<'_, PyBytes> = writer.try_into().unwrap();
244            assert_eq!(bytes.as_bytes(), buf);
245        })
246    }
247
248    #[test]
249    fn test_io_write_vectored() {
250        Python::attach(|py| {
251            let bufs = [IoSlice::new(b"hallo "), IoSlice::new(b"world")];
252            let mut writer = PyBytesWriter::new(py).unwrap();
253            assert_eq!(writer.write_vectored(&bufs).unwrap(), 11);
254            let bytes: Bound<'_, PyBytes> = writer.try_into().unwrap();
255            assert_eq!(bytes.as_bytes(), b"hallo world");
256        })
257    }
258
259    #[test]
260    fn test_io_write_vectored_large() {
261        Python::attach(|py| {
262            let large_data = vec![b'\n'; 1024]; // 1 KB
263            let bufs = [
264                IoSlice::new(b"hallo"),
265                IoSlice::new(&large_data),
266                IoSlice::new(b"world"),
267            ];
268            let mut writer = PyBytesWriter::new(py).unwrap();
269            assert_eq!(writer.write_vectored(&bufs).unwrap(), 1034);
270            let bytes: Bound<'_, PyBytes> = writer.try_into().unwrap();
271            assert!(bytes.as_bytes().starts_with(b"hallo\n"));
272            assert!(bytes.as_bytes().ends_with(b"world"));
273            assert_eq!(bytes.as_bytes().len(), 1034);
274        })
275    }
276
277    #[test]
278    fn test_large_data() {
279        Python::attach(|py| {
280            let mut writer = PyBytesWriter::new(py).unwrap();
281            let large_data = vec![0; 1024]; // 1 KB
282            writer.write_all(&large_data).unwrap();
283            let bytes: Bound<'_, PyBytes> = writer.try_into().unwrap();
284            assert_eq!(bytes.as_bytes(), large_data.as_slice());
285        })
286    }
287}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here