1use crate::attributes::KeywordAttribute;
2#[cfg(feature = "experimental-inspect")]
3use crate::introspection::{function_introspection_code, introspection_id_const};
4use crate::utils::{Ctx, LitCStr};
5use crate::{
6 attributes::{
7 self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
8 FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
9 },
10 method::{self, CallingConvention, FnArg},
11 pymethod::check_generic,
12};
13use proc_macro2::{Span, TokenStream};
14use quote::{format_ident, quote, ToTokens};
15use std::cmp::PartialEq;
16use std::ffi::CString;
17use syn::parse::{Parse, ParseStream};
18use syn::punctuated::Punctuated;
19use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token};
20
21mod signature;
22
23pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
24
25#[derive(Clone, Debug)]
26pub struct PyFunctionArgPyO3Attributes {
27 pub from_py_with: Option<FromPyWithAttribute>,
28 pub cancel_handle: Option<attributes::kw::cancel_handle>,
29}
30
31enum PyFunctionArgPyO3Attribute {
32 FromPyWith(FromPyWithAttribute),
33 CancelHandle(attributes::kw::cancel_handle),
34}
35
36impl Parse for PyFunctionArgPyO3Attribute {
37 fn parse(input: ParseStream<'_>) -> Result<Self> {
38 let lookahead = input.lookahead1();
39 if lookahead.peek(attributes::kw::cancel_handle) {
40 input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
41 } else if lookahead.peek(attributes::kw::from_py_with) {
42 input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
43 } else {
44 Err(lookahead.error())
45 }
46 }
47}
48
49impl PyFunctionArgPyO3Attributes {
50 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
52 let mut attributes = PyFunctionArgPyO3Attributes {
53 from_py_with: None,
54 cancel_handle: None,
55 };
56 take_attributes(attrs, |attr| {
57 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
58 for attr in pyo3_attrs {
59 match attr {
60 PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
61 ensure_spanned!(
62 attributes.from_py_with.is_none(),
63 from_py_with.span() => "`from_py_with` may only be specified once per argument"
64 );
65 attributes.from_py_with = Some(from_py_with);
66 }
67 PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
68 ensure_spanned!(
69 attributes.cancel_handle.is_none(),
70 cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
71 );
72 attributes.cancel_handle = Some(cancel_handle);
73 }
74 }
75 ensure_spanned!(
76 attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
77 attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
78 );
79 }
80 Ok(true)
81 } else {
82 Ok(false)
83 }
84 })?;
85 Ok(attributes)
86 }
87}
88
89type PyFunctionWarningMessageAttribute = KeywordAttribute<attributes::kw::message, LitStr>;
90type PyFunctionWarningCategoryAttribute = KeywordAttribute<attributes::kw::category, Path>;
91
92pub struct PyFunctionWarningAttribute {
93 pub message: PyFunctionWarningMessageAttribute,
94 pub category: Option<PyFunctionWarningCategoryAttribute>,
95 pub span: Span,
96}
97
98#[derive(PartialEq)]
99pub enum PyFunctionWarningCategory {
100 Path(Path),
101 UserWarning,
102 DeprecationWarning, }
104
105pub struct PyFunctionWarning {
106 pub message: LitStr,
107 pub category: PyFunctionWarningCategory,
108 pub span: Span,
109}
110
111impl From<PyFunctionWarningAttribute> for PyFunctionWarning {
112 fn from(value: PyFunctionWarningAttribute) -> Self {
113 Self {
114 message: value.message.value,
115 category: value
116 .category
117 .map_or(PyFunctionWarningCategory::UserWarning, |cat| {
118 PyFunctionWarningCategory::Path(cat.value)
119 }),
120 span: value.span,
121 }
122 }
123}
124
125pub trait WarningFactory {
126 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream;
127 fn span(&self) -> Span;
128}
129
130impl WarningFactory for PyFunctionWarning {
131 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
132 let message = &self.message.value();
133 let c_message = LitCStr::new(
134 CString::new(message.clone()).unwrap(),
135 Spanned::span(&message),
136 ctx,
137 );
138 let pyo3_path = &ctx.pyo3_path;
139 let category = match &self.category {
140 PyFunctionWarningCategory::Path(path) => quote! {#path},
141 PyFunctionWarningCategory::UserWarning => {
142 quote! {#pyo3_path::exceptions::PyUserWarning}
143 }
144 PyFunctionWarningCategory::DeprecationWarning => {
145 quote! {#pyo3_path::exceptions::PyDeprecationWarning}
146 }
147 };
148 quote! {
149 #pyo3_path::PyErr::warn(py, &<#category as #pyo3_path::PyTypeInfo>::type_object(py), #c_message, 1)?;
150 }
151 }
152
153 fn span(&self) -> Span {
154 self.span
155 }
156}
157
158impl<T: WarningFactory> WarningFactory for Vec<T> {
159 fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
160 let warnings = self.iter().map(|warning| warning.build_py_warning(ctx));
161
162 quote! {
163 #(#warnings)*
164 }
165 }
166
167 fn span(&self) -> Span {
168 self.iter()
169 .map(|val| val.span())
170 .reduce(|acc, span| acc.join(span).unwrap_or(acc))
171 .unwrap()
172 }
173}
174
175impl Parse for PyFunctionWarningAttribute {
176 fn parse(input: ParseStream<'_>) -> Result<Self> {
177 let mut message: Option<PyFunctionWarningMessageAttribute> = None;
178 let mut category: Option<PyFunctionWarningCategoryAttribute> = None;
179
180 let span = input.parse::<attributes::kw::warn>()?.span();
181
182 let content;
183 syn::parenthesized!(content in input);
184
185 while !content.is_empty() {
186 let lookahead = content.lookahead1();
187
188 if lookahead.peek(attributes::kw::message) {
189 message = content
190 .parse::<PyFunctionWarningMessageAttribute>()
191 .map(Some)?;
192 } else if lookahead.peek(attributes::kw::category) {
193 category = content
194 .parse::<PyFunctionWarningCategoryAttribute>()
195 .map(Some)?;
196 } else {
197 return Err(lookahead.error());
198 }
199
200 if content.peek(Token![,]) {
201 content.parse::<Token![,]>()?;
202 }
203 }
204
205 Ok(PyFunctionWarningAttribute {
206 message: message.ok_or(syn::Error::new(
207 content.span(),
208 "missing `message` in `warn` attribute",
209 ))?,
210 category,
211 span,
212 })
213 }
214}
215
216impl ToTokens for PyFunctionWarningAttribute {
217 fn to_tokens(&self, tokens: &mut TokenStream) {
218 let message_tokens = self.message.to_token_stream();
219 let category_tokens = self
220 .category
221 .as_ref()
222 .map_or(quote! {}, |cat| cat.to_token_stream());
223
224 let token_stream = quote! {
225 warn(#message_tokens, #category_tokens)
226 };
227
228 tokens.extend(token_stream);
229 }
230}
231
232#[derive(Default)]
233pub struct PyFunctionOptions {
234 pub pass_module: Option<attributes::kw::pass_module>,
235 pub name: Option<NameAttribute>,
236 pub signature: Option<SignatureAttribute>,
237 pub text_signature: Option<TextSignatureAttribute>,
238 pub krate: Option<CrateAttribute>,
239 pub warnings: Vec<PyFunctionWarning>,
240}
241
242impl Parse for PyFunctionOptions {
243 fn parse(input: ParseStream<'_>) -> Result<Self> {
244 let mut options = PyFunctionOptions::default();
245
246 let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
247 options.add_attributes(attrs)?;
248
249 Ok(options)
250 }
251}
252
253pub enum PyFunctionOption {
254 Name(NameAttribute),
255 PassModule(attributes::kw::pass_module),
256 Signature(SignatureAttribute),
257 TextSignature(TextSignatureAttribute),
258 Crate(CrateAttribute),
259 Warning(PyFunctionWarningAttribute),
260}
261
262impl Parse for PyFunctionOption {
263 fn parse(input: ParseStream<'_>) -> Result<Self> {
264 let lookahead = input.lookahead1();
265 if lookahead.peek(attributes::kw::name) {
266 input.parse().map(PyFunctionOption::Name)
267 } else if lookahead.peek(attributes::kw::pass_module) {
268 input.parse().map(PyFunctionOption::PassModule)
269 } else if lookahead.peek(attributes::kw::signature) {
270 input.parse().map(PyFunctionOption::Signature)
271 } else if lookahead.peek(attributes::kw::text_signature) {
272 input.parse().map(PyFunctionOption::TextSignature)
273 } else if lookahead.peek(syn::Token![crate]) {
274 input.parse().map(PyFunctionOption::Crate)
275 } else if lookahead.peek(attributes::kw::warn) {
276 input.parse().map(PyFunctionOption::Warning)
277 } else {
278 Err(lookahead.error())
279 }
280 }
281}
282
283impl PyFunctionOptions {
284 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
285 let mut options = PyFunctionOptions::default();
286 options.add_attributes(take_pyo3_options(attrs)?)?;
287 Ok(options)
288 }
289
290 pub fn add_attributes(
291 &mut self,
292 attrs: impl IntoIterator<Item = PyFunctionOption>,
293 ) -> Result<()> {
294 macro_rules! set_option {
295 ($key:ident) => {
296 {
297 ensure_spanned!(
298 self.$key.is_none(),
299 $key.span() => concat!("`", stringify!($key), "` may only be specified once")
300 );
301 self.$key = Some($key);
302 }
303 };
304 }
305 for attr in attrs {
306 match attr {
307 PyFunctionOption::Name(name) => set_option!(name),
308 PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
309 PyFunctionOption::Signature(signature) => set_option!(signature),
310 PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
311 PyFunctionOption::Crate(krate) => set_option!(krate),
312 PyFunctionOption::Warning(warning) => {
313 self.warnings.push(warning.into());
314 }
315 }
316 }
317 Ok(())
318 }
319}
320
321pub fn build_py_function(
322 ast: &mut syn::ItemFn,
323 mut options: PyFunctionOptions,
324) -> syn::Result<TokenStream> {
325 options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
326 impl_wrap_pyfunction(ast, options)
327}
328
329pub fn impl_wrap_pyfunction(
332 func: &mut syn::ItemFn,
333 options: PyFunctionOptions,
334) -> syn::Result<TokenStream> {
335 check_generic(&func.sig)?;
336 let PyFunctionOptions {
337 pass_module,
338 name,
339 signature,
340 text_signature,
341 krate,
342 warnings,
343 } = options;
344
345 let ctx = &Ctx::new(&krate, Some(&func.sig));
346 let Ctx { pyo3_path, .. } = &ctx;
347
348 let python_name = name
349 .as_ref()
350 .map_or_else(|| &func.sig.ident, |name| &name.value.0)
351 .unraw();
352
353 let tp = if pass_module.is_some() {
354 let span = match func.sig.inputs.first() {
355 Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
356 Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
357 func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
358 ),
359 };
360 method::FnType::FnModule(span)
361 } else {
362 method::FnType::FnStatic
363 };
364
365 let arguments = func
366 .sig
367 .inputs
368 .iter_mut()
369 .skip(if tp.skip_first_rust_argument_in_python_signature() {
370 1
371 } else {
372 0
373 })
374 .map(FnArg::parse)
375 .collect::<syn::Result<Vec<_>>>()?;
376
377 let signature = if let Some(signature) = signature {
378 FunctionSignature::from_arguments_and_attribute(arguments, signature)?
379 } else {
380 FunctionSignature::from_arguments(arguments)
381 };
382
383 let vis = &func.vis;
384 let name = &func.sig.ident;
385
386 #[cfg(feature = "experimental-inspect")]
387 let introspection = function_introspection_code(
388 pyo3_path,
389 Some(name),
390 &name.to_string(),
391 &signature,
392 None,
393 [] as [String; 0],
394 None,
395 );
396 #[cfg(not(feature = "experimental-inspect"))]
397 let introspection = quote! {};
398 #[cfg(feature = "experimental-inspect")]
399 let introspection_id = introspection_id_const();
400 #[cfg(not(feature = "experimental-inspect"))]
401 let introspection_id = quote! {};
402
403 let spec = method::FnSpec {
404 tp,
405 name: &func.sig.ident,
406 convention: CallingConvention::from_signature(&signature),
407 python_name,
408 signature,
409 text_signature,
410 asyncness: func.sig.asyncness,
411 unsafety: func.sig.unsafety,
412 warnings,
413 };
414
415 let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
416 if spec.asyncness.is_some() {
417 ensure_spanned!(
418 cfg!(feature = "experimental-async"),
419 spec.asyncness.span() => "async functions are only supported with the `experimental-async` feature"
420 );
421 }
422 let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
423 let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx);
424
425 let wrapped_pyfunction = quote! {
426 #[doc(hidden)]
429 #vis mod #name {
430 pub(crate) struct MakeDef;
431 pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
432 #introspection_id
433 }
434
435 #[allow(unknown_lints, non_local_definitions)]
440 impl #name::MakeDef {
441 const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
442 }
443
444 #[allow(non_snake_case)]
445 #wrapper
446
447 #introspection
448 };
449 Ok(wrapped_pyfunction)
450}