pyo3_macros_backend/
pyfunction.rs

1use crate::attributes::KeywordAttribute;
2#[cfg(feature = "experimental-inspect")]
3use crate::introspection::{function_introspection_code, introspection_id_const};
4use crate::utils::{Ctx, LitCStr};
5use crate::{
6    attributes::{
7        self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
8        FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
9    },
10    method::{self, CallingConvention, FnArg},
11    pymethod::check_generic,
12};
13use proc_macro2::{Span, TokenStream};
14use quote::{format_ident, quote, ToTokens};
15use std::cmp::PartialEq;
16use std::ffi::CString;
17use syn::parse::{Parse, ParseStream};
18use syn::punctuated::Punctuated;
19use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token};
20
21mod signature;
22
23pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
24
25#[derive(Clone, Debug)]
26pub struct PyFunctionArgPyO3Attributes {
27    pub from_py_with: Option<FromPyWithAttribute>,
28    pub cancel_handle: Option<attributes::kw::cancel_handle>,
29}
30
31enum PyFunctionArgPyO3Attribute {
32    FromPyWith(FromPyWithAttribute),
33    CancelHandle(attributes::kw::cancel_handle),
34}
35
36impl Parse for PyFunctionArgPyO3Attribute {
37    fn parse(input: ParseStream<'_>) -> Result<Self> {
38        let lookahead = input.lookahead1();
39        if lookahead.peek(attributes::kw::cancel_handle) {
40            input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
41        } else if lookahead.peek(attributes::kw::from_py_with) {
42            input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
43        } else {
44            Err(lookahead.error())
45        }
46    }
47}
48
49impl PyFunctionArgPyO3Attributes {
50    /// Parses #[pyo3(from_python_with = "func")]
51    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
52        let mut attributes = PyFunctionArgPyO3Attributes {
53            from_py_with: None,
54            cancel_handle: None,
55        };
56        take_attributes(attrs, |attr| {
57            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
58                for attr in pyo3_attrs {
59                    match attr {
60                        PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
61                            ensure_spanned!(
62                                attributes.from_py_with.is_none(),
63                                from_py_with.span() => "`from_py_with` may only be specified once per argument"
64                            );
65                            attributes.from_py_with = Some(from_py_with);
66                        }
67                        PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
68                            ensure_spanned!(
69                                attributes.cancel_handle.is_none(),
70                                cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
71                            );
72                            attributes.cancel_handle = Some(cancel_handle);
73                        }
74                    }
75                    ensure_spanned!(
76                        attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
77                        attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
78                    );
79                }
80                Ok(true)
81            } else {
82                Ok(false)
83            }
84        })?;
85        Ok(attributes)
86    }
87}
88
89type PyFunctionWarningMessageAttribute = KeywordAttribute<attributes::kw::message, LitStr>;
90type PyFunctionWarningCategoryAttribute = KeywordAttribute<attributes::kw::category, Path>;
91
92pub struct PyFunctionWarningAttribute {
93    pub message: PyFunctionWarningMessageAttribute,
94    pub category: Option<PyFunctionWarningCategoryAttribute>,
95    pub span: Span,
96}
97
98#[derive(PartialEq)]
99pub enum PyFunctionWarningCategory {
100    Path(Path),
101    UserWarning,
102    DeprecationWarning, // TODO: unused for now, intended for pyo3(deprecated) special-case
103}
104
105pub struct PyFunctionWarning {
106    pub message: LitStr,
107    pub category: PyFunctionWarningCategory,
108    pub span: Span,
109}
110
111impl From<PyFunctionWarningAttribute> for PyFunctionWarning {
112    fn from(value: PyFunctionWarningAttribute) -> Self {
113        Self {
114            message: value.message.value,
115            category: value
116                .category
117                .map_or(PyFunctionWarningCategory::UserWarning, |cat| {
118                    PyFunctionWarningCategory::Path(cat.value)
119                }),
120            span: value.span,
121        }
122    }
123}
124
125pub trait WarningFactory {
126    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream;
127    fn span(&self) -> Span;
128}
129
130impl WarningFactory for PyFunctionWarning {
131    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
132        let message = &self.message.value();
133        let c_message = LitCStr::new(
134            CString::new(message.clone()).unwrap(),
135            Spanned::span(&message),
136            ctx,
137        );
138        let pyo3_path = &ctx.pyo3_path;
139        let category = match &self.category {
140            PyFunctionWarningCategory::Path(path) => quote! {#path},
141            PyFunctionWarningCategory::UserWarning => {
142                quote! {#pyo3_path::exceptions::PyUserWarning}
143            }
144            PyFunctionWarningCategory::DeprecationWarning => {
145                quote! {#pyo3_path::exceptions::PyDeprecationWarning}
146            }
147        };
148        quote! {
149            #pyo3_path::PyErr::warn(py, &<#category as #pyo3_path::PyTypeInfo>::type_object(py), #c_message, 1)?;
150        }
151    }
152
153    fn span(&self) -> Span {
154        self.span
155    }
156}
157
158impl<T: WarningFactory> WarningFactory for Vec<T> {
159    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
160        let warnings = self.iter().map(|warning| warning.build_py_warning(ctx));
161
162        quote! {
163            #(#warnings)*
164        }
165    }
166
167    fn span(&self) -> Span {
168        self.iter()
169            .map(|val| val.span())
170            .reduce(|acc, span| acc.join(span).unwrap_or(acc))
171            .unwrap()
172    }
173}
174
175impl Parse for PyFunctionWarningAttribute {
176    fn parse(input: ParseStream<'_>) -> Result<Self> {
177        let mut message: Option<PyFunctionWarningMessageAttribute> = None;
178        let mut category: Option<PyFunctionWarningCategoryAttribute> = None;
179
180        let span = input.parse::<attributes::kw::warn>()?.span();
181
182        let content;
183        syn::parenthesized!(content in input);
184
185        while !content.is_empty() {
186            let lookahead = content.lookahead1();
187
188            if lookahead.peek(attributes::kw::message) {
189                message = content
190                    .parse::<PyFunctionWarningMessageAttribute>()
191                    .map(Some)?;
192            } else if lookahead.peek(attributes::kw::category) {
193                category = content
194                    .parse::<PyFunctionWarningCategoryAttribute>()
195                    .map(Some)?;
196            } else {
197                return Err(lookahead.error());
198            }
199
200            if content.peek(Token![,]) {
201                content.parse::<Token![,]>()?;
202            }
203        }
204
205        Ok(PyFunctionWarningAttribute {
206            message: message.ok_or(syn::Error::new(
207                content.span(),
208                "missing `message` in `warn` attribute",
209            ))?,
210            category,
211            span,
212        })
213    }
214}
215
216impl ToTokens for PyFunctionWarningAttribute {
217    fn to_tokens(&self, tokens: &mut TokenStream) {
218        let message_tokens = self.message.to_token_stream();
219        let category_tokens = self
220            .category
221            .as_ref()
222            .map_or(quote! {}, |cat| cat.to_token_stream());
223
224        let token_stream = quote! {
225            warn(#message_tokens, #category_tokens)
226        };
227
228        tokens.extend(token_stream);
229    }
230}
231
232#[derive(Default)]
233pub struct PyFunctionOptions {
234    pub pass_module: Option<attributes::kw::pass_module>,
235    pub name: Option<NameAttribute>,
236    pub signature: Option<SignatureAttribute>,
237    pub text_signature: Option<TextSignatureAttribute>,
238    pub krate: Option<CrateAttribute>,
239    pub warnings: Vec<PyFunctionWarning>,
240}
241
242impl Parse for PyFunctionOptions {
243    fn parse(input: ParseStream<'_>) -> Result<Self> {
244        let mut options = PyFunctionOptions::default();
245
246        let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
247        options.add_attributes(attrs)?;
248
249        Ok(options)
250    }
251}
252
253pub enum PyFunctionOption {
254    Name(NameAttribute),
255    PassModule(attributes::kw::pass_module),
256    Signature(SignatureAttribute),
257    TextSignature(TextSignatureAttribute),
258    Crate(CrateAttribute),
259    Warning(PyFunctionWarningAttribute),
260}
261
262impl Parse for PyFunctionOption {
263    fn parse(input: ParseStream<'_>) -> Result<Self> {
264        let lookahead = input.lookahead1();
265        if lookahead.peek(attributes::kw::name) {
266            input.parse().map(PyFunctionOption::Name)
267        } else if lookahead.peek(attributes::kw::pass_module) {
268            input.parse().map(PyFunctionOption::PassModule)
269        } else if lookahead.peek(attributes::kw::signature) {
270            input.parse().map(PyFunctionOption::Signature)
271        } else if lookahead.peek(attributes::kw::text_signature) {
272            input.parse().map(PyFunctionOption::TextSignature)
273        } else if lookahead.peek(syn::Token![crate]) {
274            input.parse().map(PyFunctionOption::Crate)
275        } else if lookahead.peek(attributes::kw::warn) {
276            input.parse().map(PyFunctionOption::Warning)
277        } else {
278            Err(lookahead.error())
279        }
280    }
281}
282
283impl PyFunctionOptions {
284    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
285        let mut options = PyFunctionOptions::default();
286        options.add_attributes(take_pyo3_options(attrs)?)?;
287        Ok(options)
288    }
289
290    pub fn add_attributes(
291        &mut self,
292        attrs: impl IntoIterator<Item = PyFunctionOption>,
293    ) -> Result<()> {
294        macro_rules! set_option {
295            ($key:ident) => {
296                {
297                    ensure_spanned!(
298                        self.$key.is_none(),
299                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
300                    );
301                    self.$key = Some($key);
302                }
303            };
304        }
305        for attr in attrs {
306            match attr {
307                PyFunctionOption::Name(name) => set_option!(name),
308                PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
309                PyFunctionOption::Signature(signature) => set_option!(signature),
310                PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
311                PyFunctionOption::Crate(krate) => set_option!(krate),
312                PyFunctionOption::Warning(warning) => {
313                    self.warnings.push(warning.into());
314                }
315            }
316        }
317        Ok(())
318    }
319}
320
321pub fn build_py_function(
322    ast: &mut syn::ItemFn,
323    mut options: PyFunctionOptions,
324) -> syn::Result<TokenStream> {
325    options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
326    impl_wrap_pyfunction(ast, options)
327}
328
329/// Generates python wrapper over a function that allows adding it to a python module as a python
330/// function
331pub fn impl_wrap_pyfunction(
332    func: &mut syn::ItemFn,
333    options: PyFunctionOptions,
334) -> syn::Result<TokenStream> {
335    check_generic(&func.sig)?;
336    let PyFunctionOptions {
337        pass_module,
338        name,
339        signature,
340        text_signature,
341        krate,
342        warnings,
343    } = options;
344
345    let ctx = &Ctx::new(&krate, Some(&func.sig));
346    let Ctx { pyo3_path, .. } = &ctx;
347
348    let python_name = name
349        .as_ref()
350        .map_or_else(|| &func.sig.ident, |name| &name.value.0)
351        .unraw();
352
353    let tp = if pass_module.is_some() {
354        let span = match func.sig.inputs.first() {
355            Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
356            Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
357                func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
358            ),
359        };
360        method::FnType::FnModule(span)
361    } else {
362        method::FnType::FnStatic
363    };
364
365    let arguments = func
366        .sig
367        .inputs
368        .iter_mut()
369        .skip(if tp.skip_first_rust_argument_in_python_signature() {
370            1
371        } else {
372            0
373        })
374        .map(FnArg::parse)
375        .collect::<syn::Result<Vec<_>>>()?;
376
377    let signature = if let Some(signature) = signature {
378        FunctionSignature::from_arguments_and_attribute(arguments, signature)?
379    } else {
380        FunctionSignature::from_arguments(arguments)
381    };
382
383    let vis = &func.vis;
384    let name = &func.sig.ident;
385
386    #[cfg(feature = "experimental-inspect")]
387    let introspection = function_introspection_code(
388        pyo3_path,
389        Some(name),
390        &name.to_string(),
391        &signature,
392        None,
393        [] as [String; 0],
394        None,
395    );
396    #[cfg(not(feature = "experimental-inspect"))]
397    let introspection = quote! {};
398    #[cfg(feature = "experimental-inspect")]
399    let introspection_id = introspection_id_const();
400    #[cfg(not(feature = "experimental-inspect"))]
401    let introspection_id = quote! {};
402
403    let spec = method::FnSpec {
404        tp,
405        name: &func.sig.ident,
406        convention: CallingConvention::from_signature(&signature),
407        python_name,
408        signature,
409        text_signature,
410        asyncness: func.sig.asyncness,
411        unsafety: func.sig.unsafety,
412        warnings,
413    };
414
415    let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
416    if spec.asyncness.is_some() {
417        ensure_spanned!(
418            cfg!(feature = "experimental-async"),
419            spec.asyncness.span() => "async functions are only supported with the `experimental-async` feature"
420        );
421    }
422    let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
423    let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx);
424
425    let wrapped_pyfunction = quote! {
426        // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
427        // will actually bring both the module and the function into scope.
428        #[doc(hidden)]
429        #vis mod #name {
430            pub(crate) struct MakeDef;
431            pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
432            #introspection_id
433        }
434
435        // Generate the definition inside an anonymous function in the same scope as the original function -
436        // this avoids complications around the fact that the generated module has a different scope
437        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
438        // inside a function body)
439        #[allow(unknown_lints, non_local_definitions)]
440        impl #name::MakeDef {
441            const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
442        }
443
444        #[allow(non_snake_case)]
445        #wrapper
446
447        #introspection
448    };
449    Ok(wrapped_pyfunction)
450}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here