Skip to content

Commit f68b3c2

Browse files
Merge pull request #54 from data-apis/fix-overloads
Include overloads
2 parents c93c406 + 98ecac6 commit f68b3c2

File tree

1 file changed

+154
-21
lines changed

1 file changed

+154
-21
lines changed

record_api/apis.py

+154-21
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def bad_name(name: str) -> bool:
3131

3232

3333
class BaseModel(pydantic.BaseModel):
34-
def __repr_args__(self) -> pydantic.ReprArgs:
34+
def __repr_args__(self) -> pydantic.ReprArgs: # type: ignore
3535
for k, v in super().__repr_args__():
3636
if v:
3737
yield k, v
@@ -62,12 +62,25 @@ def json(self, **kwargs) -> str:
6262

6363

6464
class Module(BaseModel):
65+
function_overloads: typing.Dict[str, typing.List[Signature]] = pydantic.Field(
66+
default_factory=dict
67+
)
6568
functions: typing.Dict[str, Signature] = pydantic.Field(default_factory=dict)
6669
classes: typing.Dict[str, Class] = pydantic.Field(default_factory=dict)
6770
properties: typing.Dict[str, typing.Tuple[Metadata, Type]] = pydantic.Field(
6871
default_factory=dict
6972
)
7073

74+
@pydantic.root_validator
75+
def set_overloads(cls, values):
76+
"""
77+
Set the overloads to be the values, if they are not set
78+
"""
79+
values["function_overloads"] = values.get("function_overloads", {}) or {
80+
k: create_overloads(v) for k, v in values["functions"].items()
81+
}
82+
return values
83+
7184
@property
7285
def source(self) -> str:
7386
# try:
@@ -88,10 +101,8 @@ def body(
88101
)
89102
yield from assign_properties(self.properties)
90103

91-
for name, sig in sort_items(self.functions):
92-
if bad_name(name):
93-
continue
94-
yield sig.function_def(name, "function")
104+
yield from function_defs(self.function_overloads, self.functions, "function")
105+
95106
for name, class_ in sort_items(self.classes):
96107
yield class_.class_def(name)
97108

@@ -107,18 +118,37 @@ def __ior__(self, other: Module) -> Module:
107118
update_ior(self.functions, other.functions)
108119
# property -> function
109120
merge_intersection(self.functions, self.properties, merge_property_into_method)
121+
# function overloads
122+
update_overloads(self.function_overloads, other.function_overloads)
110123

111124
# classes
112125
update_ior(self.classes, other.classes)
113126
# function -> class constructor
114127
merge_intersection(self.classes, self.functions, merge_method_class)
128+
merge_intersection(
129+
self.classes, self.function_overloads, merge_method_overloads_class
130+
)
115131

116132
return self
117133

118134

135+
def create_overloads(s: Signature) -> typing.List[Signature]:
136+
"""
137+
Copies signature to create overloads. Needs to copy because we inplace update the signature later on
138+
"""
139+
return [s.copy(deep=True)]
140+
141+
119142
class Class(BaseModel):
143+
constructor_overloads: typing.List[Signature] = pydantic.Field(default_factory=list)
120144
constructor: typing.Union[Signature, None] = None
145+
method_overloads: typing.Dict[str, typing.List[Signature]] = pydantic.Field(
146+
default_factory=dict
147+
)
121148
methods: typing.Dict[str, Signature] = pydantic.Field(default_factory=dict)
149+
classmethod_overloads: typing.Dict[str, typing.List[Signature]] = pydantic.Field(
150+
default_factory=dict
151+
)
122152
classmethods: typing.Dict[str, Signature] = pydantic.Field(default_factory=dict)
123153
properties: typing.Dict[str, typing.Tuple[Metadata, Type]] = pydantic.Field(
124154
default_factory=dict
@@ -127,26 +157,48 @@ class Class(BaseModel):
127157
default_factory=dict
128158
)
129159

