Skip to content

Commit 262b180

Browse files
authored
fix autocast handling for float8 training rowwise recipes (#2587)
Summary: Breakage reported by customer, fixing and adding a test. Two unrelated changes: 1. delete a duplicate autocast test (testing same thing as the one I'm changing) 2. modify `Float8TrainingTensor` repr to print `lp_dtype` instead of `dtype`, since logically it's printing the low precision data dtype Test Plan: ```bash pytest test/float8/test_base.py -k test_autocast_outputs -s -x ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 9f4ee3e commit 262b180

File tree

3 files changed

+15
-36
lines changed

3 files changed

+15
-36
lines changed

test/float8/test_base.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -410,51 +410,30 @@ def test_linear_from_recipe(
410410
@pytest.mark.parametrize(
411411
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
412412
)
413+
@pytest.mark.parametrize(
414+
"recipe_name",
415+
[
416+
Float8LinearRecipeName.TENSORWISE,
417+
Float8LinearRecipeName.ROWWISE,
418+
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
419+
],
420+
)
413421
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
414422
def test_autocast_outputs(
415423
self,
416424
emulate: bool,
417425
linear_dtype: torch.dtype,
426+
recipe_name: Float8LinearRecipeName,
418427
):
419428
m_ref = nn.Sequential(
420429
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
421430
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
422431
)
423-
config = Float8LinearConfig(
424-
emulate=emulate,
425-
)
426-
m = convert_to_float8_training(copy.deepcopy(m_ref), config=config)
427-
428-
# autocast off
429-
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
430-
y = m(x)
431-
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"
432-
433-
# autocast on
434-
with torch.autocast("cuda"):
435-
y = m(x)
436-
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"
437-
438-
with torch.autocast("cuda", dtype=torch.bfloat16):
439-
y = m(x)
440-
assert y.dtype == torch.bfloat16, (
441-
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
442-
)
443-
444-
@pytest.mark.parametrize(
445-
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
446-
)
447-
@pytest.mark.parametrize(
448-
"emulate", [True, False] if is_sm_at_least_89() else [True]
449-
)
450-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
451-
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
452-
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
453-
config = Float8LinearConfig(emulate=emulate)
454-
m = Float8Linear.from_float(copy.deepcopy(m), config)
432+
config = Float8LinearConfig.from_recipe_name(recipe_name)
433+
# work around config being frozen
434+
object.__setattr__(config, "emulate", emulate)
455435

456-
# Cast the module to dtype
457-
m = m.to(dtype=linear_dtype)
436+
m = convert_to_float8_training(copy.deepcopy(m_ref), config=config)
458437

459438
# autocast off
460439
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)

torchao/float8/float8_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
444444
when the input is a Float8TrainingTensor, presenting as a fp32
445445
tensor.
446446
"""
447-
_assert_tensorwise_scale(aten_op, args[0]._scale)
448447
assert isinstance(args[0], Float8TrainingTensor)
449448
assert len(kwargs) == 1 and "dtype" in kwargs, (
450449
"Only support dtype kwarg for autocast"
@@ -459,6 +458,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
459458
kwargs["dtype"],
460459
args[0]._linear_mm_config,
461460
args[0]._gemm_input_role,
461+
args[0]._axiswise_dim,
462462
)
463463

464464

torchao/float8/float8_training_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def __new__(
319319
return self
320320

321321
def __repr__(self):
322-
return f"Float8TrainingTensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}"
322+
return f"Float8TrainingTensor(lp_dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}"
323323

324324
def __tensor_flatten__(self):
325325
ctx = {

0 commit comments

Comments
 (0)