Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f8f935c

Browse files
committed
need to handle braodcasted grads
1 parent 1efdd9b commit f8f935c

File tree

3 files changed

+4
-1
lines changed

3 files changed

+4
-1
lines changed

benchmarks/bench_dynamic_linear_fused_cast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def main(
116116
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
117117

118118
def float8_forw_backward():
119-
linear_float8(input_tensor).sum().backward()
119+
out = linear_float8(input_tensor)
120+
out.sum().backward()
120121

121122
def n_times(n, fn, *args, **kwargs):
122123
def wrapper(*args, **kwargs):

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def to_fp8_no_autograd(
4848
from driss_torch import saturated_cast
4949

5050
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
51+
5152
else:
5253
x_scaled = x * x_scale
5354
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

float8_experimental/fused_kernels/fused_casting_kernels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def maximum(a, b):
2828
mask |= a != a
2929
return tl.where(mask, a, b)
3030

31+
3132
@triton.jit
3233
def abs_max_kernel(
3334
x_ptr,

0 commit comments

Comments
 (0)