pyo3_macros_backend/
pyfunction.rs

1#[cfg(feature = "experimental-inspect")]
2use crate::introspection::function_introspection_code;
3use crate::utils::Ctx;
4use crate::{
5    attributes::{
6        self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
7        FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
8    },
9    method::{self, CallingConvention, FnArg},
10    pymethod::check_generic,
11};
12use proc_macro2::TokenStream;
13use quote::{format_ident, quote};
14use syn::parse::{Parse, ParseStream};
15use syn::punctuated::Punctuated;
16use syn::{ext::IdentExt, spanned::Spanned, Result};
17
18mod signature;
19
20pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
21
22#[derive(Clone, Debug)]
23pub struct PyFunctionArgPyO3Attributes {
24    pub from_py_with: Option<FromPyWithAttribute>,
25    pub cancel_handle: Option<attributes::kw::cancel_handle>,
26}
27
28enum PyFunctionArgPyO3Attribute {
29    FromPyWith(FromPyWithAttribute),
30    CancelHandle(attributes::kw::cancel_handle),
31}
32
33impl Parse for PyFunctionArgPyO3Attribute {
34    fn parse(input: ParseStream<'_>) -> Result<Self> {
35        let lookahead = input.lookahead1();
36        if lookahead.peek(attributes::kw::cancel_handle) {
37            input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
38        } else if lookahead.peek(attributes::kw::from_py_with) {
39            input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
40        } else {
41            Err(lookahead.error())
42        }
43    }
44}
45
46impl PyFunctionArgPyO3Attributes {
47    /// Parses #[pyo3(from_python_with = "func")]
48    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
49        let mut attributes = PyFunctionArgPyO3Attributes {
50            from_py_with: None,
51            cancel_handle: None,
52        };
53        take_attributes(attrs, |attr| {
54            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
55                for attr in pyo3_attrs {
56                    match attr {
57                        PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
58                            ensure_spanned!(
59                                attributes.from_py_with.is_none(),
60                                from_py_with.span() => "`from_py_with` may only be specified once per argument"
61                            );
62                            attributes.from_py_with = Some(from_py_with);
63                        }
64                        PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
65                            ensure_spanned!(
66                                attributes.cancel_handle.is_none(),
67                                cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
68                            );
69                            attributes.cancel_handle = Some(cancel_handle);
70                        }
71                    }
72                    ensure_spanned!(
73                        attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
74                        attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
75                    );
76                }
77                Ok(true)
78            } else {
79                Ok(false)
80            }
81        })?;
82        Ok(attributes)
83    }
84}
85
86#[derive(Default)]
87pub struct PyFunctionOptions {
88    pub pass_module: Option<attributes::kw::pass_module>,
89    pub name: Option<NameAttribute>,
90    pub signature: Option<SignatureAttribute>,
91    pub text_signature: Option<TextSignatureAttribute>,
92    pub krate: Option<CrateAttribute>,
93}
94
95impl Parse for PyFunctionOptions {
96    fn parse(input: ParseStream<'_>) -> Result<Self> {
97        let mut options = PyFunctionOptions::default();
98
99        let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
100        options.add_attributes(attrs)?;
101
102        Ok(options)
103    }
104}
105
106pub enum PyFunctionOption {
107    Name(NameAttribute),
108    PassModule(attributes::kw::pass_module),
109    Signature(SignatureAttribute),
110    TextSignature(TextSignatureAttribute),
111    Crate(CrateAttribute),
112}
113
114impl Parse for PyFunctionOption {
115    fn parse(input: ParseStream<'_>) -> Result<Self> {
116        let lookahead = input.lookahead1();
117        if lookahead.peek(attributes::kw::name) {
118            input.parse().map(PyFunctionOption::Name)
119        } else if lookahead.peek(attributes::kw::pass_module) {
120            input.parse().map(PyFunctionOption::PassModule)
121        } else if lookahead.peek(attributes::kw::signature) {
122            input.parse().map(PyFunctionOption::Signature)
123        } else if lookahead.peek(attributes::kw::text_signature) {
124            input.parse().map(PyFunctionOption::TextSignature)
125        } else if lookahead.peek(syn::Token![crate]) {
126            input.parse().map(PyFunctionOption::Crate)
127        } else {
128            Err(lookahead.error())
129        }
130    }
131}
132
133impl PyFunctionOptions {
134    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
135        let mut options = PyFunctionOptions::default();
136        options.add_attributes(take_pyo3_options(attrs)?)?;
137        Ok(options)
138    }
139
140    pub fn add_attributes(
141        &mut self,
142        attrs: impl IntoIterator<Item = PyFunctionOption>,
143    ) -> Result<()> {
144        macro_rules! set_option {
145            ($key:ident) => {
146                {
147                    ensure_spanned!(
148                        self.$key.is_none(),
149                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
150                    );
151                    self.$key = Some($key);
152                }
153            };
154        }
155        for attr in attrs {
156            match attr {
157                PyFunctionOption::Name(name) => set_option!(name),
158                PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
159                PyFunctionOption::Signature(signature) => set_option!(signature),
160                PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
161                PyFunctionOption::Crate(krate) => set_option!(krate),
162            }
163        }
164        Ok(())
165    }
166}
167
168pub fn build_py_function(
169    ast: &mut syn::ItemFn,
170    mut options: PyFunctionOptions,
171) -> syn::Result<TokenStream> {
172    options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
173    impl_wrap_pyfunction(ast, options)
174}
175
176/// Generates python wrapper over a function that allows adding it to a python module as a python
177/// function
178pub fn impl_wrap_pyfunction(
179    func: &mut syn::ItemFn,
180    options: PyFunctionOptions,
181) -> syn::Result<TokenStream> {
182    check_generic(&func.sig)?;
183    let PyFunctionOptions {
184        pass_module,
185        name,
186        signature,
187        text_signature,
188        krate,
189    } = options;
190
191    let ctx = &Ctx::new(&krate, Some(&func.sig));
192    let Ctx { pyo3_path, .. } = &ctx;
193
194    let python_name = name
195        .as_ref()
196        .map_or_else(|| &func.sig.ident, |name| &name.value.0)
197        .unraw();
198
199    let tp = if pass_module.is_some() {
200        let span = match func.sig.inputs.first() {
201            Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
202            Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
203                func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
204            ),
205        };
206        method::FnType::FnModule(span)
207    } else {
208        method::FnType::FnStatic
209    };
210
211    let arguments = func
212        .sig
213        .inputs
214        .iter_mut()
215        .skip(if tp.skip_first_rust_argument_in_python_signature() {
216            1
217        } else {
218            0
219        })
220        .map(FnArg::parse)
221        .collect::<syn::Result<Vec<_>>>()?;
222
223    let signature = if let Some(signature) = signature {
224        FunctionSignature::from_arguments_and_attribute(arguments, signature)?
225    } else {
226        FunctionSignature::from_arguments(arguments)
227    };
228
229    let spec = method::FnSpec {
230        tp,
231        name: &func.sig.ident,
232        convention: CallingConvention::from_signature(&signature),
233        python_name,
234        signature,
235        text_signature,
236        asyncness: func.sig.asyncness,
237        unsafety: func.sig.unsafety,
238    };
239
240    let vis = &func.vis;
241    let name = &func.sig.ident;
242
243    let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
244    let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
245    let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx);
246    #[cfg(feature = "experimental-inspect")]
247    let introspection = function_introspection_code(pyo3_path, &name.to_string());
248    #[cfg(not(feature = "experimental-inspect"))]
249    let introspection = quote! {};
250
251    let wrapped_pyfunction = quote! {
252        // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
253        // will actually bring both the module and the function into scope.
254        #[doc(hidden)]
255        #vis mod #name {
256            pub(crate) struct MakeDef;
257            pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
258            #introspection
259        }
260
261        // Generate the definition inside an anonymous function in the same scope as the original function -
262        // this avoids complications around the fact that the generated module has a different scope
263        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
264        // inside a function body)
265        #[allow(unknown_lints, non_local_definitions)]
266        impl #name::MakeDef {
267            const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
268        }
269
270        #[allow(non_snake_case)]
271        #wrapper
272    };
273    Ok(wrapped_pyfunction)
274}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here