Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4afcd0b

Browse files
committed
fix kernel args
1 parent 3932fab commit 4afcd0b

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

float8_experimental/float8_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def to_fp8_no_autograd(
4949

5050
if x.dim() in {3, 4}:
5151
prev_x_shape = x.shape
52-
x = x.view(-1, x.size(-1))
52+
x = x.reshape(-1, x.size(-1))
5353
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
54-
bits_fp8 = bits_fp8.view(prev_x_shape)
54+
bits_fp8 = bits_fp8.reshape(prev_x_shape)
5555
else:
5656
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
5757
else:

float8_experimental/float8_utils.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,10 @@ def amax_history_to_scale_stack(
7070

7171
@torch.no_grad()
7272
def tensor_to_amax(x, distributed_reduction=False):
73-
if False and float8_experimental.config.use_fused_cast and x.is_cuda:
73+
if float8_experimental.config.use_fused_cast and x.is_cuda:
7474
from float8_experimental.fused_kernels.fused_casting_kernels import abs_max
7575

7676
amax = abs_max(x)
77-
diff = abs_max(x) - x.abs().max().to(torch.float32)
78-
assert (
79-
diff.item() == 0
80-
), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}"
8177
else:
8278
amax = x.abs().max().to(torch.float32)
8379

float8_experimental/fused_kernels/fused_casting_kernels.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,21 @@ def abs_max_kernel(
6161

6262

6363
def abs_max(x: torch.Tensor) -> torch.Tensor:
64-
"Calculates the global max of the absolute values of a tensor"
64+
"""Calculates the global max of the absolute values of a tensor
65+
66+
This kernel launches a grid of 512 threads, each thread calculates the
67+
maximum of x.numel // 512 elements. The results are then reduced to a single
68+
value in a follow up kernel.
69+
70+
Args:
71+
x: Input tensor to calculate the abs_max for
72+
"""
6573
x = x.contiguous()
6674
if x.numel() % 512 == 0:
6775
output = torch.full(
6876
(512, 1), -float("inf"), device=x.device, dtype=torch.float32
6977
)
70-
grid = lambda meta: (meta["X_BLOCK_SIZE"],)
78+
grid = lambda meta: (512,)
7179
X_BLOCK_SIZE = 1
7280
R_BLOCK_SIZE = 1024
7381
r_numel = x.numel() // 512

0 commit comments

Comments
 (0)