1use crate::model::{
2 Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator,
3 VariableLengthArgument,
4};
5use std::collections::{BTreeMap, BTreeSet, HashMap};
6use std::fmt::Write;
7use std::iter::once;
8use std::path::PathBuf;
9use std::str::FromStr;
10
11pub fn module_stub_files(module: &Module) -> HashMap<PathBuf, String> {
16 let mut output_files = HashMap::new();
17 add_module_stub_files(module, &[], &mut output_files);
18 output_files
19}
20
21fn add_module_stub_files(
22 module: &Module,
23 module_path: &[&str],
24 output_files: &mut HashMap<PathBuf, String>,
25) {
26 let mut file_path = PathBuf::new();
27 for e in module_path {
28 file_path = file_path.join(e);
29 }
30 output_files.insert(
31 file_path.join("__init__.pyi"),
32 module_stubs(module, module_path),
33 );
34 let mut module_path = module_path.to_vec();
35 module_path.push(&module.name);
36 for submodule in &module.modules {
37 if submodule.modules.is_empty() {
38 output_files.insert(
39 file_path.join(format!("{}.pyi", submodule.name)),
40 module_stubs(submodule, &module_path),
41 );
42 } else {
43 add_module_stub_files(submodule, &module_path, output_files);
44 }
45 }
46}
47
48fn module_stubs(module: &Module, parents: &[&str]) -> String {
50 let imports = Imports::create(module, parents);
51 let mut elements = Vec::new();
52 for attribute in &module.attributes {
53 elements.push(attribute_stubs(attribute, &imports));
54 }
55 for class in &module.classes {
56 elements.push(class_stubs(class, &imports));
57 }
58 for function in &module.functions {
59 elements.push(function_stubs(function, &imports, None));
60 }
61
62 if module.incomplete && !module.functions.iter().any(|f| f.name == "__getattr__") {
65 elements.push(function_stubs(
66 &Function {
67 name: "__getattr__".into(),
68 decorators: Vec::new(),
69 arguments: Arguments {
70 positional_only_arguments: Vec::new(),
71 arguments: vec![Argument {
72 name: "name".to_string(),
73 default_value: None,
74 annotation: Some(Expr::Name { id: "str".into() }),
75 }],
76 vararg: None,
77 keyword_only_arguments: Vec::new(),
78 kwarg: None,
79 },
80 returns: Some(Expr::Attribute {
81 value: Box::new(Expr::Name {
82 id: "_typeshed".into(),
83 }),
84 attr: "Incomplete".into(),
85 }),
86 is_async: false,
87 docstring: None,
88 },
89 &imports,
90 None,
91 ));
92 }
93
94 let mut final_elements = Vec::new();
95 if let Some(docstring) = &module.docstring {
96 final_elements.push(format!("\"\"\"\n{docstring}\n\"\"\""));
97 }
98 final_elements.extend(imports.imports);
99 final_elements.extend(elements);
100
101 let mut output = String::new();
102
103 for element in final_elements {
105 let is_multiline = element.contains('\n');
106 if is_multiline && !output.is_empty() && !output.ends_with("\n\n") {
107 output.push('\n');
108 }
109 output.push_str(&element);
110 output.push('\n');
111 if is_multiline {
112 output.push('\n');
113 }
114 }
115
116 if output.ends_with("\n\n") {
118 output.pop();
119 }
120 output
121}
122
123fn class_stubs(class: &Class, imports: &Imports) -> String {
124 let mut buffer = String::new();
125 for decorator in &class.decorators {
126 buffer.push('@');
127 imports.serialize_expr(decorator, &mut buffer);
128 buffer.push('\n');
129 }
130 buffer.push_str("class ");
131 buffer.push_str(&class.name);
132 if !class.bases.is_empty() {
133 buffer.push('(');
134 for (i, base) in class.bases.iter().enumerate() {
135 if i > 0 {
136 buffer.push_str(", ");
137 }
138 imports.serialize_expr(base, &mut buffer);
139 }
140 buffer.push(')');
141 }
142 buffer.push(':');
143 if class.docstring.is_none()
144 && class.methods.is_empty()
145 && class.attributes.is_empty()
146 && class.inner_classes.is_empty()
147 {
148 buffer.push_str(" ...");
149 }
150 if let Some(docstring) = &class.docstring {
151 buffer.push_str("\n \"\"\"");
152 for line in docstring.lines() {
153 buffer.push_str("\n ");
154 buffer.push_str(line);
155 }
156 buffer.push_str("\n \"\"\"");
157 }
158 for attribute in &class.attributes {
159 buffer.push_str("\n ");
161 buffer.push_str(&attribute_stubs(attribute, imports).replace('\n', "\n "));
162 }
163 for method in &class.methods {
164 buffer.push_str("\n ");
166 buffer
167 .push_str(&function_stubs(method, imports, Some(&class.name)).replace('\n', "\n "));
168 }
169 for inner_class in &class.inner_classes {
170 buffer.push_str("\n ");
172 buffer.push_str(&class_stubs(inner_class, imports).replace('\n', "\n "));
173 }
174 buffer
175}
176
177fn function_stubs(function: &Function, imports: &Imports, class_name: Option<&str>) -> String {
178 let mut parameters = Vec::new();
180 for argument in &function.arguments.positional_only_arguments {
181 parameters.push(argument_stub(argument, imports));
182 }
183 if !function.arguments.positional_only_arguments.is_empty() {
184 parameters.push("/".into());
185 }
186 for argument in &function.arguments.arguments {
187 parameters.push(argument_stub(argument, imports));
188 }
189 if let Some(argument) = &function.arguments.vararg {
190 parameters.push(format!(
191 "*{}",
192 variable_length_argument_stub(argument, imports)
193 ));
194 } else if !function.arguments.keyword_only_arguments.is_empty() {
195 parameters.push("*".into());
196 }
197 for argument in &function.arguments.keyword_only_arguments {
198 parameters.push(argument_stub(argument, imports));
199 }
200 if let Some(argument) = &function.arguments.kwarg {
201 parameters.push(format!(
202 "**{}",
203 variable_length_argument_stub(argument, imports)
204 ));
205 }
206 let mut buffer = String::new();
207 for decorator in &function.decorators {
208 buffer.push('@');
209 let mut decorator_buffer = String::new();
211 imports.serialize_expr(decorator, &mut decorator_buffer);
212 if let Some(class_name) = class_name {
213 if let Some(decorator) = decorator_buffer.strip_prefix(&format!("{class_name}.")) {
214 decorator_buffer = decorator.into();
215 }
216 }
217 buffer.push_str(&decorator_buffer);
218 buffer.push('\n');
219 }
220 if function.is_async {
221 buffer.push_str("async ");
222 }
223
224 buffer.push_str("def ");
225 buffer.push_str(&function.name);
226 buffer.push('(');
227 buffer.push_str(¶meters.join(", "));
228 buffer.push(')');
229 if let Some(returns) = &function.returns {
230 buffer.push_str(" -> ");
231 imports.serialize_expr(returns, &mut buffer);
232 }
233 if let Some(docstring) = &function.docstring {
234 buffer.push_str(":\n \"\"\"");
235 for line in docstring.lines() {
236 buffer.push_str("\n ");
237 buffer.push_str(line);
238 }
239 buffer.push_str("\n \"\"\"");
240 } else {
241 buffer.push_str(": ...");
242 }
243 buffer
244}
245
246fn attribute_stubs(attribute: &Attribute, imports: &Imports) -> String {
247 let mut buffer = attribute.name.clone();
248 if let Some(annotation) = &attribute.annotation {
249 buffer.push_str(": ");
250 imports.serialize_expr(annotation, &mut buffer);
251 }
252 if let Some(value) = &attribute.value {
253 buffer.push_str(" = ");
254 imports.serialize_expr(value, &mut buffer);
255 }
256 if let Some(docstring) = &attribute.docstring {
257 buffer.push_str("\n\"\"\"");
258 for line in docstring.lines() {
259 buffer.push('\n');
260 buffer.push_str(line);
261 }
262 buffer.push_str("\n\"\"\"");
263 }
264 buffer
265}
266
267fn argument_stub(argument: &Argument, imports: &Imports) -> String {
268 let mut buffer = argument.name.clone();
269 if let Some(annotation) = &argument.annotation {
270 buffer.push_str(": ");
271 imports.serialize_expr(annotation, &mut buffer);
272 }
273 if let Some(default_value) = &argument.default_value {
274 buffer.push_str(if argument.annotation.is_some() {
275 " = "
276 } else {
277 "="
278 });
279 imports.serialize_expr(default_value, &mut buffer);
280 }
281 buffer
282}
283
284fn variable_length_argument_stub(argument: &VariableLengthArgument, imports: &Imports) -> String {
285 let mut buffer = argument.name.clone();
286 if let Some(annotation) = &argument.annotation {
287 buffer.push_str(": ");
288 imports.serialize_expr(annotation, &mut buffer);
289 }
290 buffer
291}
292
293#[derive(Default)]
295struct Imports {
296 imports: Vec<String>,
298 renaming: BTreeMap<(String, String), String>,
300}
301
302impl Imports {
303 fn create(module: &Module, module_parents: &[&str]) -> Self {
311 let mut elements_used_in_annotations = ElementsUsedInAnnotations::new();
312 elements_used_in_annotations.walk_module(module);
313
314 let mut imports = Vec::new();
315 let mut renaming = BTreeMap::new();
316 let mut local_name_to_module_and_attribute = BTreeMap::new();
317
318 let current_module_name = module_parents
320 .iter()
321 .copied()
322 .chain(once(module.name.as_str()))
323 .collect::<Vec<_>>()
324 .join(".");
325
326 for name in module
328 .classes
329 .iter()
330 .map(|c| c.name.clone())
331 .chain(module.functions.iter().map(|f| f.name.clone()))
332 .chain(module.attributes.iter().map(|a| a.name.clone()))
333 {
334 local_name_to_module_and_attribute
335 .insert(name.clone(), (current_module_name.clone(), name.clone()));
336 }
337 local_name_to_module_and_attribute.remove(¤t_module_name);
339
340 for (module, attrs) in &elements_used_in_annotations.module_to_name {
342 let mut import_for_module = Vec::new();
343 for attr in attrs {
344 let (root_attr, attr_path) = attr
346 .split_once('.')
347 .map_or((attr.as_str(), None), |(root, path)| (root, Some(path)));
348 let mut local_name = root_attr.to_owned();
349 let mut already_imported = false;
350 while let Some((possible_conflict_module, possible_conflict_attr)) =
351 local_name_to_module_and_attribute.get(&local_name)
352 {
353 if possible_conflict_module == module && *possible_conflict_attr == root_attr {
354 already_imported = true;
356 break;
357 }
358 let number_of_digits_at_the_end = local_name
361 .bytes()
362 .rev()
363 .take_while(|b| b.is_ascii_digit())
364 .count();
365 let (local_name_prefix, local_name_number) =
366 local_name.split_at(local_name.len() - number_of_digits_at_the_end);
367 local_name = format!(
368 "{local_name_prefix}{}",
369 u64::from_str(local_name_number).unwrap_or(1) + 1
370 );
371 }
372 renaming.insert(
373 (module.clone(), attr.clone()),
374 if let Some(attr_path) = attr_path {
375 format!("{local_name}.{attr_path}")
376 } else {
377 local_name.clone()
378 },
379 );
380 if !already_imported {
381 local_name_to_module_and_attribute
382 .insert(local_name.clone(), (module.clone(), root_attr.to_owned()));
383 let is_not_aliased_builtin = module == "builtins" && local_name == root_attr;
384 if !is_not_aliased_builtin {
385 import_for_module.push(if local_name == root_attr {
386 local_name
387 } else {
388 format!("{root_attr} as {local_name}")
389 });
390 }
391 }
392 }
393 if !import_for_module.is_empty() {
394 imports.push(format!(
395 "from {module} import {}",
396 import_for_module.join(", ")
397 ));
398 }
399 }
400 imports.sort(); Self { imports, renaming }
403 }
404
405 fn serialize_expr(&self, expr: &Expr, buffer: &mut String) {
406 match expr {
407 Expr::Constant { value } => match value {
408 Constant::None => buffer.push_str("None"),
409 Constant::Bool(value) => buffer.push_str(if *value { "True" } else { "False" }),
410 Constant::Int(value) => buffer.push_str(value),
411 Constant::Float(value) => {
412 buffer.push_str(value);
413 if !value.contains(['.', 'e', 'E']) {
414 buffer.push('.'); }
416 }
417 Constant::Str(value) => {
418 buffer.push('"');
419 for c in value.chars() {
420 match c {
421 '"' => buffer.push_str("\\\""),
422 '\n' => buffer.push_str("\\n"),
423 '\r' => buffer.push_str("\\r"),
424 '\t' => buffer.push_str("\\t"),
425 '\\' => buffer.push_str("\\\\"),
426 '\0' => buffer.push_str("\\0"),
427 c @ '\x00'..'\x20' => {
428 write!(buffer, "\\x{:02x}", u32::from(c)).unwrap()
429 }
430 c => buffer.push(c),
431 }
432 }
433 buffer.push('"');
434 }
435 Constant::Ellipsis => buffer.push_str("..."),
436 },
437 Expr::Name { id } => {
438 buffer.push_str(
439 self.renaming
440 .get(&("builtins".into(), id.clone()))
441 .expect("All type hint attributes should have been visited"),
442 );
443 }
444 Expr::Attribute { value, attr } => {
445 if let Expr::Name { id, .. } = &**value {
446 buffer.push_str(
447 self.renaming
448 .get(&(id.clone(), attr.clone()))
449 .expect("All type hint attributes should have been visited"),
450 );
451 } else {
452 self.serialize_expr(value, buffer);
453 buffer.push('.');
454 buffer.push_str(attr);
455 }
456 }
457 Expr::BinOp { left, op, right } => {
458 self.serialize_expr(left, buffer);
459 buffer.push(' ');
460 buffer.push(match op {
461 Operator::BitOr => '|',
462 });
463 self.serialize_expr(right, buffer);
464 }
465 Expr::Tuple { elts } => {
466 buffer.push('(');
467 self.serialize_elts(elts, buffer);
468 if elts.len() == 1 {
469 buffer.push(',');
470 }
471 buffer.push(')')
472 }
473 Expr::List { elts } => {
474 buffer.push('[');
475 self.serialize_elts(elts, buffer);
476 buffer.push(']')
477 }
478 Expr::Subscript { value, slice } => {
479 self.serialize_expr(value, buffer);
480 buffer.push('[');
481 if let Expr::Tuple { elts } = &**slice {
482 self.serialize_elts(elts, buffer);
484 } else {
485 self.serialize_expr(slice, buffer);
486 }
487 buffer.push(']');
488 }
489 }
490 }
491
492 fn serialize_elts(&self, elts: &[Expr], buffer: &mut String) {
493 for (i, elt) in elts.iter().enumerate() {
494 if i > 0 {
495 buffer.push_str(", ");
496 }
497 self.serialize_expr(elt, buffer);
498 }
499 }
500}
501
502struct ElementsUsedInAnnotations {
504 module_to_name: BTreeMap<String, BTreeSet<String>>,
506}
507
508impl ElementsUsedInAnnotations {
509 fn new() -> Self {
510 Self {
511 module_to_name: BTreeMap::new(),
512 }
513 }
514
515 fn walk_module(&mut self, module: &Module) {
516 for attr in &module.attributes {
517 self.walk_attribute(attr);
518 }
519 for class in &module.classes {
520 self.walk_class(class);
521 }
522 for function in &module.functions {
523 self.walk_function(function);
524 }
525 if module.incomplete {
526 self.module_to_name
527 .entry("builtins".into())
528 .or_default()
529 .insert("str".into());
530 self.module_to_name
531 .entry("_typeshed".into())
532 .or_default()
533 .insert("Incomplete".into());
534 }
535 }
536
537 fn walk_class(&mut self, class: &Class) {
538 for base in &class.bases {
539 self.walk_expr(base);
540 }
541 for decorator in &class.decorators {
542 self.walk_expr(decorator);
543 }
544 for method in &class.methods {
545 self.walk_function(method);
546 }
547 for attr in &class.attributes {
548 self.walk_attribute(attr);
549 }
550 for class in &class.inner_classes {
551 self.walk_class(class);
552 }
553 }
554
555 fn walk_attribute(&mut self, attribute: &Attribute) {
556 if let Some(type_hint) = &attribute.annotation {
557 self.walk_expr(type_hint);
558 }
559 }
560
561 fn walk_function(&mut self, function: &Function) {
562 for decorator in &function.decorators {
563 self.walk_expr(decorator);
564 }
565 for arg in function
566 .arguments
567 .positional_only_arguments
568 .iter()
569 .chain(&function.arguments.arguments)
570 .chain(&function.arguments.keyword_only_arguments)
571 {
572 if let Some(type_hint) = &arg.annotation {
573 self.walk_expr(type_hint);
574 }
575 }
576 for arg in function
577 .arguments
578 .vararg
579 .as_ref()
580 .iter()
581 .chain(&function.arguments.kwarg.as_ref())
582 {
583 if let Some(type_hint) = &arg.annotation {
584 self.walk_expr(type_hint);
585 }
586 }
587 if let Some(type_hint) = &function.returns {
588 self.walk_expr(type_hint);
589 }
590 }
591
592 fn walk_expr(&mut self, expr: &Expr) {
593 match expr {
594 Expr::Name { id } => {
595 self.module_to_name
596 .entry("builtins".into())
597 .or_default()
598 .insert(id.clone());
599 }
600 Expr::Attribute { value, attr } => {
601 if let Expr::Name { id } = &**value {
602 self.module_to_name
603 .entry(id.into())
604 .or_default()
605 .insert(attr.clone());
606 } else {
607 self.walk_expr(value)
608 }
609 }
610 Expr::BinOp { left, right, .. } => {
611 self.walk_expr(left);
612 self.walk_expr(right);
613 }
614 Expr::Subscript { value, slice } => {
615 self.walk_expr(value);
616 self.walk_expr(slice);
617 }
618 Expr::Tuple { elts } | Expr::List { elts } => {
619 for elt in elts {
620 self.walk_expr(elt)
621 }
622 }
623 Expr::Constant { .. } => (),
624 }
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631 use crate::model::Arguments;
632
633 #[test]
634 fn function_stubs_with_variable_length() {
635 let function = Function {
636 name: "func".into(),
637 decorators: Vec::new(),
638 arguments: Arguments {
639 positional_only_arguments: vec![Argument {
640 name: "posonly".into(),
641 default_value: None,
642 annotation: None,
643 }],
644 arguments: vec![Argument {
645 name: "arg".into(),
646 default_value: None,
647 annotation: None,
648 }],
649 vararg: Some(VariableLengthArgument {
650 name: "varargs".into(),
651 annotation: None,
652 }),
653 keyword_only_arguments: vec![Argument {
654 name: "karg".into(),
655 default_value: None,
656 annotation: Some(Expr::Constant {
657 value: Constant::Str("str".into()),
658 }),
659 }],
660 kwarg: Some(VariableLengthArgument {
661 name: "kwarg".into(),
662 annotation: Some(Expr::Constant {
663 value: Constant::Str("str".into()),
664 }),
665 }),
666 },
667 returns: Some(Expr::Constant {
668 value: Constant::Str("list[str]".into()),
669 }),
670 is_async: false,
671 docstring: None,
672 };
673 assert_eq!(
674 "def func(posonly, /, arg, *varargs, karg: \"str\", **kwarg: \"str\") -> \"list[str]\": ...",
675 function_stubs(&function, &Imports::default(), None)
676 )
677 }
678
679 #[test]
680 fn function_stubs_without_variable_length() {
681 let function = Function {
682 name: "afunc".into(),
683 decorators: Vec::new(),
684 arguments: Arguments {
685 positional_only_arguments: vec![Argument {
686 name: "posonly".into(),
687 default_value: Some(Expr::Constant {
688 value: Constant::Int("1".into()),
689 }),
690 annotation: None,
691 }],
692 arguments: vec![Argument {
693 name: "arg".into(),
694 default_value: Some(Expr::Constant {
695 value: Constant::Bool(true),
696 }),
697 annotation: None,
698 }],
699 vararg: None,
700 keyword_only_arguments: vec![Argument {
701 name: "karg".into(),
702 default_value: Some(Expr::Constant {
703 value: Constant::Str("foo".into()),
704 }),
705 annotation: Some(Expr::Constant {
706 value: Constant::Str("str".into()),
707 }),
708 }],
709 kwarg: None,
710 },
711 returns: None,
712 is_async: false,
713 docstring: None,
714 };
715 assert_eq!(
716 "def afunc(posonly=1, /, arg=True, *, karg: \"str\" = \"foo\"): ...",
717 function_stubs(&function, &Imports::default(), None)
718 )
719 }
720
721 #[test]
722 fn test_function_async() {
723 let function = Function {
724 name: "foo".into(),
725 decorators: Vec::new(),
726 arguments: Arguments {
727 positional_only_arguments: Vec::new(),
728 arguments: Vec::new(),
729 vararg: None,
730 keyword_only_arguments: Vec::new(),
731 kwarg: None,
732 },
733 returns: None,
734 is_async: true,
735 docstring: None,
736 };
737 assert_eq!(
738 "async def foo(): ...",
739 function_stubs(&function, &Imports::default(), None)
740 )
741 }
742
743 #[test]
744 fn test_import() {
745 let big_type = Expr::Subscript {
746 value: Box::new(Expr::Name { id: "dict".into() }),
747 slice: Box::new(Expr::Tuple {
748 elts: vec![
749 Expr::Attribute {
750 value: Box::new(Expr::Name {
751 id: "foo.bar".into(),
752 }),
753 attr: "A".into(),
754 },
755 Expr::Tuple {
756 elts: vec![
757 Expr::Attribute {
758 value: Box::new(Expr::Name { id: "foo".into() }),
759 attr: "A.C".into(),
760 },
761 Expr::Attribute {
762 value: Box::new(Expr::Attribute {
763 value: Box::new(Expr::Name { id: "foo".into() }),
764 attr: "A".into(),
765 }),
766 attr: "D".into(),
767 },
768 Expr::Attribute {
769 value: Box::new(Expr::Name { id: "foo".into() }),
770 attr: "B".into(),
771 },
772 Expr::Attribute {
773 value: Box::new(Expr::Name { id: "bat".into() }),
774 attr: "A".into(),
775 },
776 Expr::Attribute {
777 value: Box::new(Expr::Name {
778 id: "foo.bar".into(),
779 }),
780 attr: "int".into(),
781 },
782 Expr::Name { id: "int".into() },
783 Expr::Name { id: "float".into() },
784 ],
785 },
786 ],
787 }),
788 };
789 let imports = Imports::create(
790 &Module {
791 name: "bar".into(),
792 modules: Vec::new(),
793 classes: vec![
794 Class {
795 name: "A".into(),
796 bases: vec![Expr::Name { id: "dict".into() }],
797 methods: Vec::new(),
798 attributes: Vec::new(),
799 decorators: vec![Expr::Attribute {
800 value: Box::new(Expr::Name {
801 id: "typing".into(),
802 }),
803 attr: "final".into(),
804 }],
805 inner_classes: Vec::new(),
806 docstring: None,
807 },
808 Class {
809 name: "int".into(),
810 bases: Vec::new(),
811 methods: Vec::new(),
812 attributes: Vec::new(),
813 decorators: Vec::new(),
814 inner_classes: Vec::new(),
815 docstring: None,
816 },
817 ],
818 functions: vec![Function {
819 name: String::new(),
820 decorators: Vec::new(),
821 arguments: Arguments {
822 positional_only_arguments: Vec::new(),
823 arguments: Vec::new(),
824 vararg: None,
825 keyword_only_arguments: Vec::new(),
826 kwarg: None,
827 },
828 returns: Some(big_type.clone()),
829 is_async: false,
830 docstring: None,
831 }],
832 attributes: Vec::new(),
833 incomplete: true,
834 docstring: None,
835 },
836 &["foo"],
837 );
838 assert_eq!(
839 &imports.imports,
840 &[
841 "from _typeshed import Incomplete",
842 "from bat import A as A2",
843 "from builtins import int as int2",
844 "from foo import A as A3, B",
845 "from typing import final"
846 ]
847 );
848 let mut output = String::new();
849 imports.serialize_expr(&big_type, &mut output);
850 assert_eq!(output, "dict[A, (A3.C, A3.D, B, A2, int, int2, float)]");
851 }
852}