Skip to content

Commit a8cfb75

Browse files
committed
Arm backend: Add extra_repr methods to MXFP modules
Add extra_repr methods to MXFPLinearOp and MXFPConv2dOp to make them show more detailed info when printed. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Id412b7da6369304a087f1a392f10278cab022533
1 parent a2438e4 commit a8cfb75

2 files changed

Lines changed: 37 additions & 0 deletions

File tree

backends/arm/ao_ext/ops/mxfp_conv2d_op.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
1111
"""
1212

13+
from typing import cast
14+
1315
import torch
1416
import torch.nn.functional as F
17+
1518
from executorch.backends.arm.ao_ext.mxfp import (
1619
_cast_to_block_scaled_cpu_ref,
1720
mxfp_dtype_to_str,
@@ -257,6 +260,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
257260
output = output.to(self.output_dtype)
258261
return output
259262

263+
def extra_repr(self) -> str:
264+
weight_qdata = cast(torch.Tensor, self.weight_qdata)
265+
weight_shape = weight_qdata.shape
266+
in_channels = _get_num_input_channels(weight_qdata, self.weight_dtype)
267+
repr_parts = [
268+
f"in_channels={in_channels}",
269+
f"out_channels={weight_shape[0]}",
270+
f"kernel_size={(weight_shape[1], weight_shape[2])}",
271+
f"stride={self.stride}",
272+
f"padding={self.padding}",
273+
f"dilation={self.dilation}",
274+
f"groups={self.groups}",
275+
f"bias={self.bias is not None}",
276+
f"weight_dtype={self.weight_dtype}",
277+
f"block_size={self.block_size}",
278+
]
279+
return ", ".join(repr_parts)
280+
260281

261282
def transform_conv2d_to_mxfp(
262283
module: torch.nn.Module,

backends/arm/ao_ext/ops/mxfp_linear_op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
1111
"""
1212

13+
from typing import cast
14+
1315
import torch
1416
import torch.nn.functional as F
17+
1518
from executorch.backends.arm.ao_ext.mxfp import (
1619
_cast_to_block_scaled_cpu_ref,
1720
mxfp_dtype_to_str,
@@ -179,6 +182,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
179182
output = output.to(self.output_dtype)
180183
return output
181184

185+
def extra_repr(self) -> str:
186+
weight_qdata = cast(torch.Tensor, self.weight_qdata)
187+
weight_shape = weight_qdata.shape
188+
in_features = _get_num_input_features(weight_qdata, self.weight_dtype)
189+
repr_parts = [
190+
f"in_features={in_features}",
191+
f"out_features={weight_shape[1]}",
192+
f"bias={self.bias is not None}",
193+
f"weight_dtype={self.weight_dtype}",
194+
f"block_size={self.block_size}",
195+
]
196+
return ", ".join(repr_parts)
197+
182198

183199
def transform_linear_to_mxfp(
184200
module: torch.nn.Module,

0 commit comments

Comments
 (0)