pyo3_macros_backend/
module.rs

1//! Code generation for the function that initializes a python module and adds classes and function.
2
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{introspection_id_const, module_introspection_code};
5use crate::{
6    attributes::{
7        self, kw, take_attributes, take_pyo3_options, CrateAttribute, GILUsedAttribute,
8        ModuleAttribute, NameAttribute, SubmoduleAttribute,
9    },
10    get_doc,
11    pyclass::PyClassPyO3Option,
12    pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
13    utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
14};
15use proc_macro2::{Span, TokenStream};
16use quote::quote;
17use std::ffi::CString;
18use syn::{
19    ext::IdentExt,
20    parse::{Parse, ParseStream},
21    parse_quote, parse_quote_spanned,
22    punctuated::Punctuated,
23    spanned::Spanned,
24    token::Comma,
25    Item, Meta, Path, Result,
26};
27
28#[derive(Default)]
29pub struct PyModuleOptions {
30    krate: Option<CrateAttribute>,
31    name: Option<NameAttribute>,
32    module: Option<ModuleAttribute>,
33    submodule: Option<kw::submodule>,
34    gil_used: Option<GILUsedAttribute>,
35}
36
37impl Parse for PyModuleOptions {
38    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
39        let mut options: PyModuleOptions = Default::default();
40
41        options.add_attributes(
42            Punctuated::<PyModulePyO3Option, syn::Token![,]>::parse_terminated(input)?,
43        )?;
44
45        Ok(options)
46    }
47}
48
49impl PyModuleOptions {
50    fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
51        self.add_attributes(take_pyo3_options(attrs)?)
52    }
53
54    fn add_attributes(
55        &mut self,
56        attrs: impl IntoIterator<Item = PyModulePyO3Option>,
57    ) -> Result<()> {
58        macro_rules! set_option {
59            ($key:ident $(, $extra:literal)?) => {
60                {
61                    ensure_spanned!(
62                        self.$key.is_none(),
63                        $key.span() => concat!("`", stringify!($key), "` may only be specified once" $(, $extra)?)
64                    );
65                    self.$key = Some($key);
66                }
67            };
68        }
69        for attr in attrs {
70            match attr {
71                PyModulePyO3Option::Crate(krate) => set_option!(krate),
72                PyModulePyO3Option::Name(name) => set_option!(name),
73                PyModulePyO3Option::Module(module) => set_option!(module),
74                PyModulePyO3Option::Submodule(submodule) => set_option!(
75                    submodule,
76                    " (it is implicitly always specified for nested modules)"
77                ),
78                PyModulePyO3Option::GILUsed(gil_used) => {
79                    set_option!(gil_used)
80                }
81            }
82        }
83        Ok(())
84    }
85}
86
87pub fn pymodule_module_impl(
88    module: &mut syn::ItemMod,
89    mut options: PyModuleOptions,
90) -> Result<TokenStream> {
91    let syn::ItemMod {
92        attrs,
93        vis,
94        unsafety: _,
95        ident,
96        mod_token,
97        content,
98        semi: _,
99    } = module;
100    let items = if let Some((_, items)) = content {
101        items
102    } else {
103        bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
104    };
105    options.take_pyo3_options(attrs)?;
106    let ctx = &Ctx::new(&options.krate, None);
107    let Ctx { pyo3_path, .. } = ctx;
108    let doc = get_doc(attrs, None, ctx);
109    let name = options
110        .name
111        .map_or_else(|| ident.unraw(), |name| name.value.0);
112    let full_name = if let Some(module) = &options.module {
113        format!("{}.{}", module.value.value(), name)
114    } else {
115        name.to_string()
116    };
117
118    let mut module_items = Vec::new();
119    let mut module_items_cfg_attrs = Vec::new();
120
121    fn extract_use_items(
122        source: &syn::UseTree,
123        cfg_attrs: &[syn::Attribute],
124        target_items: &mut Vec<syn::Ident>,
125        target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
126    ) -> Result<()> {
127        match source {
128            syn::UseTree::Name(name) => {
129                target_items.push(name.ident.clone());
130                target_cfg_attrs.push(cfg_attrs.to_vec());
131            }
132            syn::UseTree::Path(path) => {
133                extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
134            }
135            syn::UseTree::Group(group) => {
136                for tree in &group.items {
137                    extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
138                }
139            }
140            syn::UseTree::Glob(glob) => {
141                bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
142            }
143            syn::UseTree::Rename(rename) => {
144                target_items.push(rename.rename.clone());
145                target_cfg_attrs.push(cfg_attrs.to_vec());
146            }
147        }
148        Ok(())
149    }
150
151    let mut pymodule_init = None;
152
153    for item in &mut *items {
154        match item {
155            Item::Use(item_use) => {
156                let is_pymodule_export =
157                    find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
158                if is_pymodule_export {
159                    let cfg_attrs = get_cfg_attributes(&item_use.attrs);
160                    extract_use_items(
161                        &item_use.tree,
162                        &cfg_attrs,
163                        &mut module_items,
164                        &mut module_items_cfg_attrs,
165                    )?;
166                }
167            }
168            Item::Fn(item_fn) => {
169                ensure_spanned!(
170                    !has_attribute(&item_fn.attrs, "pymodule_export"),
171                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
172                );
173                let is_pymodule_init =
174                    find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
175                let ident = &item_fn.sig.ident;
176                if is_pymodule_init {
177                    ensure_spanned!(
178                        !has_attribute(&item_fn.attrs, "pyfunction"),
179                        item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
180                    );
181                    ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
182                    pymodule_init = Some(quote! { #ident(module)?; });
183                } else if has_attribute(&item_fn.attrs, "pyfunction")
184                    || has_attribute_with_namespace(
185                        &item_fn.attrs,
186                        Some(pyo3_path),
187                        &["pyfunction"],
188                    )
189                    || has_attribute_with_namespace(
190                        &item_fn.attrs,
191                        Some(pyo3_path),
192                        &["prelude", "pyfunction"],
193                    )
194                {
195                    module_items.push(ident.clone());
196                    module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
197                }
198            }
199            Item::Struct(item_struct) => {
200                ensure_spanned!(
201                    !has_attribute(&item_struct.attrs, "pymodule_export"),
202                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
203                );
204                if has_attribute(&item_struct.attrs, "pyclass")
205                    || has_attribute_with_namespace(
206                        &item_struct.attrs,
207                        Some(pyo3_path),
208                        &["pyclass"],
209                    )
210                    || has_attribute_with_namespace(
211                        &item_struct.attrs,
212                        Some(pyo3_path),
213                        &["prelude", "pyclass"],
214                    )
215                {
216                    module_items.push(item_struct.ident.clone());
217                    module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
218                    if !has_pyo3_module_declared::<PyClassPyO3Option>(
219                        &item_struct.attrs,
220                        "pyclass",
221                        |option| matches!(option, PyClassPyO3Option::Module(_)),
222                    )? {
223                        set_module_attribute(&mut item_struct.attrs, &full_name);
224                    }
225                }
226            }
227            Item::Enum(item_enum) => {
228                ensure_spanned!(
229                    !has_attribute(&item_enum.attrs, "pymodule_export"),
230                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
231                );
232                if has_attribute(&item_enum.attrs, "pyclass")
233                    || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
234                    || has_attribute_with_namespace(
235                        &item_enum.attrs,
236                        Some(pyo3_path),
237                        &["prelude", "pyclass"],
238                    )
239                {
240                    module_items.push(item_enum.ident.clone());
241                    module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
242                    if !has_pyo3_module_declared::<PyClassPyO3Option>(
243                        &item_enum.attrs,
244                        "pyclass",
245                        |option| matches!(option, PyClassPyO3Option::Module(_)),
246                    )? {
247                        set_module_attribute(&mut item_enum.attrs, &full_name);
248                    }
249                }
250            }
251            Item::Mod(item_mod) => {
252                ensure_spanned!(
253                    !has_attribute(&item_mod.attrs, "pymodule_export"),
254                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
255                );
256                if has_attribute(&item_mod.attrs, "pymodule")
257                    || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
258                    || has_attribute_with_namespace(
259                        &item_mod.attrs,
260                        Some(pyo3_path),
261                        &["prelude", "pymodule"],
262                    )
263                {
264                    module_items.push(item_mod.ident.clone());
265                    module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
266                    if !has_pyo3_module_declared::<PyModulePyO3Option>(
267                        &item_mod.attrs,
268                        "pymodule",
269                        |option| matches!(option, PyModulePyO3Option::Module(_)),
270                    )? {
271                        set_module_attribute(&mut item_mod.attrs, &full_name);
272                    }
273                    item_mod
274                        .attrs
275                        .push(parse_quote_spanned!(item_mod.mod_token.span()=> #[pyo3(submodule)]));
276                }
277            }
278            Item::ForeignMod(item) => {
279                ensure_spanned!(
280                    !has_attribute(&item.attrs, "pymodule_export"),
281                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
282                );
283            }
284            Item::Trait(item) => {
285                ensure_spanned!(
286                    !has_attribute(&item.attrs, "pymodule_export"),
287                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
288                );
289            }
290            Item::Const(item) => {
291                ensure_spanned!(
292                    !has_attribute(&item.attrs, "pymodule_export"),
293                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
294                );
295            }
296            Item::Static(item) => {
297                ensure_spanned!(
298                    !has_attribute(&item.attrs, "pymodule_export"),
299                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
300                );
301            }
302            Item::Macro(item) => {
303                ensure_spanned!(
304                    !has_attribute(&item.attrs, "pymodule_export"),
305                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
306                );
307            }
308            Item::ExternCrate(item) => {
309                ensure_spanned!(
310                    !has_attribute(&item.attrs, "pymodule_export"),
311                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
312                );
313            }
314            Item::Impl(item) => {
315                ensure_spanned!(
316                    !has_attribute(&item.attrs, "pymodule_export"),
317                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
318                );
319            }
320            Item::TraitAlias(item) => {
321                ensure_spanned!(
322                    !has_attribute(&item.attrs, "pymodule_export"),
323                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
324                );
325            }
326            Item::Type(item) => {
327                ensure_spanned!(
328                    !has_attribute(&item.attrs, "pymodule_export"),
329                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
330                );
331            }
332            Item::Union(item) => {
333                ensure_spanned!(
334                    !has_attribute(&item.attrs, "pymodule_export"),
335                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
336                );
337            }
338            _ => (),
339        }
340    }
341
342    #[cfg(feature = "experimental-inspect")]
343    let introspection = module_introspection_code(
344        pyo3_path,
345        &name.to_string(),
346        &module_items,
347        &module_items_cfg_attrs,
348    );
349    #[cfg(not(feature = "experimental-inspect"))]
350    let introspection = quote! {};
351    #[cfg(feature = "experimental-inspect")]
352    let introspection_id = introspection_id_const();
353    #[cfg(not(feature = "experimental-inspect"))]
354    let introspection_id = quote! {};
355
356    let module_def = quote! {{
357        use #pyo3_path::impl_::pymodule as impl_;
358        const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
359        unsafe {
360           impl_::ModuleDef::new(
361                __PYO3_NAME,
362                #doc,
363                INITIALIZER
364            )
365        }
366    }};
367    let initialization = module_initialization(
368        &name,
369        ctx,
370        module_def,
371        options.submodule.is_some(),
372        options.gil_used.map_or(true, |op| op.value.value),
373    );
374
375    Ok(quote!(
376        #(#attrs)*
377        #vis #mod_token #ident {
378            #(#items)*
379
380            #initialization
381            #introspection
382            #introspection_id
383
384            fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
385                use #pyo3_path::impl_::pymodule::PyAddToModule;
386                #(
387                    #(#module_items_cfg_attrs)*
388                    #module_items::_PYO3_DEF.add_to_module(module)?;
389                )*
390                #pymodule_init
391                ::std::result::Result::Ok(())
392            }
393        }
394    ))
395}
396
397/// Generates the function that is called by the python interpreter to initialize the native
398/// module
399pub fn pymodule_function_impl(
400    function: &mut syn::ItemFn,
401    mut options: PyModuleOptions,
402) -> Result<TokenStream> {
403    options.take_pyo3_options(&mut function.attrs)?;
404    process_functions_in_module(&options, function)?;
405    let ctx = &Ctx::new(&options.krate, None);
406    let Ctx { pyo3_path, .. } = ctx;
407    let ident = &function.sig.ident;
408    let name = options
409        .name
410        .map_or_else(|| ident.unraw(), |name| name.value.0);
411    let vis = &function.vis;
412    let doc = get_doc(&function.attrs, None, ctx);
413
414    let initialization = module_initialization(
415        &name,
416        ctx,
417        quote! { MakeDef::make_def() },
418        false,
419        options.gil_used.map_or(true, |op| op.value.value),
420    );
421
422    #[cfg(feature = "experimental-inspect")]
423    let introspection = module_introspection_code(pyo3_path, &name.to_string(), &[], &[]);
424    #[cfg(not(feature = "experimental-inspect"))]
425    let introspection = quote! {};
426    #[cfg(feature = "experimental-inspect")]
427    let introspection_id = introspection_id_const();
428    #[cfg(not(feature = "experimental-inspect"))]
429    let introspection_id = quote! {};
430
431    // Module function called with optional Python<'_> marker as first arg, followed by the module.
432    let mut module_args = Vec::new();
433    if function.sig.inputs.len() == 2 {
434        module_args.push(quote!(module.py()));
435    }
436    module_args
437        .push(quote!(::std::convert::Into::into(#pyo3_path::impl_::pymethods::BoundRef(module))));
438
439    Ok(quote! {
440        #[doc(hidden)]
441        #vis mod #ident {
442            #initialization
443            #introspection
444            #introspection_id
445        }
446
447        // Generate the definition inside an anonymous function in the same scope as the original function -
448        // this avoids complications around the fact that the generated module has a different scope
449        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pymodule] is
450        // inside a function body)
451        #[allow(unknown_lints, non_local_definitions)]
452        impl #ident::MakeDef {
453            const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
454                fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
455                    #ident(#(#module_args),*)
456                }
457
458                const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
459                unsafe {
460                    #pyo3_path::impl_::pymodule::ModuleDef::new(
461                        #ident::__PYO3_NAME,
462                        #doc,
463                        INITIALIZER
464                    )
465                }
466            }
467        }
468    })
469}
470
471fn module_initialization(
472    name: &syn::Ident,
473    ctx: &Ctx,
474    module_def: TokenStream,
475    is_submodule: bool,
476    gil_used: bool,
477) -> TokenStream {
478    let Ctx { pyo3_path, .. } = ctx;
479    let pyinit_symbol = format!("PyInit_{}", name);
480    let name = name.to_string();
481    let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
482
483    let mut result = quote! {
484        #[doc(hidden)]
485        pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
486
487        pub(super) struct MakeDef;
488        #[doc(hidden)]
489        pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
490        #[doc(hidden)]
491        // so wrapped submodules can see what gil_used is
492        pub static __PYO3_GIL_USED: bool = #gil_used;
493    };
494    if !is_submodule {
495        result.extend(quote! {
496            /// This autogenerated function is called by the python interpreter when importing
497            /// the module.
498            #[doc(hidden)]
499            #[export_name = #pyinit_symbol]
500            pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
501                unsafe { #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py, #gil_used)) }
502            }
503        });
504    }
505    result
506}
507
508/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
509fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
510    let ctx = &Ctx::new(&options.krate, None);
511    let Ctx { pyo3_path, .. } = ctx;
512    let mut stmts: Vec<syn::Stmt> = Vec::new();
513
514    for mut stmt in func.block.stmts.drain(..) {
515        if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
516            if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
517                let module_name = pyfn_args.modname;
518                let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
519                let name = &func.sig.ident;
520                let statements: Vec<syn::Stmt> = syn::parse_quote! {
521                    #wrapped_function
522                    {
523                        use #pyo3_path::types::PyModuleMethods;
524                        #module_name.add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
525                    }
526                };
527                stmts.extend(statements);
528            }
529        };
530        stmts.push(stmt);
531    }
532
533    func.block.stmts = stmts;
534    Ok(())
535}
536
537pub struct PyFnArgs {
538    modname: Path,
539    options: PyFunctionOptions,
540}
541
542impl Parse for PyFnArgs {
543    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
544        let modname = input.parse().map_err(
545            |e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
546        )?;
547
548        if input.is_empty() {
549            return Ok(Self {
550                modname,
551                options: Default::default(),
552            });
553        }
554
555        let _: Comma = input.parse()?;
556
557        Ok(Self {
558            modname,
559            options: input.parse()?,
560        })
561    }
562}
563
564/// Extracts the data from the #[pyfn(...)] attribute of a function
565fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
566    let mut pyfn_args: Option<PyFnArgs> = None;
567
568    take_attributes(attrs, |attr| {
569        if attr.path().is_ident("pyfn") {
570            ensure_spanned!(
571                pyfn_args.is_none(),
572                attr.span() => "`#[pyfn] may only be specified once"
573            );
574            pyfn_args = Some(attr.parse_args()?);
575            Ok(true)
576        } else {
577            Ok(false)
578        }
579    })?;
580
581    if let Some(pyfn_args) = &mut pyfn_args {
582        pyfn_args
583            .options
584            .add_attributes(take_pyo3_options(attrs)?)?;
585    }
586
587    Ok(pyfn_args)
588}
589
590fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
591    attrs
592        .iter()
593        .filter(|attr| attr.path().is_ident("cfg"))
594        .cloned()
595        .collect()
596}
597
598fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
599    let mut found = false;
600    attrs.retain(|attr| {
601        if attr.path().is_ident(ident) {
602            found = true;
603            false
604        } else {
605            true
606        }
607    });
608    found
609}
610
611impl PartialEq<syn::Ident> for IdentOrStr<'_> {
612    fn eq(&self, other: &syn::Ident) -> bool {
613        match self {
614            IdentOrStr::Str(s) => other == s,
615            IdentOrStr::Ident(i) => other == i,
616        }
617    }
618}
619
620fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
621    attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
622}
623
624fn has_pyo3_module_declared<T: Parse>(
625    attrs: &[syn::Attribute],
626    root_attribute_name: &str,
627    is_module_option: impl Fn(&T) -> bool + Copy,
628) -> Result<bool> {
629    for attr in attrs {
630        if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
631            && matches!(attr.meta, Meta::List(_))
632        {
633            for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
634                if is_module_option(option) {
635                    return Ok(true);
636                }
637            }
638        }
639    }
640    Ok(false)
641}
642
643enum PyModulePyO3Option {
644    Submodule(SubmoduleAttribute),
645    Crate(CrateAttribute),
646    Name(NameAttribute),
647    Module(ModuleAttribute),
648    GILUsed(GILUsedAttribute),
649}
650
651impl Parse for PyModulePyO3Option {
652    fn parse(input: ParseStream<'_>) -> Result<Self> {
653        let lookahead = input.lookahead1();
654        if lookahead.peek(attributes::kw::name) {
655            input.parse().map(PyModulePyO3Option::Name)
656        } else if lookahead.peek(syn::Token![crate]) {
657            input.parse().map(PyModulePyO3Option::Crate)
658        } else if lookahead.peek(attributes::kw::module) {
659            input.parse().map(PyModulePyO3Option::Module)
660        } else if lookahead.peek(attributes::kw::submodule) {
661            input.parse().map(PyModulePyO3Option::Submodule)
662        } else if lookahead.peek(attributes::kw::gil_used) {
663            input.parse().map(PyModulePyO3Option::GILUsed)
664        } else {
665            Err(lookahead.error())
666        }
667    }
668}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here