Skip to content

Commit a2b25e5

Browse files
pre-commit-ci[bot]crcrpar
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 782ea70 commit a2b25e5

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

thunder/core/proxies.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,6 @@ def is_dynamic(self) -> bool:
730730
# fn is the function to call if executing outside a language context
731731
@staticmethod
732732
def _elementwise_unary_helper(a, name, fn, type_promotion_kind=None):
733-
734733
vala = pyval(a)
735734

736735
trace: None | TraceCtx = get_tracectx()
@@ -2063,18 +2062,28 @@ def type_string(self) -> str:
20632062
base_str = f"{self.device.device_str()} {self.dtype.shortname()}{list(self._shape)}"
20642063

20652064
if self._subclass_type is not None:
2066-
type_str = f"{self._subclass_type.__name__} of {base_str}"
2065+
type_str = f"{self._subclass_type.__name__}[{base_str}]"
20672066
else:
20682067
type_str = base_str
20692068

20702069
if self._tensors:
2071-
if (tensor_attr_names := getattr(self, "_tensor_attr_names", [])):
2072-
tensor_attr_type_str = ", ".join([
2073-
f"{name}: {t.type_string()}" for name, t in zip(tensor_attr_names, self._tensors)
2074-
])
2070+
if tensor_attr_names := getattr(self, "_tensor_attr_names", []):
2071+
tensor_attr_type_str = ", ".join(
2072+
[f"{name}: {t.type_string()}" for name, t in zip(tensor_attr_names, self._tensors)]
2073+
)
20752074
else:
20762075
tensor_attr_type_str = ", ".join([t.type_string() for t in self._tensors])
2077-
type_str = type_str + f" ({tensor_attr_type_str})"
2076+
2077+
if self._non_tensors:
2078+
if non_tensor_attr_names := getattr(self, "_non_tensor_attr_names", []):
2079+
non_tensor_attr_type_str = ", ".join(
2080+
[f"{name}: {v}" for name, v in zip(non_tensor_attr_names, self._non_tensors)]
2081+
)
2082+
else:
2083+
non_tensor_attr_type_str = ", ".join(str(v) for v in self._non_tensors)
2084+
type_str = type_str + f" (tensors: {tensor_attr_type_str}, constants: {non_tensor_attr_type_str})"
2085+
else:
2086+
type_str = type_str + f" ({tensor_attr_type_str})"
20782087

20792088
return type_str
20802089

thunder/tests/test_tensor_subclass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def f(x: torch.Tensor, scale: torch.Tensor):
159159
actual = jitted(x, scale)
160160
torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))
161161

162+
print(thunder.last_traces(jitted)[0])
163+
162164

163165
@instantiate(
164166
dtypes=(thunder.core.dtypes.float32,),

0 commit comments

Comments
 (0)