Skip to content

Commit d86aacd

Browse files
authored
Merge pull request #105 from ayasyrev/changed_args
changed fields, set fields
2 parents e3c89ed + 2bd255a commit d86aacd

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

src/model_constructor/helpers.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,50 @@ def __repr_args__(self) -> list[tuple[str, str]]:
9393
if (str_value := self._get_str_value(field))
9494
]
9595

96-
def __repr_changed_args__(self) -> list[str]:
97-
"""Return list repr for changed fields"""
96+
def __repr_set_fields__(self) -> list[str]:
97+
"""Return list repr for fields set at init"""
9898
return [
9999
f"{field}: {self._get_str_value(field)}"
100100
for field in self.model_fields_set # pylint: disable=E1133
101101
if field != "name"
102102
]
103103

104+
def __repr_changed_fields__(self) -> list[str]:
105+
"""Return list repr for changed fields"""
106+
return [
107+
f"{field}: {self._get_str_value(field)}"
108+
for field in self.changed_fields
109+
if field != "name"
110+
]
111+
112+
@property
113+
def changed_fields(self) -> dict[str, Any]:
114+
# return "\n".join(self.__repr_changed_fields__())
115+
return {
116+
field: self._get_str_value(field)
117+
for field in self.model_fields # pylint: disable=E1133
118+
if getattr(self, field) != self.model_fields[field].default
119+
}
120+
104121
def print_cfg(self) -> None:
105122
"""Print full config"""
106123
print(self.__repr__())
107124

108-
def print_changed(self) -> None:
109-
"""Print changed fields."""
110-
changed_fields = self.__repr_changed_args__()
111-
if changed_fields:
125+
def print_set_fields(self) -> None:
126+
"""Print fields changed at init."""
127+
set_fields = self.__repr_set_fields__()
128+
if set_fields:
129+
print("Set fields:")
130+
for field in set_fields:
131+
print(field)
132+
else:
133+
print("Nothing changed")
134+
135+
def print_changed_fields(self) -> None:
136+
"""Print fields changed at init."""
137+
if self.changed_fields:
112138
print("Changed fields:")
113-
for i in changed_fields:
114-
print(" ", i)
139+
for field in self.changed_fields:
140+
print(f"{field}: {self._get_str_value(field)}")
115141
else:
116142
print("Nothing changed")

src/model_constructor/model_constructor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def __call__(self) -> nn.Sequential:
216216
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)]) # type: ignore
217217
)
218218
self.init_cnn(model) # pylint: disable=too-many-function-args
219-
extra_repr = self.__repr_changed_args__()
219+
extra_repr = self.__repr_changed_fields__()
220220
if extra_repr:
221221
model.extra_repr = lambda: ", ".join(extra_repr)
222222
return model

tests/test_helpers.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_cfg_repr_print(capsys: CaptureFixture[str]):
5050
cfg = Cfg()
5151
repr_res = cfg.__repr__()
5252
assert repr_res == "Cfg(\n )"
53-
cfg.print_changed()
53+
cfg.print_set_fields()
5454
out = capsys.readouterr().out
5555
assert out == "Nothing changed\n"
5656
cfg.name = "cfg_name"
@@ -59,13 +59,29 @@ def test_cfg_repr_print(capsys: CaptureFixture[str]):
5959
cfg.print_cfg()
6060
out = capsys.readouterr().out
6161
assert out == "Cfg(\n name='cfg_name')\n"
62-
# changed fields. default - name is not in changed
62+
# Set fields. default - name is not in changed
6363
cfg = Cfg2(name="cfg_name")
64-
cfg.print_changed()
64+
cfg.print_set_fields()
6565
out = capsys.readouterr().out
6666
assert out == "Nothing changed\n"
6767
assert "name" in cfg.model_fields_set
6868
cfg = Cfg2(int_value=0)
69-
cfg.print_changed()
69+
cfg.print_set_fields()
7070
out = capsys.readouterr().out
71-
assert out == "Changed fields:\n int_value: 0\n"
71+
assert out == "Set fields:\nint_value: 0\n"
72+
# Changed fields
73+
cfg = Cfg2(name="cfg_name")
74+
assert cfg.changed_fields == {"name": "cfg_name"}
75+
cfg.int_value = 1
76+
cfg.name = None
77+
assert cfg.changed_fields == {"int_value": 1}
78+
# print
79+
cfg.print_changed_fields()
80+
out = capsys.readouterr().out
81+
assert out == "Changed fields:\nint_value: 1\n"
82+
# return to default
83+
cfg.int_value = 10
84+
assert not cfg.changed_fields
85+
cfg.print_changed_fields()
86+
out = capsys.readouterr().out
87+
assert out == "Nothing changed\n"

0 commit comments

Comments
 (0)