Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3932fab

Browse files
committed
my abs_max is busted
1 parent f430039 commit 3932fab

File tree

3 files changed

+35
-20
lines changed

3 files changed

+35
-20
lines changed

float8_experimental/float8_tensor.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,13 @@ def to_fp8_no_autograd(
4747
):
4848
from driss_torch import saturated_cast
4949

50-
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
51-
50+
if x.dim() in {3, 4}:
51+
prev_x_shape = x.shape
52+
x = x.view(-1, x.size(-1))
53+
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
54+
bits_fp8 = bits_fp8.view(prev_x_shape)
55+
else:
56+
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
5257
else:
5358
x_scaled = x * x_scale
5459
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

float8_experimental/float8_utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,16 @@ def amax_history_to_scale_stack(
7070

7171
@torch.no_grad()
7272
def tensor_to_amax(x, distributed_reduction=False):
73-
if float8_experimental.config.use_fused_cast and x.is_cuda:
73+
if False and float8_experimental.config.use_fused_cast and x.is_cuda:
7474
from float8_experimental.fused_kernels.fused_casting_kernels import abs_max
75+
7576
amax = abs_max(x)
77+
diff = abs_max(x) - x.abs().max().to(torch.float32)
78+
assert (
79+
diff.item() == 0
80+
), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}"
7681
else:
77-
amax = x.abs().max()
82+
amax = x.abs().max().to(torch.float32)
7883

7984
# If the user asked for distributed reduction, do it.
8085
# If the user did not ask for it, assume that it will

float8_experimental/fused_kernels/fused_casting_kernels.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def abs_max_kernel(
4848
r_mask = r_index < r_numel
4949
values = tl.load(
5050
x_ptr + (r_index + (r_numel * x_index)),
51-
x_mask & r_mask,
51+
r_mask,
5252
eviction_policy="evict_last",
5353
other=0.0,
5454
).to(tl.float32)
@@ -62,21 +62,26 @@ def abs_max_kernel(
6262

6363
def abs_max(x: torch.Tensor) -> torch.Tensor:
6464
"Calculates the global max of the absolute values of a tensor"
65-
output = torch.empty((512, 1), device=x.device, dtype=torch.float32)
66-
n_elements = x.numel()
67-
grid = lambda meta: (meta["X_BLOCK_SIZE"],)
68-
X_BLOCK_SIZE = 1
69-
R_BLOCK_SIZE = 1024
70-
r_numel = n_elements // 512
71-
abs_max_kernel[grid](
72-
x,
73-
output,
74-
x_numel=512,
75-
r_numel=r_numel,
76-
X_BLOCK_SIZE=X_BLOCK_SIZE,
77-
R_BLOCK_SIZE=R_BLOCK_SIZE,
78-
)
79-
return output.max()
65+
x = x.contiguous()
66+
if x.numel() % 512 == 0:
67+
output = torch.full(
68+
(512, 1), -float("inf"), device=x.device, dtype=torch.float32
69+
)
70+
grid = lambda meta: (meta["X_BLOCK_SIZE"],)
71+
X_BLOCK_SIZE = 1
72+
R_BLOCK_SIZE = 1024
73+
r_numel = x.numel() // 512
74+
abs_max_kernel[grid](
75+
x,
76+
output,
77+
x_numel=512,
78+
r_numel=r_numel,
79+
X_BLOCK_SIZE=X_BLOCK_SIZE,
80+
R_BLOCK_SIZE=R_BLOCK_SIZE,
81+
)
82+
return output.max()
83+
else:
84+
return x.abs().max().to(torch.float32)
8085

8186

8287
@triton.jit

0 commit comments

Comments
 (0)