Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Checkpoint to reduce fp8_weight tensor saved for backwards #193

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def backward(ctx, gradY):
)


def cast_weight_linear(
x_fp8: Float8Tensor, weight: torch.Tensor, scale: torch.Tensor, bias, emulate: bool
) -> torch.Tensor:
"""Cast weight to fp8_e4m3fn and do linear
Why a new function for something that can be inlined?
Because we want to call torch utils checkpoint on this function.
We always want to recompute the cast of the weight to fp8 since we can, trivially
fuse this into the transpose/contiguous of the weight during the backwards.

Args:
x_fp8 (Float8Tensor): input activation in fp8
weight (torch.Tensor): weight tensor in higher precision
scale (torch.Tensor): scale tensor for weight
bias: bias tensor in higher precision
emulate (bool): whether to emulate fp8 matmul logic in float32
"""
w_fp8 = Float8Tensor.to_float8(weight, scale, torch.float8_e4m3fn, emulate=emulate)
y = torch.nn.functional.linear(x_fp8, w_fp8, bias)
return y


class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
Expand All @@ -48,9 +69,16 @@ class Float8DynamicLinear(torch.nn.Linear):

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
w_fp8 = self.cast_to_float8(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
scale = tensor_to_scale(self.weight, torch.float8_e4m3fn)
y = torch.utils.checkpoint.checkpoint(
cast_weight_linear,
x_fp8,
self.weight,
scale,
self.bias,
self.emulate,
use_reentrant=False,
)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
Expand Down