1#[cfg(feature = "experimental-inspect")]
4use crate::introspection::module_introspection_code;
5use crate::{
6 attributes::{
7 self, kw, take_attributes, take_pyo3_options, CrateAttribute, GILUsedAttribute,
8 ModuleAttribute, NameAttribute, SubmoduleAttribute,
9 },
10 get_doc,
11 pyclass::PyClassPyO3Option,
12 pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
13 utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, LitCStr},
14};
15use proc_macro2::{Span, TokenStream};
16use quote::quote;
17use std::ffi::CString;
18use syn::{
19 ext::IdentExt,
20 parse::{Parse, ParseStream},
21 parse_quote, parse_quote_spanned,
22 punctuated::Punctuated,
23 spanned::Spanned,
24 token::Comma,
25 Item, Meta, Path, Result,
26};
27
28#[derive(Default)]
29pub struct PyModuleOptions {
30 krate: Option<CrateAttribute>,
31 name: Option<NameAttribute>,
32 module: Option<ModuleAttribute>,
33 submodule: Option<kw::submodule>,
34 gil_used: Option<GILUsedAttribute>,
35}
36
37impl Parse for PyModuleOptions {
38 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
39 let mut options: PyModuleOptions = Default::default();
40
41 options.add_attributes(
42 Punctuated::<PyModulePyO3Option, syn::Token![,]>::parse_terminated(input)?,
43 )?;
44
45 Ok(options)
46 }
47}
48
49impl PyModuleOptions {
50 fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
51 self.add_attributes(take_pyo3_options(attrs)?)
52 }
53
54 fn add_attributes(
55 &mut self,
56 attrs: impl IntoIterator<Item = PyModulePyO3Option>,
57 ) -> Result<()> {
58 macro_rules! set_option {
59 ($key:ident $(, $extra:literal)?) => {
60 {
61 ensure_spanned!(
62 self.$key.is_none(),
63 $key.span() => concat!("`", stringify!($key), "` may only be specified once" $(, $extra)?)
64 );
65 self.$key = Some($key);
66 }
67 };
68 }
69 for attr in attrs {
70 match attr {
71 PyModulePyO3Option::Crate(krate) => set_option!(krate),
72 PyModulePyO3Option::Name(name) => set_option!(name),
73 PyModulePyO3Option::Module(module) => set_option!(module),
74 PyModulePyO3Option::Submodule(submodule) => set_option!(
75 submodule,
76 " (it is implicitly always specified for nested modules)"
77 ),
78 PyModulePyO3Option::GILUsed(gil_used) => {
79 set_option!(gil_used)
80 }
81 }
82 }
83 Ok(())
84 }
85}
86
87pub fn pymodule_module_impl(
88 module: &mut syn::ItemMod,
89 mut options: PyModuleOptions,
90) -> Result<TokenStream> {
91 let syn::ItemMod {
92 attrs,
93 vis,
94 unsafety: _,
95 ident,
96 mod_token,
97 content,
98 semi: _,
99 } = module;
100 let items = if let Some((_, items)) = content {
101 items
102 } else {
103 bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
104 };
105 options.take_pyo3_options(attrs)?;
106 let ctx = &Ctx::new(&options.krate, None);
107 let Ctx { pyo3_path, .. } = ctx;
108 let doc = get_doc(attrs, None, ctx);
109 let name = options
110 .name
111 .map_or_else(|| ident.unraw(), |name| name.value.0);
112 let full_name = if let Some(module) = &options.module {
113 format!("{}.{}", module.value.value(), name)
114 } else {
115 name.to_string()
116 };
117
118 let mut module_items = Vec::new();
119 let mut module_items_cfg_attrs = Vec::new();
120
121 fn extract_use_items(
122 source: &syn::UseTree,
123 cfg_attrs: &[syn::Attribute],
124 target_items: &mut Vec<syn::Ident>,
125 target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
126 ) -> Result<()> {
127 match source {
128 syn::UseTree::Name(name) => {
129 target_items.push(name.ident.clone());
130 target_cfg_attrs.push(cfg_attrs.to_vec());
131 }
132 syn::UseTree::Path(path) => {
133 extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
134 }
135 syn::UseTree::Group(group) => {
136 for tree in &group.items {
137 extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
138 }
139 }
140 syn::UseTree::Glob(glob) => {
141 bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
142 }
143 syn::UseTree::Rename(rename) => {
144 target_items.push(rename.rename.clone());
145 target_cfg_attrs.push(cfg_attrs.to_vec());
146 }
147 }
148 Ok(())
149 }
150
151 let mut pymodule_init = None;
152
153 for item in &mut *items {
154 match item {
155 Item::Use(item_use) => {
156 let is_pymodule_export =
157 find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
158 if is_pymodule_export {
159 let cfg_attrs = get_cfg_attributes(&item_use.attrs);
160 extract_use_items(
161 &item_use.tree,
162 &cfg_attrs,
163 &mut module_items,
164 &mut module_items_cfg_attrs,
165 )?;
166 }
167 }
168 Item::Fn(item_fn) => {
169 ensure_spanned!(
170 !has_attribute(&item_fn.attrs, "pymodule_export"),
171 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
172 );
173 let is_pymodule_init =
174 find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
175 let ident = &item_fn.sig.ident;
176 if is_pymodule_init {
177 ensure_spanned!(
178 !has_attribute(&item_fn.attrs, "pyfunction"),
179 item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
180 );
181 ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
182 pymodule_init = Some(quote! { #ident(module)?; });
183 } else if has_attribute(&item_fn.attrs, "pyfunction")
184 || has_attribute_with_namespace(
185 &item_fn.attrs,
186 Some(pyo3_path),
187 &["pyfunction"],
188 )
189 || has_attribute_with_namespace(
190 &item_fn.attrs,
191 Some(pyo3_path),
192 &["prelude", "pyfunction"],
193 )
194 {
195 module_items.push(ident.clone());
196 module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
197 }
198 }
199 Item::Struct(item_struct) => {
200 ensure_spanned!(
201 !has_attribute(&item_struct.attrs, "pymodule_export"),
202 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
203 );
204 if has_attribute(&item_struct.attrs, "pyclass")
205 || has_attribute_with_namespace(
206 &item_struct.attrs,
207 Some(pyo3_path),
208 &["pyclass"],
209 )
210 || has_attribute_with_namespace(
211 &item_struct.attrs,
212 Some(pyo3_path),
213 &["prelude", "pyclass"],
214 )
215 {
216 module_items.push(item_struct.ident.clone());
217 module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
218 if !has_pyo3_module_declared::<PyClassPyO3Option>(
219 &item_struct.attrs,
220 "pyclass",
221 |option| matches!(option, PyClassPyO3Option::Module(_)),
222 )? {
223 set_module_attribute(&mut item_struct.attrs, &full_name);
224 }
225 }
226 }
227 Item::Enum(item_enum) => {
228 ensure_spanned!(
229 !has_attribute(&item_enum.attrs, "pymodule_export"),
230 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
231 );
232 if has_attribute(&item_enum.attrs, "pyclass")
233 || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
234 || has_attribute_with_namespace(
235 &item_enum.attrs,
236 Some(pyo3_path),
237 &["prelude", "pyclass"],
238 )
239 {
240 module_items.push(item_enum.ident.clone());
241 module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
242 if !has_pyo3_module_declared::<PyClassPyO3Option>(
243 &item_enum.attrs,
244 "pyclass",
245 |option| matches!(option, PyClassPyO3Option::Module(_)),
246 )? {
247 set_module_attribute(&mut item_enum.attrs, &full_name);
248 }
249 }
250 }
251 Item::Mod(item_mod) => {
252 ensure_spanned!(
253 !has_attribute(&item_mod.attrs, "pymodule_export"),
254 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
255 );
256 if has_attribute(&item_mod.attrs, "pymodule")
257 || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
258 || has_attribute_with_namespace(
259 &item_mod.attrs,
260 Some(pyo3_path),
261 &["prelude", "pymodule"],
262 )
263 {
264 module_items.push(item_mod.ident.clone());
265 module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
266 if !has_pyo3_module_declared::<PyModulePyO3Option>(
267 &item_mod.attrs,
268 "pymodule",
269 |option| matches!(option, PyModulePyO3Option::Module(_)),
270 )? {
271 set_module_attribute(&mut item_mod.attrs, &full_name);
272 }
273 item_mod
274 .attrs
275 .push(parse_quote_spanned!(item_mod.mod_token.span()=> #[pyo3(submodule)]));
276 }
277 }
278 Item::ForeignMod(item) => {
279 ensure_spanned!(
280 !has_attribute(&item.attrs, "pymodule_export"),
281 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
282 );
283 }
284 Item::Trait(item) => {
285 ensure_spanned!(
286 !has_attribute(&item.attrs, "pymodule_export"),
287 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
288 );
289 }
290 Item::Const(item) => {
291 ensure_spanned!(
292 !has_attribute(&item.attrs, "pymodule_export"),
293 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
294 );
295 }
296 Item::Static(item) => {
297 ensure_spanned!(
298 !has_attribute(&item.attrs, "pymodule_export"),
299 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
300 );
301 }
302 Item::Macro(item) => {
303 ensure_spanned!(
304 !has_attribute(&item.attrs, "pymodule_export"),
305 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
306 );
307 }
308 Item::ExternCrate(item) => {
309 ensure_spanned!(
310 !has_attribute(&item.attrs, "pymodule_export"),
311 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
312 );
313 }
314 Item::Impl(item) => {
315 ensure_spanned!(
316 !has_attribute(&item.attrs, "pymodule_export"),
317 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
318 );
319 }
320 Item::TraitAlias(item) => {
321 ensure_spanned!(
322 !has_attribute(&item.attrs, "pymodule_export"),
323 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
324 );
325 }
326 Item::Type(item) => {
327 ensure_spanned!(
328 !has_attribute(&item.attrs, "pymodule_export"),
329 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
330 );
331 }
332 Item::Union(item) => {
333 ensure_spanned!(
334 !has_attribute(&item.attrs, "pymodule_export"),
335 item.span() => "`#[pymodule_export]` may only be used on `use` statements"
336 );
337 }
338 _ => (),
339 }
340 }
341
342 #[cfg(feature = "experimental-inspect")]
343 let introspection = module_introspection_code(
344 pyo3_path,
345 &name.to_string(),
346 &module_items,
347 &module_items_cfg_attrs,
348 );
349 #[cfg(not(feature = "experimental-inspect"))]
350 let introspection = quote! {};
351
352 let module_def = quote! {{
353 use #pyo3_path::impl_::pymodule as impl_;
354 const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(__pyo3_pymodule);
355 unsafe {
356 impl_::ModuleDef::new(
357 __PYO3_NAME,
358 #doc,
359 INITIALIZER
360 )
361 }
362 }};
363 let initialization = module_initialization(
364 &name,
365 ctx,
366 module_def,
367 options.submodule.is_some(),
368 options.gil_used.map_or(true, |op| op.value.value),
369 );
370
371 Ok(quote!(
372 #(#attrs)*
373 #vis #mod_token #ident {
374 #(#items)*
375
376 #initialization
377 #introspection
378
379 fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
380 use #pyo3_path::impl_::pymodule::PyAddToModule;
381 #(
382 #(#module_items_cfg_attrs)*
383 #module_items::_PYO3_DEF.add_to_module(module)?;
384 )*
385 #pymodule_init
386 ::std::result::Result::Ok(())
387 }
388 }
389 ))
390}
391
392pub fn pymodule_function_impl(
395 function: &mut syn::ItemFn,
396 mut options: PyModuleOptions,
397) -> Result<TokenStream> {
398 options.take_pyo3_options(&mut function.attrs)?;
399 process_functions_in_module(&options, function)?;
400 let ctx = &Ctx::new(&options.krate, None);
401 let Ctx { pyo3_path, .. } = ctx;
402 let ident = &function.sig.ident;
403 let name = options
404 .name
405 .map_or_else(|| ident.unraw(), |name| name.value.0);
406 let vis = &function.vis;
407 let doc = get_doc(&function.attrs, None, ctx);
408
409 let initialization = module_initialization(
410 &name,
411 ctx,
412 quote! { MakeDef::make_def() },
413 false,
414 options.gil_used.map_or(true, |op| op.value.value),
415 );
416
417 #[cfg(feature = "experimental-inspect")]
418 let introspection = module_introspection_code(pyo3_path, &name.to_string(), &[], &[]);
419 #[cfg(not(feature = "experimental-inspect"))]
420 let introspection = quote! {};
421
422 let mut module_args = Vec::new();
424 if function.sig.inputs.len() == 2 {
425 module_args.push(quote!(module.py()));
426 }
427 module_args
428 .push(quote!(::std::convert::Into::into(#pyo3_path::impl_::pymethods::BoundRef(module))));
429
430 Ok(quote! {
431 #[doc(hidden)]
432 #vis mod #ident {
433 #initialization
434 #introspection
435 }
436
437 #[allow(unknown_lints, non_local_definitions)]
442 impl #ident::MakeDef {
443 const fn make_def() -> #pyo3_path::impl_::pymodule::ModuleDef {
444 fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
445 #ident(#(#module_args),*)
446 }
447
448 const INITIALIZER: #pyo3_path::impl_::pymodule::ModuleInitializer = #pyo3_path::impl_::pymodule::ModuleInitializer(__pyo3_pymodule);
449 unsafe {
450 #pyo3_path::impl_::pymodule::ModuleDef::new(
451 #ident::__PYO3_NAME,
452 #doc,
453 INITIALIZER
454 )
455 }
456 }
457 }
458 })
459}
460
461fn module_initialization(
462 name: &syn::Ident,
463 ctx: &Ctx,
464 module_def: TokenStream,
465 is_submodule: bool,
466 gil_used: bool,
467) -> TokenStream {
468 let Ctx { pyo3_path, .. } = ctx;
469 let pyinit_symbol = format!("PyInit_{}", name);
470 let name = name.to_string();
471 let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
472
473 let mut result = quote! {
474 #[doc(hidden)]
475 pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
476
477 pub(super) struct MakeDef;
478 #[doc(hidden)]
479 pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def;
480 #[doc(hidden)]
481 pub static __PYO3_GIL_USED: bool = #gil_used;
483 };
484 if !is_submodule {
485 result.extend(quote! {
486 #[doc(hidden)]
489 #[export_name = #pyinit_symbol]
490 pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
491 unsafe { #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py, #gil_used)) }
492 }
493 });
494 }
495 result
496}
497
498fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
500 let ctx = &Ctx::new(&options.krate, None);
501 let Ctx { pyo3_path, .. } = ctx;
502 let mut stmts: Vec<syn::Stmt> = Vec::new();
503
504 for mut stmt in func.block.stmts.drain(..) {
505 if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
506 if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
507 let module_name = pyfn_args.modname;
508 let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
509 let name = &func.sig.ident;
510 let statements: Vec<syn::Stmt> = syn::parse_quote! {
511 #wrapped_function
512 {
513 use #pyo3_path::types::PyModuleMethods;
514 #module_name.add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
515 }
516 };
517 stmts.extend(statements);
518 }
519 };
520 stmts.push(stmt);
521 }
522
523 func.block.stmts = stmts;
524 Ok(())
525}
526
527pub struct PyFnArgs {
528 modname: Path,
529 options: PyFunctionOptions,
530}
531
532impl Parse for PyFnArgs {
533 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
534 let modname = input.parse().map_err(
535 |e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
536 )?;
537
538 if input.is_empty() {
539 return Ok(Self {
540 modname,
541 options: Default::default(),
542 });
543 }
544
545 let _: Comma = input.parse()?;
546
547 Ok(Self {
548 modname,
549 options: input.parse()?,
550 })
551 }
552}
553
554fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
556 let mut pyfn_args: Option<PyFnArgs> = None;
557
558 take_attributes(attrs, |attr| {
559 if attr.path().is_ident("pyfn") {
560 ensure_spanned!(
561 pyfn_args.is_none(),
562 attr.span() => "`#[pyfn] may only be specified once"
563 );
564 pyfn_args = Some(attr.parse_args()?);
565 Ok(true)
566 } else {
567 Ok(false)
568 }
569 })?;
570
571 if let Some(pyfn_args) = &mut pyfn_args {
572 pyfn_args
573 .options
574 .add_attributes(take_pyo3_options(attrs)?)?;
575 }
576
577 Ok(pyfn_args)
578}
579
580fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
581 attrs
582 .iter()
583 .filter(|attr| attr.path().is_ident("cfg"))
584 .cloned()
585 .collect()
586}
587
588fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
589 let mut found = false;
590 attrs.retain(|attr| {
591 if attr.path().is_ident(ident) {
592 found = true;
593 false
594 } else {
595 true
596 }
597 });
598 found
599}
600
601impl PartialEq<syn::Ident> for IdentOrStr<'_> {
602 fn eq(&self, other: &syn::Ident) -> bool {
603 match self {
604 IdentOrStr::Str(s) => other == s,
605 IdentOrStr::Ident(i) => other == i,
606 }
607 }
608}
609
610fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
611 attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
612}
613
614fn has_pyo3_module_declared<T: Parse>(
615 attrs: &[syn::Attribute],
616 root_attribute_name: &str,
617 is_module_option: impl Fn(&T) -> bool + Copy,
618) -> Result<bool> {
619 for attr in attrs {
620 if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
621 && matches!(attr.meta, Meta::List(_))
622 {
623 for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
624 if is_module_option(option) {
625 return Ok(true);
626 }
627 }
628 }
629 }
630 Ok(false)
631}
632
633enum PyModulePyO3Option {
634 Submodule(SubmoduleAttribute),
635 Crate(CrateAttribute),
636 Name(NameAttribute),
637 Module(ModuleAttribute),
638 GILUsed(GILUsedAttribute),
639}
640
641impl Parse for PyModulePyO3Option {
642 fn parse(input: ParseStream<'_>) -> Result<Self> {
643 let lookahead = input.lookahead1();
644 if lookahead.peek(attributes::kw::name) {
645 input.parse().map(PyModulePyO3Option::Name)
646 } else if lookahead.peek(syn::Token![crate]) {
647 input.parse().map(PyModulePyO3Option::Crate)
648 } else if lookahead.peek(attributes::kw::module) {
649 input.parse().map(PyModulePyO3Option::Module)
650 } else if lookahead.peek(attributes::kw::submodule) {
651 input.parse().map(PyModulePyO3Option::Submodule)
652 } else if lookahead.peek(attributes::kw::gil_used) {
653 input.parse().map(PyModulePyO3Option::GILUsed)
654 } else {
655 Err(lookahead.error())
656 }
657 }
658}