File tree Expand file tree Collapse file tree 2 files changed +45
-4
lines changed Expand file tree Collapse file tree 2 files changed +45
-4
lines changed Original file line number Diff line number Diff line change
1
+ import pytest
2
+ import torch
3
+
4
+ from torchao .float8 .float8_utils import _round_down_to_power_of_2
5
+
6
+
7
+ @pytest .mark .parametrize (
8
+ "input_shape" ,
9
+ [
10
+ (1 ,),
11
+ (2 , 3 ),
12
+ (8 , 2048 , 4 , 1024 ),
13
+ ],
14
+ )
15
+ @pytest .mark .parametrize (
16
+ "multiplier" ,
17
+ [
18
+ 1.0 ,
19
+ 2.5 ,
20
+ 10.0 ,
21
+ ],
22
+ )
23
+ def test_round_down_to_power_of_2 (input_shape : tuple [int ], multiplier : int ):
24
+ input_tensor = torch .rand (* input_shape , dtype = torch .float32 ) * multiplier
25
+ expected_output = torch .exp2 (torch .floor (torch .log2 (input_tensor )))
26
+ result = _round_down_to_power_of_2 (input_tensor )
27
+ assert torch .equal (
28
+ result , expected_output
29
+ ), f"expected { expected_output } , but got { result } "
30
+
31
+
32
+ def test_non_float32_input ():
33
+ non_float32_tensor = torch .tensor ([3.0 ], dtype = torch .float64 )
34
+ with pytest .raises (AssertionError , match = "input must be float32 tensor" ):
35
+ _round_down_to_power_of_2 (non_float32_tensor )
Original file line number Diff line number Diff line change @@ -49,10 +49,7 @@ def amax_to_scale(
49
49
else :
50
50
raise ValueError (f"Unsupported float8_dtype: { float8_dtype } " )
51
51
if round_scales_to_power_of_2 :
52
- # rounds down to the nearest power of 2
53
- res = res .view (torch .int32 )
54
- res = (res >> 23 ) << 23
55
- res = res .view (torch .float32 )
52
+ res = _round_down_to_power_of_2 (res )
56
53
return res
57
54
58
55
@@ -286,3 +283,12 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
286
283
or config .cast_config_weight .scaling_type != ScalingType .DYNAMIC
287
284
or config .cast_config_grad_output .scaling_type != ScalingType .DYNAMIC
288
285
)
286
+
287
+
288
+ def _round_down_to_power_of_2 (x : torch .Tensor ) -> torch .Tensor :
289
+ assert x .dtype == torch .float32 , "input must be float32 tensor"
290
+ # rounds down to the nearest power of 2
291
+ x = x .view (torch .int32 )
292
+ x = (x >> 23 ) << 23
293
+ x = x .view (torch .float32 )
294
+ return x
You can’t perform that action at this time.
0 commit comments