From c07f1d9014c0decbbc98fa61bfaa67165d6c7ebe Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 22 Jan 2024 18:22:29 -0800 Subject: [PATCH] checkpiont to reduce memory usage, only do dynamic for now --- float8_experimental/float8_dynamic_linear.py | 34 ++++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 58e352da..1af700b3 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -40,6 +40,27 @@ def backward(ctx, gradY): ) +def cast_weight_linear( + x_fp8: Float8Tensor, weight: torch.Tensor, scale: torch.Tensor, bias, emulate: bool +) -> torch.Tensor: + """Cast weight to fp8_e4m3fn and do linear + Why a new function for something that can be inlined? + Because we want to call torch utils checkpoint on this function. + We always want to recompute the cast of the weight to fp8 since we can, trivially + fuse this into the transpose/contiguous of the weight during the backwards. + + Args: + x_fp8 (Float8Tensor): input activation in fp8 + weight (torch.Tensor): weight tensor in higher precision + scale (torch.Tensor): scale tensor for weight + bias: bias tensor in higher precision + emulate (bool): whether to emulate fp8 matmul logic in float32 + """ + w_fp8 = Float8Tensor.to_float8(weight, scale, torch.float8_e4m3fn, emulate=emulate) + y = torch.nn.functional.linear(x_fp8, w_fp8, bias) + return y + + class Float8DynamicLinear(torch.nn.Linear): """ A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly @@ -48,9 +69,16 @@ class Float8DynamicLinear(torch.nn.Linear): def forward(self, x): x_fp8 = self.cast_to_float8(x) - w_fp8 = self.cast_to_float8(self.weight) - - y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) + scale = tensor_to_scale(self.weight, torch.float8_e4m3fn) + y = torch.utils.checkpoint.checkpoint( + cast_weight_linear, + x_fp8, + self.weight, + scale, + self.bias, + self.emulate, + use_reentrant=False, + ) # Cast gradY to float8_e5m2 during backward y = self.cast_to_float8e5m2_bw(y)