Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 24f09e4

Browse files
committed
Update on "[2/x] clean up casting functions: delayed scaling"
Summary: Removes delayed scaling from `float8_tensor.py`. After this PR, the invariant is that everything in `float8_tensor.py` requires the scale to be calculated elsewhere. This moves the codebase towards separation of concerns for calculating the scale (via various scaling strategies), separated from creating an instance of `Float8Tensor`. Note that stateful delayed scaling is the reason we need this separation. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 08f4052 commit 24f09e4

File tree

1 file changed

+0
-4
lines changed

1 file changed

+0
-4
lines changed

float8_experimental/float8_tensor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def forward(
207207
tensor: torch.Tensor,
208208
scale: torch.Tensor,
209209
float8_dtype=e4m3_dtype,
210-
# amax_buffer: Optional[torch.Tensor] = None,
211210
linear_mm_config: Optional[LinearMMConfig] = None,
212211
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
213212
):
@@ -216,11 +215,8 @@ def forward(
216215
tensor: the tensor to convert
217216
scale: the scale to use to convert the tensor
218217
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
219-
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
220218
emulate: whether to emulate the matmuls in fp32
221219
"""
222-
# if amax_buffer is not None:
223-
# amax_buffer.fill_(tensor_to_amax(tensor))
224220

225221
return to_fp8_no_autograd(
226222
tensor,

0 commit comments

Comments
 (0)