160+
@pydantic.root_validator
161+
def set_overloads(cls, values):
162+
"""
163+
Set the overloads to be the values, if not passed in.
164+
"""
165+
if values["constructor"] and not values["constructor_overloads"]:
166+
values["constructor_overloads"] = create_overloads(values["constructor"])
167+
values["method_overloads"] = values.get("method_overloads", {}) or {
168+
k: create_overloads(v) for k, v in values["methods"].items()
169+
}
170+
values["classmethod_overloads"] = values.get("classmethod_overloads", {}) or {
171+
k: create_overloads(v) for k, v in values["classmethods"].items()
172+
}
173+
return values
174+
130175
def class_def(self, name: str) -> cst.ClassDef:
131176
return cst.ClassDef(cst.Name(name), cst.IndentedBlock(list(self.body)),)
132177

133178
@property
134179
def body(self) -> typing.Iterable[cst.BaseStatement]:
135180
if self.constructor is not None:
136-
yield self.constructor.function_def("__init__", "method", indent=1)
181+
function_def_with_overloads(
182+
"__init__",
183+
self.constructor_overloads,
184+
self.constructor,
185+
"method",
186+
indent=1,
187+
)
137188
yield from assign_properties(self.classproperties, True)
138189

139-
for name, sig in sort_items(self.classmethods):
140-
if bad_name(name):
141-
continue
142-
yield sig.function_def(name, "classmethod", indent=1)
143-
190+
yield from function_defs(
191+
self.classmethod_overloads, self.classmethods, "classmethod", 1
192+
)
144193
yield from assign_properties(self.properties)
145194

146-
for name, sig in sort_items(self.methods):
147-
if bad_name(name):
148-
continue
149-
yield sig.function_def(name, "method", indent=1)
195+
yield from function_defs(self.method_overloads, self.methods, "method", 1)
196+
197+
def ior_constructor(self, other: typing.Optional[Signature]) -> None:
198+
if self.constructor and other:
199+
self.constructor |= other
200+
else:
201+
self.constructor = other
150202

151203
def __ior__(self, other: Class) -> Class:
152204
"""
@@ -159,10 +211,9 @@ def __ior__(self, other: Class) -> Class:
159211
160212
* (implicit) property -> classmethod
161213
"""
162-
if self.constructor and other.constructor:
163-
self.constructor |= other.constructor
164-
else:
165-
self.constructor = other.constructor
214+
self.ior_constructor(other.constructor)
215+
merge_overloads(self.constructor_overloads, other.constructor_overloads)
216+
166217
# properties
167218
update_metadata_and_types(self.properties, other.properties)
168219

@@ -177,11 +228,19 @@ def __ior__(self, other: Class) -> Class:
177228
update_ior(self.methods, other.methods)
178229
# property -> method
179230
merge_intersection(self.methods, self.properties, merge_property_into_method)
231+
# method overloads
232+
update_overloads(self.method_overloads, other.method_overloads)
180233

