1use crate::attributes::{
2 self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute,
3 RenameAllAttribute, RenamingRule,
4};
5use crate::utils::{self, Ctx};
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use syn::{
9 ext::IdentExt,
10 parenthesized,
11 parse::{Parse, ParseStream},
12 parse_quote,
13 punctuated::Punctuated,
14 spanned::Spanned,
15 Attribute, DataEnum, DeriveInput, Fields, Ident, LitStr, Result, Token,
16};
17
18struct Enum<'a> {
20 enum_ident: &'a Ident,
21 variants: Vec<Container<'a>>,
22}
23
24impl<'a> Enum<'a> {
25 fn new(data_enum: &'a DataEnum, ident: &'a Ident, options: ContainerOptions) -> Result<Self> {
30 ensure_spanned!(
31 !data_enum.variants.is_empty(),
32 ident.span() => "cannot derive FromPyObject for empty enum"
33 );
34 let variants = data_enum
35 .variants
36 .iter()
37 .map(|variant| {
38 let mut variant_options = ContainerOptions::from_attrs(&variant.attrs)?;
39 if let Some(rename_all) = &options.rename_all {
40 ensure_spanned!(
41 variant_options.rename_all.is_none(),
42 variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
43 );
44 variant_options.rename_all = Some(rename_all.clone());
45
46 }
47 let var_ident = &variant.ident;
48 Container::new(
49 &variant.fields,
50 parse_quote!(#ident::#var_ident),
51 variant_options,
52 )
53 })
54 .collect::<Result<Vec<_>>>()?;
55
56 Ok(Enum {
57 enum_ident: ident,
58 variants,
59 })
60 }
61
62 fn build(&self, ctx: &Ctx) -> TokenStream {
64 let Ctx { pyo3_path, .. } = ctx;
65 let mut var_extracts = Vec::new();
66 let mut variant_names = Vec::new();
67 let mut error_names = Vec::new();
68
69 for var in &self.variants {
70 let struct_derive = var.build(ctx);
71 let ext = quote!({
72 let maybe_ret = || -> #pyo3_path::PyResult<Self> {
73 #struct_derive
74 }();
75
76 match maybe_ret {
77 ok @ ::std::result::Result::Ok(_) => return ok,
78 ::std::result::Result::Err(err) => err
79 }
80 });
81
82 var_extracts.push(ext);
83 variant_names.push(var.path.segments.last().unwrap().ident.to_string());
84 error_names.push(&var.err_name);
85 }
86 let ty_name = self.enum_ident.to_string();
87 quote!(
88 let errors = [
89 #(#var_extracts),*
90 ];
91 ::std::result::Result::Err(
92 #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
93 obj.py(),
94 #ty_name,
95 &[#(#variant_names),*],
96 &[#(#error_names),*],
97 &errors
98 )
99 )
100 )
101 }
102}
103
104struct NamedStructField<'a> {
105 ident: &'a syn::Ident,
106 getter: Option<FieldGetter>,
107 from_py_with: Option<FromPyWithAttribute>,
108 default: Option<DefaultAttribute>,
109}
110
111struct TupleStructField {
112 from_py_with: Option<FromPyWithAttribute>,
113}
114
115enum ContainerType<'a> {
119 Struct(Vec<NamedStructField<'a>>),
123 StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
127 Tuple(Vec<TupleStructField>),
132 TupleNewtype(Option<FromPyWithAttribute>),
136}
137
138struct Container<'a> {
142 path: syn::Path,
143 ty: ContainerType<'a>,
144 err_name: String,
145 rename_rule: Option<RenamingRule>,
146}
147
148impl<'a> Container<'a> {
149 fn new(fields: &'a Fields, path: syn::Path, options: ContainerOptions) -> Result<Self> {
153 let style = match fields {
154 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
155 ensure_spanned!(
156 options.rename_all.is_none(),
157 options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
158 );
159 let mut tuple_fields = unnamed
160 .unnamed
161 .iter()
162 .map(|field| {
163 let attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
164 ensure_spanned!(
165 attrs.getter.is_none(),
166 field.span() => "`getter` is not permitted on tuple struct elements."
167 );
168 ensure_spanned!(
169 attrs.default.is_none(),
170 field.span() => "`default` is not permitted on tuple struct elements."
171 );
172 Ok(TupleStructField {
173 from_py_with: attrs.from_py_with,
174 })
175 })
176 .collect::<Result<Vec<_>>>()?;
177
178 if tuple_fields.len() == 1 {
179 let field = tuple_fields.pop().unwrap();
182 ContainerType::TupleNewtype(field.from_py_with)
183 } else if options.transparent {
184 bail_spanned!(
185 fields.span() => "transparent structs and variants can only have 1 field"
186 );
187 } else {
188 ContainerType::Tuple(tuple_fields)
189 }
190 }
191 Fields::Named(named) if !named.named.is_empty() => {
192 let mut struct_fields = named
193 .named
194 .iter()
195 .map(|field| {
196 let ident = field
197 .ident
198 .as_ref()
199 .expect("Named fields should have identifiers");
200 let mut attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
201
202 if let Some(ref from_item_all) = options.from_item_all {
203 if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(None))
204 {
205 match replaced {
206 FieldGetter::GetItem(Some(item_name)) => {
207 attrs.getter = Some(FieldGetter::GetItem(Some(item_name)));
208 }
209 FieldGetter::GetItem(None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
210 FieldGetter::GetAttr(_) => bail_spanned!(
211 from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
212 ),
213 }
214 }
215 }
216
217 Ok(NamedStructField {
218 ident,
219 getter: attrs.getter,
220 from_py_with: attrs.from_py_with,
221 default: attrs.default,
222 })
223 })
224 .collect::<Result<Vec<_>>>()?;
225 if struct_fields.iter().all(|field| field.default.is_some()) {
226 bail_spanned!(
227 fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
228 )
229 } else if options.transparent {
230 ensure_spanned!(
231 struct_fields.len() == 1,
232 fields.span() => "transparent structs and variants can only have 1 field"
233 );
234 ensure_spanned!(
235 options.rename_all.is_none(),
236 options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
237 );
238 let field = struct_fields.pop().unwrap();
239 ensure_spanned!(
240 field.getter.is_none(),
241 field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
242 );
243 ContainerType::StructNewtype(field.ident, field.from_py_with)
244 } else {
245 ContainerType::Struct(struct_fields)
246 }
247 }
248 _ => bail_spanned!(
249 fields.span() => "cannot derive FromPyObject for empty structs and variants"
250 ),
251 };
252 let err_name = options.annotation.map_or_else(
253 || path.segments.last().unwrap().ident.to_string(),
254 |lit_str| lit_str.value(),
255 );
256
257 let v = Container {
258 path,
259 ty: style,
260 err_name,
261 rename_rule: options.rename_all.map(|v| v.value.rule),
262 };
263 Ok(v)
264 }
265
266 fn name(&self) -> String {
267 let mut value = String::new();
268 for segment in &self.path.segments {
269 if !value.is_empty() {
270 value.push_str("::");
271 }
272 value.push_str(&segment.ident.to_string());
273 }
274 value
275 }
276
277 fn build(&self, ctx: &Ctx) -> TokenStream {
279 match &self.ty {
280 ContainerType::StructNewtype(ident, from_py_with) => {
281 self.build_newtype_struct(Some(ident), from_py_with, ctx)
282 }
283 ContainerType::TupleNewtype(from_py_with) => {
284 self.build_newtype_struct(None, from_py_with, ctx)
285 }
286 ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
287 ContainerType::Struct(tups) => self.build_struct(tups, ctx),
288 }
289 }
290
291 fn build_newtype_struct(
292 &self,
293 field_ident: Option<&Ident>,
294 from_py_with: &Option<FromPyWithAttribute>,
295 ctx: &Ctx,
296 ) -> TokenStream {
297 let Ctx { pyo3_path, .. } = ctx;
298 let self_ty = &self.path;
299 let struct_name = self.name();
300 if let Some(ident) = field_ident {
301 let field_name = ident.to_string();
302 if let Some(FromPyWithAttribute {
303 kw,
304 value: expr_path,
305 }) = from_py_with
306 {
307 let extractor = quote_spanned! { kw.span =>
308 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
309 };
310 quote! {
311 Ok(#self_ty {
312 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
313 })
314 }
315 } else {
316 quote! {
317 Ok(#self_ty {
318 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
319 })
320 }
321 }
322 } else if let Some(FromPyWithAttribute {
323 kw,
324 value: expr_path,
325 }) = from_py_with
326 {
327 let extractor = quote_spanned! { kw.span =>
328 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
329 };
330 quote! {
331 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
332 }
333 } else {
334 quote! {
335 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
336 }
337 }
338 }
339
340 fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
341 let Ctx { pyo3_path, .. } = ctx;
342 let self_ty = &self.path;
343 let struct_name = &self.name();
344 let field_idents: Vec<_> = (0..struct_fields.len())
345 .map(|i| format_ident!("arg{}", i))
346 .collect();
347 let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
348 if let Some(FromPyWithAttribute {
349 kw,
350 value: expr_path, ..
351 }) = &field.from_py_with {
352 let extractor = quote_spanned! { kw.span =>
353 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
354 };
355 quote! {
356 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
357 }
358 } else {
359 quote!{
360 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
361 }}
362 });
363
364 quote!(
365 match #pyo3_path::types::PyAnyMethods::extract(obj) {
366 ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
367 ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
368 }
369 )
370 }
371
372 fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
373 let Ctx { pyo3_path, .. } = ctx;
374 let self_ty = &self.path;
375 let struct_name = self.name();
376 let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
377 for field in struct_fields {
378 let ident = field.ident;
379 let field_name = ident.unraw().to_string();
380 let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) {
381 FieldGetter::GetAttr(Some(name)) => {
382 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
383 }
384 FieldGetter::GetAttr(None) => {
385 let name = self
386 .rename_rule
387 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
388 let name = name.as_deref().unwrap_or(&field_name);
389 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
390 }
391 FieldGetter::GetItem(Some(syn::Lit::Str(key))) => {
392 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
393 }
394 FieldGetter::GetItem(Some(key)) => {
395 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
396 }
397 FieldGetter::GetItem(None) => {
398 let name = self
399 .rename_rule
400 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
401 let name = name.as_deref().unwrap_or(&field_name);
402 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
403 }
404 };
405 let extractor = if let Some(FromPyWithAttribute {
406 kw,
407 value: expr_path,
408 }) = &field.from_py_with
409 {
410 let extractor = quote_spanned! { kw.span =>
411 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
412 };
413 quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
414 } else {
415 quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
416 };
417 let extracted = if let Some(default) = &field.default {
418 let default_expr = if let Some(default_expr) = &default.value {
419 default_expr.to_token_stream()
420 } else {
421 quote!(::std::default::Default::default())
422 };
423 quote!(if let ::std::result::Result::Ok(value) = #getter {
424 #extractor
425 } else {
426 #default_expr
427 })
428 } else {
429 quote!({
430 let value = #getter?;
431 #extractor
432 })
433 };
434
435 fields.push(quote!(#ident: #extracted));
436 }
437
438 quote!(::std::result::Result::Ok(#self_ty{#fields}))
439 }
440}
441
442#[derive(Default)]
443struct ContainerOptions {
444 transparent: bool,
446 from_item_all: Option<attributes::kw::from_item_all>,
448 annotation: Option<syn::LitStr>,
450 krate: Option<CrateAttribute>,
452 rename_all: Option<RenameAllAttribute>,
454}
455
456enum ContainerPyO3Attribute {
458 Transparent(attributes::kw::transparent),
460 ItemAll(attributes::kw::from_item_all),
462 ErrorAnnotation(LitStr),
464 Crate(CrateAttribute),
466 RenameAll(RenameAllAttribute),
468}
469
470impl Parse for ContainerPyO3Attribute {
471 fn parse(input: ParseStream<'_>) -> Result<Self> {
472 let lookahead = input.lookahead1();
473 if lookahead.peek(attributes::kw::transparent) {
474 let kw: attributes::kw::transparent = input.parse()?;
475 Ok(ContainerPyO3Attribute::Transparent(kw))
476 } else if lookahead.peek(attributes::kw::from_item_all) {
477 let kw: attributes::kw::from_item_all = input.parse()?;
478 Ok(ContainerPyO3Attribute::ItemAll(kw))
479 } else if lookahead.peek(attributes::kw::annotation) {
480 let _: attributes::kw::annotation = input.parse()?;
481 let _: Token![=] = input.parse()?;
482 input.parse().map(ContainerPyO3Attribute::ErrorAnnotation)
483 } else if lookahead.peek(Token![crate]) {
484 input.parse().map(ContainerPyO3Attribute::Crate)
485 } else if lookahead.peek(attributes::kw::rename_all) {
486 input.parse().map(ContainerPyO3Attribute::RenameAll)
487 } else {
488 Err(lookahead.error())
489 }
490 }
491}
492
493impl ContainerOptions {
494 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
495 let mut options = ContainerOptions::default();
496
497 for attr in attrs {
498 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
499 for pyo3_attr in pyo3_attrs {
500 match pyo3_attr {
501 ContainerPyO3Attribute::Transparent(kw) => {
502 ensure_spanned!(
503 !options.transparent,
504 kw.span() => "`transparent` may only be provided once"
505 );
506 options.transparent = true;
507 }
508 ContainerPyO3Attribute::ItemAll(kw) => {
509 ensure_spanned!(
510 options.from_item_all.is_none(),
511 kw.span() => "`from_item_all` may only be provided once"
512 );
513 options.from_item_all = Some(kw);
514 }
515 ContainerPyO3Attribute::ErrorAnnotation(lit_str) => {
516 ensure_spanned!(
517 options.annotation.is_none(),
518 lit_str.span() => "`annotation` may only be provided once"
519 );
520 options.annotation = Some(lit_str);
521 }
522 ContainerPyO3Attribute::Crate(path) => {
523 ensure_spanned!(
524 options.krate.is_none(),
525 path.span() => "`crate` may only be provided once"
526 );
527 options.krate = Some(path);
528 }
529 ContainerPyO3Attribute::RenameAll(rename_all) => {
530 ensure_spanned!(
531 options.rename_all.is_none(),
532 rename_all.span() => "`rename_all` may only be provided once"
533 );
534 options.rename_all = Some(rename_all);
535 }
536 }
537 }
538 }
539 }
540 Ok(options)
541 }
542}
543
544#[derive(Clone, Debug)]
546struct FieldPyO3Attributes {
547 getter: Option<FieldGetter>,
548 from_py_with: Option<FromPyWithAttribute>,
549 default: Option<DefaultAttribute>,
550}
551
552#[derive(Clone, Debug)]
553enum FieldGetter {
554 GetItem(Option<syn::Lit>),
555 GetAttr(Option<LitStr>),
556}
557
558enum FieldPyO3Attribute {
559 Getter(FieldGetter),
560 FromPyWith(FromPyWithAttribute),
561 Default(DefaultAttribute),
562}
563
564impl Parse for FieldPyO3Attribute {
565 fn parse(input: ParseStream<'_>) -> Result<Self> {
566 let lookahead = input.lookahead1();
567 if lookahead.peek(attributes::kw::attribute) {
568 let _: attributes::kw::attribute = input.parse()?;
569 if input.peek(syn::token::Paren) {
570 let content;
571 let _ = parenthesized!(content in input);
572 let attr_name: LitStr = content.parse()?;
573 if !content.is_empty() {
574 return Err(content.error(
575 "expected at most one argument: `attribute` or `attribute(\"name\")`",
576 ));
577 }
578 ensure_spanned!(
579 !attr_name.value().is_empty(),
580 attr_name.span() => "attribute name cannot be empty"
581 );
582 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(Some(
583 attr_name,
584 ))))
585 } else {
586 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(None)))
587 }
588 } else if lookahead.peek(attributes::kw::item) {
589 let _: attributes::kw::item = input.parse()?;
590 if input.peek(syn::token::Paren) {
591 let content;
592 let _ = parenthesized!(content in input);
593 let key = content.parse()?;
594 if !content.is_empty() {
595 return Err(
596 content.error("expected at most one argument: `item` or `item(key)`")
597 );
598 }
599 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(Some(key))))
600 } else {
601 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(None)))
602 }
603 } else if lookahead.peek(attributes::kw::from_py_with) {
604 input.parse().map(FieldPyO3Attribute::FromPyWith)
605 } else if lookahead.peek(Token![default]) {
606 input.parse().map(FieldPyO3Attribute::Default)
607 } else {
608 Err(lookahead.error())
609 }
610 }
611}
612
613impl FieldPyO3Attributes {
614 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
616 let mut getter = None;
617 let mut from_py_with = None;
618 let mut default = None;
619
620 for attr in attrs {
621 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
622 for pyo3_attr in pyo3_attrs {
623 match pyo3_attr {
624 FieldPyO3Attribute::Getter(field_getter) => {
625 ensure_spanned!(
626 getter.is_none(),
627 attr.span() => "only one of `attribute` or `item` can be provided"
628 );
629 getter = Some(field_getter);
630 }
631 FieldPyO3Attribute::FromPyWith(from_py_with_attr) => {
632 ensure_spanned!(
633 from_py_with.is_none(),
634 attr.span() => "`from_py_with` may only be provided once"
635 );
636 from_py_with = Some(from_py_with_attr);
637 }
638 FieldPyO3Attribute::Default(default_attr) => {
639 ensure_spanned!(
640 default.is_none(),
641 attr.span() => "`default` may only be provided once"
642 );
643 default = Some(default_attr);
644 }
645 }
646 }
647 }
648 }
649
650 Ok(FieldPyO3Attributes {
651 getter,
652 from_py_with,
653 default,
654 })
655 }
656}
657
658fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
659 let mut lifetimes = generics.lifetimes();
660 let lifetime = lifetimes.next();
661 ensure_spanned!(
662 lifetimes.next().is_none(),
663 generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
664 );
665 Ok(lifetime)
666}
667
668pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
677 let options = ContainerOptions::from_attrs(&tokens.attrs)?;
678 let ctx = &Ctx::new(&options.krate, None);
679 let Ctx { pyo3_path, .. } = &ctx;
680
681 let (_, ty_generics, _) = tokens.generics.split_for_impl();
682 let mut trait_generics = tokens.generics.clone();
683 let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
684 lt.clone()
685 } else {
686 trait_generics.params.push(parse_quote!('py));
687 parse_quote!('py)
688 };
689 let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
690
691 let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
692 for param in trait_generics.type_params() {
693 let gen_ident = ¶m.ident;
694 where_clause
695 .predicates
696 .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
697 }
698
699 let derives = match &tokens.data {
700 syn::Data::Enum(en) => {
701 if options.transparent || options.annotation.is_some() {
702 bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
703 at top level for enums");
704 }
705 let en = Enum::new(en, &tokens.ident, options)?;
706 en.build(ctx)
707 }
708 syn::Data::Struct(st) => {
709 if let Some(lit_str) = &options.annotation {
710 bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
711 }
712 let ident = &tokens.ident;
713 let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
714 st.build(ctx)
715 }
716 syn::Data::Union(_) => bail_spanned!(
717 tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
718 ),
719 };
720
721 let ident = &tokens.ident;
722 Ok(quote!(
723 #[automatically_derived]
724 impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
725 fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
726 #derives
727 }
728 }
729 ))
730}