Skip to content

Commit be3de0f

Browse files
committed
Add some code for evaluating FPx (not enabled)
1 parent d393bfe commit be3de0f

File tree

3 files changed

+253
-0
lines changed

3 files changed

+253
-0
lines changed

exllamav2/experimental/fpx.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import torch
2+
from torch import Tensor
3+
import gc
4+
5+
# From https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/quantization/utils/fp6_utils.py
6+
7+
def _n_ones(n: int) -> int:
8+
return (1 << n) - 1
9+
10+
EBITS_F32, MBITS_F32 = 8, 23
11+
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
12+
13+
_ONES_TABLE = [_n_ones(i) for i in range(8)]
14+
15+
def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
16+
"""Convert FP32 numbers to sub-byte floating point numbers with the given
17+
number of exponent and mantissa bits.
18+
Input: torch.Tensor of dtype torch.float
19+
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
20+
in the least significant bits. e.g.
21+
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
22+
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
23+
Note: there are no special values (NaN, inf) support in this code. Values
24+
outside the representable range of FPx after rounding are clamped to the
25+
maximum FPx magnitude (sign is preserved).
26+
Code below is an adaptation of https://fburl.com/code/ciwofcg4
27+
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
28+
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
29+
"""
30+
assert x.dtype == torch.float
31+
assert 1 + ebits + mbits <= 8
32+
33+
# calculate constants
34+
exp_bias = _n_ones(ebits - 1)
35+
max_int = _n_ones(ebits + mbits)
36+
sign_mask = 1 << (ebits + mbits)
37+
38+
# TODO document this better
39+
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
40+
41+
# all E bits and M bits are 1s
42+
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
43+
44+
# E bits = 1, M bits = 0
45+
min_normal = 2 ** (1 - exp_bias)
46+
47+
denorm_exp = (
48+
# exp bias conversion between formats
49+
(F32_EXP_BIAS - exp_bias)
50+
# mantissa length difference between formats
51+
+ (MBITS_F32 - mbits)
52+
# add one to encoded exponent for denormalized numbers
53+
+ 1
54+
)
55+
denorm_mask_int = denorm_exp << MBITS_F32
56+
57+
# reinterpret int32 as float32
58+
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
59+
60+
# save the sign
61+
# Note that we have torch.uint32, but some ops like cpu bit shifts
62+
# do not work on it. So, we stay in int32.
63+
x = x.view(torch.int32)
64+
sign = x & 0x80000000
65+
66+
# set everything to positive, will add sign back at the end
67+
x = x ^ sign
68+
69+
# TODO: can the branch floating point comparisons below be done without
70+
# converting to float? probably but need to verify
71+
x = x.view(torch.float)
72+
73+
# rewrite saturate/denorm/norm branches without explicit data dependent
74+
# control flow, to be more compiler friendly
75+
saturate_mask = x >= max_normal
76+
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
77+
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
78+
79+
#
80+
# branch 1: saturate to max val - handled later in the code which combines
81+
# the branches
82+
#
83+
84+
#
85+
# branch 2: to conversion to denormal as well as rounding up to normal
86+
#
87+
denormal_x = x + denorm_mask_float
88+
denormal_x = denormal_x.view(torch.int32)
89+
denormal_x -= denorm_mask_int
90+
denormal_x = denormal_x.to(torch.uint8)
91+
92+
#
93+
# branch 3: stay in normal range, adjust the exponent and round
94+
#
95+
normal_x = x.view(torch.int32)
96+
# resulting mantissa is odd
97+
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
98+
# update exponent, rounding bias part 1
99+
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
100+
normal_x += val_to_add
101+
# rounding bias part 2
102+
normal_x += mant_odd
103+
# take the bits!
104+
normal_x = normal_x >> (MBITS_F32 - mbits)
105+
normal_x = normal_x.to(torch.uint8)
106+
107+
#
108+
# combine the branches
109+
#
110+
x = torch.full_like(x, max_int, dtype=torch.uint8)
111+
x = torch.where(denormal_mask, denormal_x, x)
112+
x = torch.where(normal_mask, normal_x, x)
113+
114+
# add sign back
115+
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
116+
sign_lp = sign_lp.to(torch.uint8)
117+
# Right shift of a negative signed integer can fill the least significant
118+
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
119+
# doesn't have an uint32 dtype, we mask out these bits to get just the
120+
# f4 sign bit
121+
sign_lp = sign_lp & sign_mask
122+
x = x | sign_lp
123+
124+
return x.to(torch.uint8)
125+
126+
127+
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
128+
"""Convert sub-byte floating point numbers with the given number of exponent
129+
and mantissa bits to FP32.
130+
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
131+
in the least significant bits. e.g.
132+
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
133+
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
134+
Output: torch.Tensor of dtype fp32 with the dequantized value
135+
"""
136+
assert x.dtype == torch.uint8
137+
assert 1 + ebits + mbits <= 8
138+
139+
sign_mask = 1 << (ebits + mbits)
140+
exp_bias = _n_ones(ebits - 1)
141+
mantissa_mask = _n_ones(mbits)
142+
143+
# save the sign
144+
sign_lp = x & sign_mask
145+
146+
# set everything to positive, will add sign back at the end
147+
x_pos = x ^ sign_lp
148+
149+
#
150+
# 1. Calculate zero mask
151+
#
152+
zero_mask = x_pos == 0
153+
154+
#
155+
# 2. Calculate the denormal path mask
156+
#
157+
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
158+
159+
#
160+
# 3. Calculate the normal path
161+
#
162+
163+
# calculate the new exponent and shift it to bits 2:9 of the result
164+
exp_biased_lp = x_pos >> mbits
165+
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
166+
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
167+
168+
# shift the mantissa to bits 10:32 of the result
169+
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
170+
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
171+
result = exp_biased_f32 | mantissa_f32
172+
173+
#
174+
# 4. Add the zero and denormal casts to the already casted normal path
175+
#
176+
result[zero_mask] = 0
177+
178+
denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
179+
180+
# fast path.
181+
# without this, performance for FP4_E2M1 is slower by 2x
182+
if mbits == 1:
183+
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
184+
185+
else:
186+
# iterate over all possible values of mantissa
187+
# i=0, j=1
188+
# i=1, j=10,11
189+
# i=2, j=100,101,110,111
190+
# and so on
191+
for i in range(mbits):
192+
for mantissa_cmp in range(1 << i, 1 << (i+1)):
193+
# left shift mantissa until it overflows (create an implicit 1)
194+
# subtract exponent by the same amount
195+
left_shift = mbits - i
196+
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
197+
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
198+
199+
# we can update this in-place since the values won't overlap
200+
# torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
201+
# thus we use + instead of | here
202+
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
203+
204+
result = torch.where(denormal_mask, mantissa_lp_int32, result)
205+
206+
# add sign back
207+
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
208+
result = result | sign_f32
209+
210+
return result.view(torch.float)
211+
212+
213+
def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> tuple[Tensor, Tensor]:
214+
# _n_ones() is not compatible with torch.compile() due to << operator
215+
# https://github.com/pytorch/pytorch/issues/119152
216+
# exp_bias = _n_ones(ebits - 1)
217+
# max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
218+
219+
# workaround: global lookup table
220+
exp_bias = _ONES_TABLE[ebits - 1]
221+
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
222+
223+
tensor = tensor.float()
224+
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
225+
tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
226+
return tensor_fpx, scale.half()
227+
# tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
228+
# return tensor_tc_fpx, scale.half()
229+
230+
231+
def from_scaled_tc_fpx(fpx_unpacked: Tensor, ebits: int, mbits: int, scale = None) -> Tensor:
232+
# fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
233+
tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
234+
if scale is not None:
235+
tensor = tensor * scale.float().view(-1, 1)
236+
return tensor
237+
238+
239+
240+
def fpxify(tensor: torch.Tensor, exponent: int, mantissa: int) -> torch.Tensor:
241+
"""
242+
Convert to eXmY and back again
243+
"""
244+
245+
a = tensor.to("cuda:0").float()
246+
b, scale = to_scaled_tc_fpx(a, exponent, mantissa)
247+
c = from_scaled_tc_fpx(b, exponent, mantissa, scale)
248+
d = c.half().to(tensor.device)
249+
return d

