Skip to content

Commit c4310e0

Browse files
committed
print type_string of tensor attributes
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent edfd224 commit c4310e0

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

thunder/core/proxies.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,25 @@ def __repr__(self):
20592059
tensors = {n: getattr(self, n) for n in tensor_names}
20602060
return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self._shape}, {tensors=}, {metadata=})>'
20612061

2062+
def type_string(self) -> str:
2063+
base_str = f"{self.device.device_str()} {self.dtype.shortname()}{list(self._shape)}"
2064+
2065+
if self._subclass_type is not None:
2066+
type_str = f"{self._subclass_type.__name__} of {base_str}"
2067+
else:
2068+
type_str = base_str
2069+
2070+
if self._tensors:
2071+
if (tensor_attr_names := getattr(self, "_tensor_attr_names", [])):
2072+
tensor_attr_type_str = ", ".join([
2073+
f"{name}: {t.type_string()}" for name, t in zip(tensor_attr_names, self._tensors)
2074+
])
2075+
else:
2076+
tensor_attr_type_str = ", ".join([t.type_string() for t in self._tensors])
2077+
type_str = type_str + f" ({tensor_attr_type_str})"
2078+
2079+
return type_str
2080+
20622081

20632082
class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface):
20642083
def __init__(

0 commit comments

Comments
 (0)