181234
# class methods
182235
update_ior(self.classmethods, other.classmethods)
183236
# method -> classmethod
184237
merge_intersection(self.classmethods, self.methods, operator.ior)
238+
# classmethod overloads
239+
merge_intersection(
240+
self.classmethod_overloads, self.method_overloads, merge_overloads
241+
)
242+
update_overloads(self.classmethod_overloads, other.classmethod_overloads)
243+
185244
# classproperty -> classmethod
186245
merge_intersection(
187246
self.classmethods, self.classproperties, merge_property_into_method
@@ -226,6 +285,37 @@ def sort_items(d: typing.Dict[str, V]) -> typing.Iterable[typing.Tuple[str, V]]:
226285
PartialKeyOrdering = typing.List[typing.Tuple[str, str]]
227286

228287

288+
def function_defs(
289+
overloads: typing.Dict[str, typing.List[Signature]],
290+
functions: typing.Dict[str, Signature],
291+
type: typing.Literal["function", "classmethod", "method"],
292+
indent: int = 0,
293+
) -> typing.Iterable[cst.FunctionDef]:
294+
"""
295+
Exports CST for functions plus overloads
296+
"""
297+
for name, sig in sort_items(functions):
298+
if bad_name(name):
299+
continue
300+
yield from function_def_with_overloads(
301+
name, overloads.get(name, []), sig, type, indent
302+
)
303+
304+
305+
def function_def_with_overloads(
306+
name: str,
307+
overloads: typing.List[Signature],
308+
fn: Signature,
309+
type: typing.Literal["function", "classmethod", "method"],
310+
indent: int = 0,
311+
) -> typing.Iterable[cst.FunctionDef]:
312+
# Don't print overloads if just one!
313+
if len(overloads) > 1:
314+
for overload in overloads:
315+
yield overload.function_def(name, type, indent=indent, overload=True)
316+
yield fn.function_def(name, type, indent=indent)
317+
318+
229319
class Signature(BaseModel):
230320
# See for a helpful spec https://www.python.org/dev/peps/pep-0570/#syntax-and-semantics
231321
# Also keyword only PEP https://www.python.org/dev/peps/pep-3102/
@@ -292,14 +382,20 @@ def function_def(
292382
name: str,
293383
type: typing.Literal["function", "classmethod", "method"],
294384
indent=0,
385+
overload=False,
295386
) -> cst.FunctionDef:
387+
decorators: typing.List[cst.Decorator] = []
388+
if overload:
389+
decorators.append(cst.Decorator(cst.Name("overload")))
390+
if type == "classmethod":
391+
decorators.append(cst.Decorator(cst.Name("classmethod")))
296392
return cst.FunctionDef(
297393
cst.Name(name),
298394
self.parameters(type),
299395
cst.IndentedBlock(
300396
[cst.SimpleStatementLine([s]) for s in self.body(indent)]
301397
),
302-
[cst.Decorator(cst.Name("classmethod"))] if type == "classmethod" else [],
398+
decorators,
303399
)
304400

305401
def body(self, indent: int) -> typing.Iterable[cst.BaseSmallStatement]:
@@ -413,6 +509,17 @@ def from_bound_params(
413509
),
414510
)
415511

512+
def content_equal(self, other: Signature) -> bool:
513+
"""
514+
Returns true if all fields besides metadata are equal
515+
"""
516+
for field in self.__fields__.keys():
517+
if field == "metadata":
518+
continue
519+
if getattr(self, field) != getattr(other, field):
520+
return False
521+
return True
522+
416523
def __ior__(self, other: Signature) -> Signature:
417524
self._copy_pos_only(other)
418525
self._copy_pos_or_kw(other)
@@ -696,7 +803,12 @@ def merge_property_into_method(
696803

697804

698805
def merge_method_class(l: Class, r: Signature) -> Class:
699-
l |= Class(constructor=r)
806+
l.ior_constructor(r)
807+
return l
808+
809+
810+
def merge_method_overloads_class(l: Class, r: typing.List[Signature]) -> Class:
811+
merge_overloads(l.constructor_overloads, r)
700812
return l
701813

702814

@@ -710,6 +822,27 @@ def merge_methods_properties(
710822
update_metadata_and_types = functools.partial(update, f=merge_methods_properties)
711823

712824

825+
def merge_overloads(
826+
old: typing.List[Signature], new: typing.List[Signature]
827+
) -> typing.List[Signature]:
828+
"""
829+
Merges the two lists of overloads, updating the old list. Any signatures that match, we take the union of their metadata.
830+
"""
831+
for new_sig in new:
832+
# Iterate through all the old signatures, if we find one which matches, update that and break
833+
for old_sig in old:
834+
if new_sig.content_equal(old_sig):
835+
update_add(old_sig.metadata, new_sig.metadata)
836+
break
837+
# Otherwise, we didn't find one, so add the new one to the old
838+
else:
839+
old.append(new_sig)
840+
return old
841+
842+
843+
update_overloads = functools.partial(update, f=merge_overloads)
844+
845+
713846
T = typing.TypeVar("T")
714847

715848
# Set theoretic functions on lists

0 commit comments

Comments
 (0)