exllamav2/generator/dynamic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,7 @@ def emit(
19891989
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
19901990

19911991
# Stop if we reach max_new_tokens
1992+
# TODO: Auto-extend option
19921993

19931994
if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
19941995
return emit(results, emit_eos = True, emit_held = True, eos_reason = "max_new_tokens")

exllamav2/linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from exllamav2.tensor_p import BROADCAST_VC
1010
from exllamav2.util import unpack_4bit, pack_4bit
1111
import gc
12+
from exllamav2.experimental.fpx import fpxify
1213

1314
from typing import TYPE_CHECKING
1415

@@ -170,6 +171,7 @@ def load(
170171

171172
elif isinstance(w, nn.Parameter):
172173
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
174+
# w = nn.Parameter(fpxify(w.data, 2, 3), requires_grad = False)
173175
if self.normalize_unq:
174176
w = self.normalize(w)
175177
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
@@ -188,6 +190,7 @@ def load(
188190
if self.normalize_unq:
189191
w = self.normalize(w[0]), w[1]
190192
ww = w[0]
193+
# ww = nn.Parameter(fpxify(ww.data, 2, 3), requires_grad = False)
191194
wb = w[1]
192195
if self.padding > 0:
193196
ww = nn.Parameter(F.pad(ww.data, (0, 0, 0, self.padding)).contiguous())

0 commit comments

Comments
 (0)