pyo3_macros_backend/
module.rs

1//! Code generation for the function that initializes a python module and adds classes and function.
2
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::module_introspection_code;
5use crate::{
6    attributes::{
7        self, kw, take_attributes, take_pyo3_options, CrateAttribute, GILUsedAttribute,
8        ModuleAttribute, NameAttribute, SubmoduleAttribute,
9    },
10    get_doc,
11    pyclass::PyClassPyO3Option,
12    pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
13    utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
14};
15use proc_macro2::{Span, TokenStream};
16use quote::quote;
17use std::ffi::CString;
18use syn::{
19    ext::IdentExt,
20    parse::{Parse, ParseStream},
21    parse_quote, parse_quote_spanned,
22    punctuated::Punctuated,
23    spanned::Spanned,
24    token::Comma,
25    Item, Meta, Path, Result,
26};
27
28#[derive(Default)]
29pub struct PyModuleOptions {
30    krate: Option<CrateAttribute>,
31    name: Option<NameAttribute>,
32    module: Option<ModuleAttribute>,
33    submodule: Option<kw::submodule>,
34    gil_used: Option<GILUsedAttribute>,
35}
36
37impl Parse for PyModuleOptions {
38    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
39        let mut options: PyModuleOptions = Default::default();
40
41        options.add_attributes(
42            Punctuated::<PyModulePyO3Option, syn::Token![,]>::parse_terminated(input)?,
43        )?;
44
45        Ok(options)
46    }
47}
48
49impl PyModuleOptions {
50    fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
51        self.add_attributes(take_pyo3_options(attrs)?)
52    }
53
54    fn add_attributes(
55        &mut self,
56        attrs: impl IntoIterator<Item = PyModulePyO3Option>,
57    ) -> Result<()> {
58        macro_rules! set_option {
59            ($key:ident $(, $extra:literal)?) => {
60                {
61                    ensure_spanned!(
62                        self.$key.is_none(),
63                        $key.span() => concat!("`", stringify!($key), "` may only be specified once" $(, $extra)?)
64                    );
65                    self.$key = Some($key);
66                }
67            };
68        }
69        for attr in attrs {
70            match attr {
71                PyModulePyO3Option::Crate(krate) => set_option!(krate),
72                PyModulePyO3Option::Name(name) => set_option!(name),
73                PyModulePyO3Option::Module(module) => set_option!(module),
74                PyModulePyO3Option::Submodule(submodule) => set_option!(
75                    submodule,
76                    " (it is implicitly always specified for nested modules)"
77                ),
78                PyModulePyO3Option::GILUsed(gil_used) => {
79                    set_option!(gil_used)
80                }
81            }
82        }
83        Ok(())
84    }
85}
86
87pub fn pymodule_module_impl(
88    module: &mut syn::ItemMod,
89    mut options: PyModuleOptions,
90) -> Result<TokenStream> {
91    let syn::ItemMod {
92        attrs,
93        vis,
94        unsafety: _,
95        ident,
96        mod_token,
97        content,
98        semi: _,
99    } = module;
100    let items = if let Some((_, items)) = content {
101        items
102    } else {
103        bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
104    };
105    options.take_pyo3_options(attrs)?;
106    let ctx = &Ctx::new(&options.krate, None);
107    let Ctx { pyo3_path, .. } = ctx;
108    let doc = get_doc(attrs, None, ctx);
109    let name = options
110        .name
111        .map_or_else(|| ident.unraw(), |name| name.value.0);
112    let full_name = if let Some(module) = &options.module {
113        format!("{}.{}", module.value.value(), name)
114    } else {
115        name.to_string()
116    };
117
118    let mut module_items = Vec::new();
119    let mut module_items_cfg_attrs = Vec::new();
120
121    fn extract_use_items(
122        source: &syn::UseTree,
123        cfg_attrs: &[syn::Attribute],
124        target_items: &mut Vec<syn::Ident>,
125        target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
126    ) -> Result<()> {
127        match source {
128            syn::UseTree::Name(name) => {
129                target_items.push(name.ident.clone());
130                target_cfg_attrs.push(cfg_attrs.to_vec());
131            }
132            syn::UseTree::Path(path) => {
133                extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
134            }
135            syn::UseTree::Group(group) => {
136                for tree in &group.items {
137                    extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
138                }
139            }
140            syn::UseTree::Glob(glob) => {
141                bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
142            }
143            syn::UseTree::Rename(rename) => {
144                target_items.push(rename.rename.clone());
145                target_cfg_attrs.push(cfg_attrs.to_vec());
146            }
147        }
148        Ok(())
149    }
150
151    let mut pymodule_init = None;
152
153    for item in &mut *items {
154        match item {
155            Item::Use(item_use) => {
156                let is_pymodule_export =
157                    find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
158                if is_pymodule_export {
159                    let cfg_attrs = get_cfg_attributes(&item_use.attrs);
160                    extract_use_items(
161                        &item_use.tree,
162                        &cfg_attrs,
163                        &mut module_items,
164                        &mut module_items_cfg_attrs,
165                    )?;
166                }
167            }
168            Item::Fn(item_fn) => {
169                ensure_spanned!(
170                    !has_attribute(&item_fn.attrs, "pymodule_export"),
171                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
172                );
173                let is_pymodule_init =
174                    find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
175                let ident = &item_fn.sig.ident;
176                if is_pymodule_init {
177                    ensure_spanned!(
178                        !has_attribute(&item_fn.attrs, "pyfunction"),
179                        item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
180                    );
181                    ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
182                    pymodule_init = Some(quote! { #ident(module)?; });
183                } else if has_attribute(&item_fn.attrs, "pyfunction")
184                    || has_attribute_with_namespace(
185                        &item_fn.attrs,
186                        Some(pyo3_path),
187                        &["pyfunction"],
188                    )
189                    || has_attribute_with_namespace(
190                        &item_fn.attrs,
191                        Some(pyo3_path),
192                        &["prelude", "pyfunction"],
193                    )
194                {
195                    module_items.push(ident.clone());
196                    module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
197                }
198            }
199            Item::Struct(item_struct) => {
200                ensure_spanned!(
201                    !has_attribute(&item_struct.attrs, "pymodule_export"),
202                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
203                );
204                if has_attribute(&item_struct.attrs, "pyclass")
205                    || has_attribute_with_namespace(
206                        &item_struct.attrs,
207                        Some(pyo3_path),
208                        &["pyclass"],
209                    )
210                    || has_attribute_with_namespace(
211                        &item_struct.attrs,
212                        Some(pyo3_path),
213                        &["prelude", "pyclass"],
214                    )
215                {
216                    module_items.push(item_struct.ident.clone());
217                    module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
218                    if !has_pyo3_module_declared::<PyClassPyO3Option>(
219                        &item_struct.attrs,
220                        "pyclass",
221                        |option| matches!(option, PyClassPyO3Option::Module(_)),
222                    )? {
223                        set_module_attribute(&mut item_struct.attrs, &full_name);
224                    }
225                }
226            }
227            Item::Enum(item_enum) => {
228                ensure_spanned!(
229                    !has_attribute(&item_enum.attrs, "pymodule_export"),
230                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
231                );
232                if has_attribute(&item_enum.attrs, "pyclass")
233                    || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
234                    || has_attribute_with_namespace(
235                        &item_enum.attrs,
236                        Some(pyo3_path),
237                        &["prelude", "pyclass"],
238                    )
239                {
240                    module_items.push(item_enum.ident.clone());
241                    module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
242                    if !has_pyo3_module_declared::<PyClassPyO3Option>(
243                        &item_enum.attrs,
244                        "pyclass",
245                        |option| matches!(option, PyClassPyO3Option::Module(_)),
246                    )? {
247                        set_module_attribute(&mut item_enum.attrs, &full_name);
248                    }
249                }
250            }
251            Item::Mod(item_mod) => {
252                ensure_spanned!(
253                    !has_attribute(&item_mod.attrs, "pymodule_export"),
254                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
255                );
256                if has_attribute(&item_mod.attrs, "pymodule")
257                    || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
258                    || has_attribute_with_namespace(
259                        &item_mod.attrs,
260                        Some(pyo3_path),
261                        &["prelude", "pymodule"],
262                    )
263                {
264                    module_items.push(item_mod.ident.clone());
265                    module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
266                    if !has_pyo3_module_declared::<PyModulePyO3Option>(
267                        &item_mod.attrs,
268                        "pymodule",
269                        |option| matches!(option, PyModulePyO3Option::Module(_)),
270                    )? {
271                        set_module_attribute(&mut item_mod.attrs, &full_name);
272                    }
273                    item_mod
274                        .attrs
275                        .push(parse_quote_spanned!(item_mod.mod_token.span()=> #[pyo3(submodule)]));
276                }
277            }
278            Item::ForeignMod(item) => {
279                ensure_spanned!(
280                    !has_attribute(&item.attrs, "pymodule_export"),
281                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
282                );
283            }
284            Item::Trait(item) => {
285                ensure_spanned!(
286                    !has_attribute(&item.attrs, "pymodule_export"),
287                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
288                );
289            }
290            Item::Const(item) => {
291                ensure_spanned!(
292                    !has_attribute(&item.attrs, "pymodule_export"),
293                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
294                );
295            }
296            Item::Static(item) => {
297                ensure_spanned!(
298                    !has_attribute(&item.attrs, "pymodule_export"),
299                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
300                );
301            }
302            Item::Macro(item) => {
303                ensure_spanned!(
304                    !has_attribute(&item.attrs, "pymodule_export"),
305                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
306                );
307            }
308            Item::ExternCrate(item) => {
309                ensure_spanned!(
310                    !has_attribute(&item.attrs, "pymodule_export"),
311                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
312                );
313            }
314            Item::Impl(item) => {
315                ensure_spanned!(
316                    !has_attribute(&item.attrs, "pymodule_export"),
317                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
318                );
319            }
320            Item::TraitAlias(item) => {
321                ensure_spanned!(
322                    !has_attribute(&item.attrs, "pymodule_export"),
323                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
324                );
325            }
326            Item::Type(item) => {
327                ensure_spanned!(
328                    !has_attribute(&item.attrs, "pymodule_export"),
329                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
330                );
331            }
332            Item::Union(item) => {
333                ensure_spanned!(
334                    !has_attribute(&item.attrs, "pymodule_export"),
335                    item.span() => "`#[pymodule_export]` may only be used on `use` statements"
336                );
337            }
338            _ => (),
339        }
340    }
341
342    #[cfg(feature = "experimental-inspect")]
343    let introspection = module_introspection_code(
344        pyo3_path,
345        &name.to_string(),
346        &module_items,
347        &module_items_cfg_attrs,
348    );
349    #[cfg(not(feature = "experimental-inspect"))]
350    let introspection = quote! {};
351
352    let module_def = quote! {{
353        use #pyo3_path::impl_::pymodule as impl_;
354        const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
355        unsafe {
356           impl_::ModuleDef::new(
357                __PYO3_NAME,
358                #doc,
359                INITIALIZER
360            )
361        }
362    }};
363    let initialization = module_initialization(
364        &name,
365        ctx,
366        module_def,
367        options.submodule.is_some(),
368        options.gil_used.map_or(true, |op| op.value.value),
369    );
370
371    Ok(quote!(
372        #(#attrs)*
373        #vis #mod_token #ident {
374            #(#items)*
375
376            #initialization
377            #introspection
378
379            fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
380                use #pyo3_path::impl_::pymodule::PyAddToModule;
381                #(
382                    #(#module_items_cfg_attrs)*
383                    #module_items::_PYO3_DEF.add_to_module(module)?;
384                )*
385                #pymodule_init
386                ::std::result::Result::Ok(())
387            }
388        }
389    ))
390}
391
392/// Generates the function that is called by the python interpreter to initialize the native
393/// module
394pub fn pymodule_function_impl(
395    function: &mut syn::ItemFn,
396    mut options: PyModuleOptions,
397) -> Result<TokenStream> {
398    options.take_pyo3_options(&mut function.attrs)?;
399    process_functions_in_module(&options, function)?;
400    let ctx = &Ctx::new(&options.krate, None);
401    let Ctx { pyo3_path, .. } = ctx;
402    let ident = &function.sig.ident;
403    let name = options
404        .name
405        .map_or_else(|| ident.unraw(), |name| name.value.0);
406    let vis = &function.vis;
407    let doc = get_doc(&function.attrs, None, ctx);
408
409    let initialization = module_initialization(
410        &name,
411        ctx,
412        quote! { MakeDef::make_def() },
413        false,
414        options.gil_used.map_or(true, |op| op.value.value),
415    );
416
417    #[cfg(feature = "experimental-inspect")]
418    let introspection = module_introspection_code(pyo3_path, &name.to_string(), &[], &[]);
419    #[cfg(not(feature = "experimental-inspect"))]
420    let introspection = quote! {};
421
422    // Module function called with optional Python<'_> marker as first arg, followed by the module.
423    let mut module_args = Vec::new();
424    if function.sig.inputs.len() == 2 {
425        module_args.push(quote!(module.py()));
426    }
427    module_args
428        .push(quote!(::std::convert::Into::into(#pyo3_path::impl_::pymethods::BoundRef(module))));
429
430    Ok(quote! {
431        #[doc(hidden)]
432        #vis mod #ident {
433            #initialization
434            #introspection
435        }
436
437        // Generate the definition inside an anonymous function in the same scope as the original function -
438        // this avoids complications around the fact that the generated module has a different scope
439        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pymodule] is
440        // inside a function body)
441        #[allow(unknown_lints, non_local_definitions)]
442        impl #ident::MakeDef {
443            const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
444                fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
445                    #ident(#(#module_args),*)
446                }
447
448                const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
449                unsafe {
450                    #pyo3_path::impl_::pymodule::ModuleDef::new(
451                        #ident::__PYO3_NAME,
452                        #doc,
453                        INITIALIZER
454                    )
455                }
456            }
457        }
458    })
459}
460
461fn module_initialization(
462    name: &syn::Ident,
463    ctx: &Ctx,
464    module_def: TokenStream,
465    is_submodule: bool,
466    gil_used: bool,
467) -> TokenStream {
468    let Ctx { pyo3_path, .. } = ctx;
469    let pyinit_symbol = format!("PyInit_{}", name);
470    let name = name.to_string();
471    let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
472
473    let mut result = quote! {
474        #[doc(hidden)]
475        pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
476
477        pub(super) struct MakeDef;
478        #[doc(hidden)]
479        pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
480        #[doc(hidden)]
481        // so wrapped submodules can see what gil_used is
482        pub static __PYO3_GIL_USED: bool = #gil_used;
483    };
484    if !is_submodule {
485        result.extend(quote! {
486            /// This autogenerated function is called by the python interpreter when importing
487            /// the module.
488            #[doc(hidden)]
489            #[export_name = #pyinit_symbol]
490            pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
491                unsafe { #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py, #gil_used)) }
492            }
493        });
494    }
495    result
496}
497
498/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
499fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
500    let ctx = &Ctx::new(&options.krate, None);
501    let Ctx { pyo3_path, .. } = ctx;
502    let mut stmts: Vec<syn::Stmt> = Vec::new();
503
504    for mut stmt in func.block.stmts.drain(..) {
505        if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
506            if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
507                let module_name = pyfn_args.modname;
508                let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
509                let name = &func.sig.ident;
510                let statements: Vec<syn::Stmt> = syn::parse_quote! {
511                    #wrapped_function
512                    {
513                        use #pyo3_path::types::PyModuleMethods;
514                        #module_name.add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
515                    }
516                };
517                stmts.extend(statements);
518            }
519        };
520        stmts.push(stmt);
521    }
522
523    func.block.stmts = stmts;
524    Ok(())
525}
526
527pub struct PyFnArgs {
528    modname: Path,
529    options: PyFunctionOptions,
530}
531
532impl Parse for PyFnArgs {
533    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
534        let modname = input.parse().map_err(
535            |e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
536        )?;
537
538        if input.is_empty() {
539            return Ok(Self {
540                modname,
541                options: Default::default(),
542            });
543        }
544
545        let _: Comma = input.parse()?;
546
547        Ok(Self {
548            modname,
549            options: input.parse()?,
550        })
551    }
552}
553
554/// Extracts the data from the #[pyfn(...)] attribute of a function
555fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
556    let mut pyfn_args: Option<PyFnArgs> = None;
557
558    take_attributes(attrs, |attr| {
559        if attr.path().is_ident("pyfn") {
560            ensure_spanned!(
561                pyfn_args.is_none(),
562                attr.span() => "`#[pyfn] may only be specified once"
563            );
564            pyfn_args = Some(attr.parse_args()?);
565            Ok(true)
566        } else {
567            Ok(false)
568        }
569    })?;
570
571    if let Some(pyfn_args) = &mut pyfn_args {
572        pyfn_args
573            .options
574            .add_attributes(take_pyo3_options(attrs)?)?;
575    }
576
577    Ok(pyfn_args)
578}
579
580fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
581    attrs
582        .iter()
583        .filter(|attr| attr.path().is_ident("cfg"))
584        .cloned()
585        .collect()
586}
587
588fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
589    let mut found = false;
590    attrs.retain(|attr| {
591        if attr.path().is_ident(ident) {
592            found = true;
593            false
594        } else {
595            true
596        }
597    });
598    found
599}
600
601impl PartialEq<syn::Ident> for IdentOrStr<'_> {
602    fn eq(&self, other: &syn::Ident) -> bool {
603        match self {
604            IdentOrStr::Str(s) => other == s,
605            IdentOrStr::Ident(i) => other == i,
606        }
607    }
608}
609
610fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
611    attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
612}
613
614fn has_pyo3_module_declared<T: Parse>(
615    attrs: &[syn::Attribute],
616    root_attribute_name: &str,
617    is_module_option: impl Fn(&T) -> bool + Copy,
618) -> Result<bool> {
619    for attr in attrs {
620        if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
621            && matches!(attr.meta, Meta::List(_))
622        {
623            for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
624                if is_module_option(option) {
625                    return Ok(true);
626                }
627            }
628        }
629    }
630    Ok(false)
631}
632
633enum PyModulePyO3Option {
634    Submodule(SubmoduleAttribute),
635    Crate(CrateAttribute),
636    Name(NameAttribute),
637    Module(ModuleAttribute),
638    GILUsed(GILUsedAttribute),
639}
640
641impl Parse for PyModulePyO3Option {
642    fn parse(input: ParseStream<'_>) -> Result<Self> {
643        let lookahead = input.lookahead1();
644        if lookahead.peek(attributes::kw::name) {
645            input.parse().map(PyModulePyO3Option::Name)
646        } else if lookahead.peek(syn::Token![crate]) {
647            input.parse().map(PyModulePyO3Option::Crate)
648        } else if lookahead.peek(attributes::kw::module) {
649            input.parse().map(PyModulePyO3Option::Module)
650        } else if lookahead.peek(attributes::kw::submodule) {
651            input.parse().map(PyModulePyO3Option::Submodule)
652        } else if lookahead.peek(attributes::kw::gil_used) {
653            input.parse().map(PyModulePyO3Option::GILUsed)
654        } else {
655            Err(lookahead.error())
656        }
657    }
658}
⚠️ Internal Docs ⚠️ Not Public API 👉 Official Docs Here