@@ -31,7 +31,7 @@ def bad_name(name: str) -> bool:
31
31
32
32
33
33
class BaseModel (pydantic .BaseModel ):
34
- def __repr_args__ (self ) -> pydantic .ReprArgs :
34
+ def __repr_args__ (self ) -> pydantic .ReprArgs : # type: ignore
35
35
for k , v in super ().__repr_args__ ():
36
36
if v :
37
37
yield k , v
@@ -62,12 +62,25 @@ def json(self, **kwargs) -> str:
62
62
63
63
64
64
class Module (BaseModel ):
65
+ function_overloads : typing .Dict [str , typing .List [Signature ]] = pydantic .Field (
66
+ default_factory = dict
67
+ )
65
68
functions : typing .Dict [str , Signature ] = pydantic .Field (default_factory = dict )
66
69
classes : typing .Dict [str , Class ] = pydantic .Field (default_factory = dict )
67
70
properties : typing .Dict [str , typing .Tuple [Metadata , Type ]] = pydantic .Field (
68
71
default_factory = dict
69
72
)
70
73
74
+ @pydantic .root_validator
75
+ def set_overloads (cls , values ):
76
+ """
77
+ Set the overloads to be the values, if they are not set
78
+ """
79
+ values ["function_overloads" ] = values .get ("function_overloads" , {}) or {
80
+ k : create_overloads (v ) for k , v in values ["functions" ].items ()
81
+ }
82
+ return values
83
+
71
84
@property
72
85
def source (self ) -> str :
73
86
# try:
@@ -88,10 +101,8 @@ def body(
88
101
)
89
102
yield from assign_properties (self .properties )
90
103
91
- for name , sig in sort_items (self .functions ):
92
- if bad_name (name ):
93
- continue
94
- yield sig .function_def (name , "function" )
104
+ yield from function_defs (self .function_overloads , self .functions , "function" )
105
+
95
106
for name , class_ in sort_items (self .classes ):
96
107
yield class_ .class_def (name )
97
108
@@ -107,18 +118,37 @@ def __ior__(self, other: Module) -> Module:
107
118
update_ior (self .functions , other .functions )
108
119
# property -> function
109
120
merge_intersection (self .functions , self .properties , merge_property_into_method )
121
+ # function overloads
122
+ update_overloads (self .function_overloads , other .function_overloads )
110
123
111
124
# classes
112
125
update_ior (self .classes , other .classes )
113
126
# function -> class constructor
114
127
merge_intersection (self .classes , self .functions , merge_method_class )
128
+ merge_intersection (
129
+ self .classes , self .function_overloads , merge_method_overloads_class
130
+ )
115
131
116
132
return self
117
133
118
134
135
+ def create_overloads (s : Signature ) -> typing .List [Signature ]:
136
+ """
137
+ Copies signature to create overloads. Needs to copy because we inplace update the signature later on
138
+ """
139
+ return [s .copy (deep = True )]
140
+
141
+
119
142
class Class (BaseModel ):
143
+ constructor_overloads : typing .List [Signature ] = pydantic .Field (default_factory = list )
120
144
constructor : typing .Union [Signature , None ] = None
145
+ method_overloads : typing .Dict [str , typing .List [Signature ]] = pydantic .Field (
146
+ default_factory = dict
147
+ )
121
148
methods : typing .Dict [str , Signature ] = pydantic .Field (default_factory = dict )
149
+ classmethod_overloads : typing .Dict [str , typing .List [Signature ]] = pydantic .Field (
150
+ default_factory = dict
151
+ )
122
152
classmethods : typing .Dict [str , Signature ] = pydantic .Field (default_factory = dict )
123
153
properties : typing .Dict [str , typing .Tuple [Metadata , Type ]] = pydantic .Field (
124
154
default_factory = dict
@@ -127,26 +157,48 @@ class Class(BaseModel):
127
157
default_factory = dict
128
158
)
129
159
160
+ @pydantic .root_validator
161
+ def set_overloads (cls , values ):
162
+ """
163
+ Set the overloads to be the values, if not passed in.
164
+ """
165
+ if values ["constructor" ] and not values ["constructor_overloads" ]:
166
+ values ["constructor_overloads" ] = create_overloads (values ["constructor" ])
167
+ values ["method_overloads" ] = values .get ("method_overloads" , {}) or {
168
+ k : create_overloads (v ) for k , v in values ["methods" ].items ()
169
+ }
170
+ values ["classmethod_overloads" ] = values .get ("classmethod_overloads" , {}) or {
171
+ k : create_overloads (v ) for k , v in values ["classmethods" ].items ()
172
+ }
173
+ return values
174
+
130
175
def class_def (self , name : str ) -> cst .ClassDef :
131
176
return cst .ClassDef (cst .Name (name ), cst .IndentedBlock (list (self .body )),)
132
177
133
178
@property
134
179
def body (self ) -> typing .Iterable [cst .BaseStatement ]:
135
180
if self .constructor is not None :
136
- yield self .constructor .function_def ("__init__" , "method" , indent = 1 )
181
+ function_def_with_overloads (
182
+ "__init__" ,
183
+ self .constructor_overloads ,
184
+ self .constructor ,
185
+ "method" ,
186
+ indent = 1 ,
187
+ )
137
188
yield from assign_properties (self .classproperties , True )
138
189
139
- for name , sig in sort_items (self .classmethods ):
140
- if bad_name (name ):
141
- continue
142
- yield sig .function_def (name , "classmethod" , indent = 1 )
143
-
190
+ yield from function_defs (
191
+ self .classmethod_overloads , self .classmethods , "classmethod" , 1
192
+ )
144
193
yield from assign_properties (self .properties )
145
194
146
- for name , sig in sort_items (self .methods ):
147
- if bad_name (name ):
148
- continue
149
- yield sig .function_def (name , "method" , indent = 1 )
195
+ yield from function_defs (self .method_overloads , self .methods , "method" , 1 )
196
+
197
+ def ior_constructor (self , other : typing .Optional [Signature ]) -> None :
198
+ if self .constructor and other :
199
+ self .constructor |= other
200
+ else :
201
+ self .constructor = other
150
202
151
203
def __ior__ (self , other : Class ) -> Class :
152
204
"""
@@ -159,10 +211,9 @@ def __ior__(self, other: Class) -> Class:
159
211
160
212
* (implicit) property -> classmethod
161
213
"""
162
- if self .constructor and other .constructor :
163
- self .constructor |= other .constructor
164
- else :
165
- self .constructor = other .constructor
214
+ self .ior_constructor (other .constructor )
215
+ merge_overloads (self .constructor_overloads , other .constructor_overloads )
216
+
166
217
# properties
167
218
update_metadata_and_types (self .properties , other .properties )
168
219
@@ -177,11 +228,19 @@ def __ior__(self, other: Class) -> Class:
177
228
update_ior (self .methods , other .methods )
178
229
# property -> method
179
230
merge_intersection (self .methods , self .properties , merge_property_into_method )
231
+ # method overloads
232
+ update_overloads (self .method_overloads , other .method_overloads )
180
233
181
234
# class methods
182
235
update_ior (self .classmethods , other .classmethods )
183
236
# method -> classmethod
184
237
merge_intersection (self .classmethods , self .methods , operator .ior )
238
+ # classmethod overloads
239
+ merge_intersection (
240
+ self .classmethod_overloads , self .method_overloads , merge_overloads
241
+ )
242
+ update_overloads (self .classmethod_overloads , other .classmethod_overloads )
243
+
185
244
# classproperty -> classmethod
186
245
merge_intersection (
187
246
self .classmethods , self .classproperties , merge_property_into_method
@@ -226,6 +285,37 @@ def sort_items(d: typing.Dict[str, V]) -> typing.Iterable[typing.Tuple[str, V]]:
226
285
PartialKeyOrdering = typing .List [typing .Tuple [str , str ]]
227
286
228
287
288
+ def function_defs (
289
+ overloads : typing .Dict [str , typing .List [Signature ]],
290
+ functions : typing .Dict [str , Signature ],
291
+ type : typing .Literal ["function" , "classmethod" , "method" ],
292
+ indent : int = 0 ,
293
+ ) -> typing .Iterable [cst .FunctionDef ]:
294
+ """
295
+ Exports CST for functions plus overloads
296
+ """
297
+ for name , sig in sort_items (functions ):
298
+ if bad_name (name ):
299
+ continue
300
+ yield from function_def_with_overloads (
301
+ name , overloads .get (name , []), sig , type , indent
302
+ )
303
+
304
+
305
+ def function_def_with_overloads (
306
+ name : str ,
307
+ overloads : typing .List [Signature ],
308
+ fn : Signature ,
309
+ type : typing .Literal ["function" , "classmethod" , "method" ],
310
+ indent : int = 0 ,
311
+ ) -> typing .Iterable [cst .FunctionDef ]:
312
+ # Don't print overloads if just one!
313
+ if len (overloads ) > 1 :
314
+ for overload in overloads :
315
+ yield overload .function_def (name , type , indent = indent , overload = True )
316
+ yield fn .function_def (name , type , indent = indent )
317
+
318
+
229
319
class Signature (BaseModel ):
230
320
# See for a helpful spec https://www.python.org/dev/peps/pep-0570/#syntax-and-semantics
231
321
# Also keyword only PEP https://www.python.org/dev/peps/pep-3102/
@@ -292,14 +382,20 @@ def function_def(
292
382
name : str ,
293
383
type : typing .Literal ["function" , "classmethod" , "method" ],
294
384
indent = 0 ,
385
+ overload = False ,
295
386
) -> cst .FunctionDef :
387
+ decorators : typing .List [cst .Decorator ] = []
388
+ if overload :
389
+ decorators .append (cst .Decorator (cst .Name ("overload" )))
390
+ if type == "classmethod" :
391
+ decorators .append (cst .Decorator (cst .Name ("classmethod" )))
296
392
return cst .FunctionDef (
297
393
cst .Name (name ),
298
394
self .parameters (type ),
299
395
cst .IndentedBlock (
300
396
[cst .SimpleStatementLine ([s ]) for s in self .body (indent )]
301
397
),
302
- [ cst . Decorator ( cst . Name ( "classmethod" ))] if type == "classmethod" else [] ,
398
+ decorators ,
303
399
)
304
400
305
401
def body (self , indent : int ) -> typing .Iterable [cst .BaseSmallStatement ]:
@@ -413,6 +509,17 @@ def from_bound_params(
413
509
),
414
510
)
415
511
512
+ def content_equal (self , other : Signature ) -> bool :
513
+ """
514
+ Returns true if all fields besides metadata are equal
515
+ """
516
+ for field in self .__fields__ .keys ():
517
+ if field == "metadata" :
518
+ continue
519
+ if getattr (self , field ) != getattr (other , field ):
520
+ return False
521
+ return True
522
+
416
523
def __ior__ (self , other : Signature ) -> Signature :
417
524
self ._copy_pos_only (other )
418
525
self ._copy_pos_or_kw (other )
@@ -696,7 +803,12 @@ def merge_property_into_method(
696
803
697
804
698
805
def merge_method_class (l : Class , r : Signature ) -> Class :
699
- l |= Class (constructor = r )
806
+ l .ior_constructor (r )
807
+ return l
808
+
809
+
810
+ def merge_method_overloads_class (l : Class , r : typing .List [Signature ]) -> Class :
811
+ merge_overloads (l .constructor_overloads , r )
700
812
return l
701
813
702
814
@@ -710,6 +822,27 @@ def merge_methods_properties(
710
822
update_metadata_and_types = functools .partial (update , f = merge_methods_properties )
711
823
712
824
825
+ def merge_overloads (
826
+ old : typing .List [Signature ], new : typing .List [Signature ]
827
+ ) -> typing .List [Signature ]:
828
+ """
829
+ Merges the two lists of overloads, updating the old list. Any signatures that match, we take the union of their metadata.
830
+ """
831
+ for new_sig in new :
832
+ # Iterate through all the old signatures, if we find one which matches, update that and break
833
+ for old_sig in old :
834
+ if new_sig .content_equal (old_sig ):
835
+ update_add (old_sig .metadata , new_sig .metadata )
836
+ break
837
+ # Otherwise, we didn't find one, so add the new one to the old
838
+ else :
839
+ old .append (new_sig )
840
+ return old
841
+
842
+
843
+ update_overloads = functools .partial (update , f = merge_overloads )
844
+
845
+
713
846
T = typing .TypeVar ("T" )
714
847
715
848
# Set theoretic functions on lists
0 commit comments