@@ -31,7 +31,7 @@ def bad_name(name: str) -> bool:
3131
3232
3333class BaseModel (pydantic .BaseModel ):
34- def __repr_args__ (self ) -> pydantic .ReprArgs :
34+ def __repr_args__ (self ) -> pydantic .ReprArgs : # type: ignore
3535 for k , v in super ().__repr_args__ ():
3636 if v :
3737 yield k , v
@@ -62,12 +62,25 @@ def json(self, **kwargs) -> str:
6262
6363
6464class Module (BaseModel ):
65+ function_overloads : typing .Dict [str , typing .List [Signature ]] = pydantic .Field (
66+ default_factory = dict
67+ )
6568 functions : typing .Dict [str , Signature ] = pydantic .Field (default_factory = dict )
6669 classes : typing .Dict [str , Class ] = pydantic .Field (default_factory = dict )
6770 properties : typing .Dict [str , typing .Tuple [Metadata , Type ]] = pydantic .Field (
6871 default_factory = dict
6972 )
7073
74+ @pydantic .root_validator
75+ def set_overloads (cls , values ):
76+ """
77+ Set the overloads to be the values, if they are not set
78+ """
79+ values ["function_overloads" ] = values .get ("function_overloads" , {}) or {
80+ k : create_overloads (v ) for k , v in values ["functions" ].items ()
81+ }
82+ return values
83+
7184 @property
7285 def source (self ) -> str :
7386 # try:
@@ -88,10 +101,8 @@ def body(
88101 )
89102 yield from assign_properties (self .properties )
90103
91- for name , sig in sort_items (self .functions ):
92- if bad_name (name ):
93- continue
94- yield sig .function_def (name , "function" )
104+ yield from function_defs (self .function_overloads , self .functions , "function" )
105+
95106 for name , class_ in sort_items (self .classes ):
96107 yield class_ .class_def (name )
97108
@@ -107,18 +118,37 @@ def __ior__(self, other: Module) -> Module:
107118 update_ior (self .functions , other .functions )
108119 # property -> function
109120 merge_intersection (self .functions , self .properties , merge_property_into_method )
121+ # function overloads
122+ update_overloads (self .function_overloads , other .function_overloads )
110123
111124 # classes
112125 update_ior (self .classes , other .classes )
113126 # function -> class constructor
114127 merge_intersection (self .classes , self .functions , merge_method_class )
128+ merge_intersection (
129+ self .classes , self .function_overloads , merge_method_overloads_class
130+ )
115131
116132 return self
117133
118134
135+ def create_overloads (s : Signature ) -> typing .List [Signature ]:
136+ """
137+ Copies signature to create overloads. Needs to copy because we inplace update the signature later on
138+ """
139+ return [s .copy (deep = True )]
140+
141+
119142class Class (BaseModel ):
143+ constructor_overloads : typing .List [Signature ] = pydantic .Field (default_factory = list )
120144 constructor : typing .Union [Signature , None ] = None
145+ method_overloads : typing .Dict [str , typing .List [Signature ]] = pydantic .Field (
146+ default_factory = dict
147+ )
121148 methods : typing .Dict [str , Signature ] = pydantic .Field (default_factory = dict )
149+ classmethod_overloads : typing .Dict [str , typing .List [Signature ]] = pydantic .Field (
150+ default_factory = dict
151+ )
122152 classmethods : typing .Dict [str , Signature ] = pydantic .Field (default_factory = dict )
123153 properties : typing .Dict [str , typing .Tuple [Metadata , Type ]] = pydantic .Field (
124154 default_factory = dict
@@ -127,26 +157,48 @@ class Class(BaseModel):
127157 default_factory = dict
128158 )
129159
160+ @pydantic .root_validator
161+ def set_overloads (cls , values ):
162+ """
163+ Set the overloads to be the values, if not passed in.
164+ """
165+ if values ["constructor" ] and not values ["constructor_overloads" ]:
166+ values ["constructor_overloads" ] = create_overloads (values ["constructor" ])
167+ values ["method_overloads" ] = values .get ("method_overloads" , {}) or {
168+ k : create_overloads (v ) for k , v in values ["methods" ].items ()
169+ }
170+ values ["classmethod_overloads" ] = values .get ("classmethod_overloads" , {}) or {
171+ k : create_overloads (v ) for k , v in values ["classmethods" ].items ()
172+ }
173+ return values
174+
130175 def class_def (self , name : str ) -> cst .ClassDef :
131176 return cst .ClassDef (cst .Name (name ), cst .IndentedBlock (list (self .body )),)
132177
133178 @property
134179 def body (self ) -> typing .Iterable [cst .BaseStatement ]:
135180 if self .constructor is not None :
136- yield self .constructor .function_def ("__init__" , "method" , indent = 1 )
181+ function_def_with_overloads (
182+ "__init__" ,
183+ self .constructor_overloads ,
184+ self .constructor ,
185+ "method" ,
186+ indent = 1 ,
187+ )
137188 yield from assign_properties (self .classproperties , True )
138189
139- for name , sig in sort_items (self .classmethods ):
140- if bad_name (name ):
141- continue
142- yield sig .function_def (name , "classmethod" , indent = 1 )
143-
190+ yield from function_defs (
191+ self .classmethod_overloads , self .classmethods , "classmethod" , 1
192+ )
144193 yield from assign_properties (self .properties )
145194
146- for name , sig in sort_items (self .methods ):
147- if bad_name (name ):
148- continue
149- yield sig .function_def (name , "method" , indent = 1 )
195+ yield from function_defs (self .method_overloads , self .methods , "method" , 1 )
196+
197+ def ior_constructor (self , other : typing .Optional [Signature ]) -> None :
198+ if self .constructor and other :
199+ self .constructor |= other
200+ else :
201+ self .constructor = other
150202
151203 def __ior__ (self , other : Class ) -> Class :
152204 """
@@ -159,10 +211,9 @@ def __ior__(self, other: Class) -> Class:
159211
160212 * (implicit) property -> classmethod
161213 """
162- if self .constructor and other .constructor :
163- self .constructor |= other .constructor
164- else :
165- self .constructor = other .constructor
214+ self .ior_constructor (other .constructor )
215+ merge_overloads (self .constructor_overloads , other .constructor_overloads )
216+
166217 # properties
167218 update_metadata_and_types (self .properties , other .properties )
168219
@@ -177,11 +228,19 @@ def __ior__(self, other: Class) -> Class:
177228 update_ior (self .methods , other .methods )
178229 # property -> method
179230 merge_intersection (self .methods , self .properties , merge_property_into_method )
231+ # method overloads
232+ update_overloads (self .method_overloads , other .method_overloads )
180233
181234 # class methods
182235 update_ior (self .classmethods , other .classmethods )
183236 # method -> classmethod
184237 merge_intersection (self .classmethods , self .methods , operator .ior )
238+ # classmethod overloads
239+ merge_intersection (
240+ self .classmethod_overloads , self .method_overloads , merge_overloads
241+ )
242+ update_overloads (self .classmethod_overloads , other .classmethod_overloads )
243+
185244 # classproperty -> classmethod
186245 merge_intersection (
187246 self .classmethods , self .classproperties , merge_property_into_method
@@ -226,6 +285,37 @@ def sort_items(d: typing.Dict[str, V]) -> typing.Iterable[typing.Tuple[str, V]]:
226285PartialKeyOrdering = typing .List [typing .Tuple [str , str ]]
227286
228287
288+ def function_defs (
289+ overloads : typing .Dict [str , typing .List [Signature ]],
290+ functions : typing .Dict [str , Signature ],
291+ type : typing .Literal ["function" , "classmethod" , "method" ],
292+ indent : int = 0 ,
293+ ) -> typing .Iterable [cst .FunctionDef ]:
294+ """
295+ Exports CST for functions plus overloads
296+ """
297+ for name , sig in sort_items (functions ):
298+ if bad_name (name ):
299+ continue
300+ yield from function_def_with_overloads (
301+ name , overloads .get (name , []), sig , type , indent
302+ )
303+
304+
305+ def function_def_with_overloads (
306+ name : str ,
307+ overloads : typing .List [Signature ],
308+ fn : Signature ,
309+ type : typing .Literal ["function" , "classmethod" , "method" ],
310+ indent : int = 0 ,
311+ ) -> typing .Iterable [cst .FunctionDef ]:
312+ # Don't print overloads if just one!
313+ if len (overloads ) > 1 :
314+ for overload in overloads :
315+ yield overload .function_def (name , type , indent = indent , overload = True )
316+ yield fn .function_def (name , type , indent = indent )
317+
318+
229319class Signature (BaseModel ):
230320 # See for a helpful spec https://www.python.org/dev/peps/pep-0570/#syntax-and-semantics
231321 # Also keyword only PEP https://www.python.org/dev/peps/pep-3102/
@@ -292,14 +382,20 @@ def function_def(
292382 name : str ,
293383 type : typing .Literal ["function" , "classmethod" , "method" ],
294384 indent = 0 ,
385+ overload = False ,
295386 ) -> cst .FunctionDef :
387+ decorators : typing .List [cst .Decorator ] = []
388+ if overload :
389+ decorators .append (cst .Decorator (cst .Name ("overload" )))
390+ if type == "classmethod" :
391+ decorators .append (cst .Decorator (cst .Name ("classmethod" )))
296392 return cst .FunctionDef (
297393 cst .Name (name ),
298394 self .parameters (type ),
299395 cst .IndentedBlock (
300396 [cst .SimpleStatementLine ([s ]) for s in self .body (indent )]
301397 ),
302- [ cst . Decorator ( cst . Name ( "classmethod" ))] if type == "classmethod" else [] ,
398+ decorators ,
303399 )
304400
305401 def body (self , indent : int ) -> typing .Iterable [cst .BaseSmallStatement ]:
@@ -413,6 +509,17 @@ def from_bound_params(
413509 ),
414510 )
415511
512+ def content_equal (self , other : Signature ) -> bool :
513+ """
514+ Returns true if all fields besides metadata are equal
515+ """
516+ for field in self .__fields__ .keys ():
517+ if field == "metadata" :
518+ continue
519+ if getattr (self , field ) != getattr (other , field ):
520+ return False
521+ return True
522+
416523 def __ior__ (self , other : Signature ) -> Signature :
417524 self ._copy_pos_only (other )
418525 self ._copy_pos_or_kw (other )
@@ -696,7 +803,12 @@ def merge_property_into_method(
696803
697804
698805def merge_method_class (l : Class , r : Signature ) -> Class :
699- l |= Class (constructor = r )
806+ l .ior_constructor (r )
807+ return l
808+
809+
810+ def merge_method_overloads_class (l : Class , r : typing .List [Signature ]) -> Class :
811+ merge_overloads (l .constructor_overloads , r )
700812 return l
701813
702814
@@ -710,6 +822,27 @@ def merge_methods_properties(
710822update_metadata_and_types = functools .partial (update , f = merge_methods_properties )
711823
712824
825+ def merge_overloads (
826+ old : typing .List [Signature ], new : typing .List [Signature ]
827+ ) -> typing .List [Signature ]:
828+ """
829+ Merges the two lists of overloads, updating the old list. Any signatures that match, we take the union of their metadata.
830+ """
831+ for new_sig in new :
832+ # Iterate through all the old signatures, if we find one which matches, update that and break
833+ for old_sig in old :
834+ if new_sig .content_equal (old_sig ):
835+ update_add (old_sig .metadata , new_sig .metadata )
836+ break
837+ # Otherwise, we didn't find one, so add the new one to the old
838+ else :
839+ old .append (new_sig )
840+ return old
841+
842+
843+ update_overloads = functools .partial (update , f = merge_overloads )
844+
845+
713846T = typing .TypeVar ("T" )
714847
715848# Set theoretic functions on lists
0 commit comments