|
2 | 2 | import torch.nn as nn |
3 | 3 | from float8_experimental.config import ScalingType |
4 | 4 | from float8_experimental.float8_scaling_utils import ( |
5 | | - cast_to_float8_e4m3_dynamic, |
| 5 | + cast_to_float8_dynamic, |
6 | 6 | NoopFwToFloat8E5M2BwDynamic, |
7 | 7 | ) |
8 | 8 | from float8_experimental.float8_tensor import GemmInputRole |
| 9 | +from float8_experimental.float8_utils import e4m3_dtype |
9 | 10 | from torch.distributed._tensor import DTensor |
10 | 11 | from torch.distributed.device_mesh import DeviceMesh |
11 | 12 | from torch.distributed.tensor.parallel import ( |
@@ -45,8 +46,9 @@ def _prepare_input_fn( |
45 | 46 | input_tensor, device_mesh, input_layouts, run_check=False |
46 | 47 | ) |
47 | 48 |
|
48 | | - input_tensor = cast_to_float8_e4m3_dynamic( |
| 49 | + input_tensor = cast_to_float8_dynamic( |
49 | 50 | input_tensor, |
| 51 | + e4m3_dtype, |
50 | 52 | mod.linear_mm_config, |
51 | 53 | gemm_input_role=GemmInputRole.INPUT, |
52 | 54 | ) # DTensor(Float8Tensor) |
@@ -98,8 +100,9 @@ def _prepare_input_fn( |
98 | 100 | input_tensor, device_mesh, input_layouts, run_check=False |
99 | 101 | ) |
100 | 102 |
|
101 | | - input_tensor = cast_to_float8_e4m3_dynamic( |
| 103 | + input_tensor = cast_to_float8_dynamic( |
102 | 104 | input_tensor, |
| 105 | + e4m3_dtype, |
103 | 106 | mod.linear_mm_config, |
104 | 107 | gemm_input_role=GemmInputRole.INPUT, |
105 | 108 | ) # DTensor(Float8Tensor) |
@@ -196,8 +199,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): |
196 | 199 | input, mesh, (input_layout,), run_check=False |
197 | 200 | ) |
198 | 201 |
|
199 | | - dt_inp = cast_to_float8_e4m3_dynamic( |
| 202 | + dt_inp = cast_to_float8_dynamic( |
200 | 203 | dt_inp, |
| 204 | + e4m3_dtype, |
201 | 205 | self.linear_mm_config, |
202 | 206 | gemm_input_role=GemmInputRole.INPUT, |
203 | 207 | ) # DTensor(Float8Tensor) |
|
0 commit comments