10
10
11
11
QUANTIZE_CONFIG_FILENAME = "quantize_config.json"
12
12
13
+ def unpack (x , dim , bits = 4 ):
14
+ return unpack_row (x , bits ) if dim == 0 else unpack_col (x , bits )
15
+
16
+ def unpack_col (x , bits ):
17
+ mask = 2 ** bits - 1
18
+ pack_size = 32 // bits
19
+ unpacked_x = torch .zeros ((x .shape [0 ], x .shape [1 ]* pack_size ), dtype = torch .int )
20
+ for i in range (pack_size ):
21
+ unpacked_x [:, i ::pack_size ] = (x >> (i * bits )) & (mask )
22
+ return unpacked_x
23
+
24
+ def unpack_row (x , bits ):
25
+ mask = 2 ** bits - 1
26
+ pack_size = 32 // bits
27
+ unpacked_x = torch .zeros ((x .shape [0 ]* pack_size , x .shape [1 ]), dtype = torch .int )
28
+ for i in range (pack_size ):
29
+ unpacked_x [i ::pack_size ] = (x >> (i * bits )) & (mask )
30
+ return unpacked_x
31
+
32
+
33
+ def pack (x , dim , bits = 4 ):
34
+ return pack_row (x , bits ) if dim == 0 else pack_col (x , bits )
35
+
36
+ def pack_col (x , bits ):
37
+ mask = 2 ** bits - 1
38
+ pack_size = 32 // bits
39
+ packed_x = torch .zeros ((x .shape [0 ], x .shape [1 ]// pack_size ), dtype = torch .int )
40
+ for i in range (pack_size ):
41
+ packed_x |= (x [:, i ::pack_size ] & mask ) << (i * bits )
42
+ return packed_x
43
+
44
+ def pack_row (x , bits ):
45
+ mask = 2 ** bits - 1
46
+ pack_size = 32 // bits
47
+ packed_x = torch .zeros ((x .shape [0 ]// pack_size , x .shape [1 ]), dtype = torch .int )
48
+ for i in range (pack_size ):
49
+ packed_x |= (x [i ::pack_size ] & mask ) << (i * bits )
50
+ return packed_x
13
51
14
52
class Weights :
15
53
def __init__ (
@@ -101,7 +139,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int):
101
139
tensor = tensor .to (device = self .device )
102
140
return tensor
103
141
104
- def get_sharded (self , tensor_name : str , dim : int ):
142
+ def get_sharded (self , tensor_name : str , dim : int , perm = None , packed = False ):
105
143
filename , tensor_name = self .get_filename (tensor_name )
106
144
f = self ._get_handle (filename )
107
145
slice_ = f .get_slice (tensor_name )
@@ -110,17 +148,53 @@ def get_sharded(self, tensor_name: str, dim: int):
110
148
assert (
111
149
size % world_size == 0
112
150
), f"The choosen size { size } is not compatible with sharding on { world_size } shards"
113
- return self .get_partial_sharded (tensor_name , dim )
151
+ if perm is None :
152
+ return self .get_partial_sharded (tensor_name , dim )
153
+ else :
154
+ return self .get_shuffle_sharded (tensor_name , dim , perm , packed )
155
+
156
+ def get_shuffle_sharded (self , tensor_name : str , dim : int , perm , packed : bool ):
157
+ filename , tensor_name = self .get_filename (tensor_name )
158
+ world_size = self .process_group .size ()
159
+ rank = self .process_group .rank ()
114
160
115
- def get_multi_weights_col (self , prefixes : List [str ], quantize : str , dim : int ):
161
+ f = self ._get_handle (filename )
162
+ tensor = f .get_tensor (tensor_name )
163
+ perm = perm .to (device = tensor .device )
164
+ size = tensor .shape [dim ]
165
+ block_size = size // world_size
166
+ start = rank * block_size
167
+ stop = (rank + 1 ) * block_size
168
+
169
+ # TODO: pack-unpack on cuda to speed up this part
170
+ if dim == 0 :
171
+ if packed :
172
+ tensor = pack (unpack (tensor , dim )[perm ], dim )[start :stop ]
173
+ else :
174
+ tensor = tensor [perm ][start :stop ]
175
+ elif dim == 1 :
176
+ if packed :
177
+ tensor = pack (unpack (tensor , dim )[:, perm ], dim )[:, start :stop ]
178
+ else :
179
+ tensor = tensor [:, perm ][:, start :stop ]
180
+ else :
181
+ raise NotImplementedError ("Let's make that generic when needed" )
182
+ # Special case for gptq which shouldn't convert
183
+ # u4 which are disguised as int32
184
+ if tensor .dtype != torch .int32 :
185
+ tensor = tensor .to (dtype = self .dtype )
186
+ tensor = tensor .to (device = self .device )
187
+ return tensor
188
+
189
+ def get_multi_weights_col (self , prefixes : List [str ], quantize : str , dim : int , col_perm = None ):
116
190
if quantize == "gptq" :
117
191
try :
118
- qweight = torch .cat ([self .get_sharded (f"{ p } .qweight" , dim = 1 ) for p in prefixes ], dim = 1 )
192
+ qweight = torch .cat ([self .get_sharded (f"{ p } .qweight" , dim = 1 , perm = col_perm , packed = False ) for p in prefixes ], dim = 1 )
119
193
except RuntimeError :
120
194
raise RuntimeError ("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" )
121
195
122
- qzeros = torch .cat ([self .get_sharded (f"{ p } .qzeros" , dim = 1 ) for p in prefixes ], dim = 1 )
123
- scales = torch .cat ([self .get_sharded (f"{ p } .scales" , dim = 1 ) for p in prefixes ], dim = 1 )
196
+ qzeros = torch .cat ([self .get_sharded (f"{ p } .qzeros" , dim = 1 , perm = col_perm , packed = True ) for p in prefixes ], dim = 1 )
197
+ scales = torch .cat ([self .get_sharded (f"{ p } .scales" , dim = 1 , perm = col_perm , packed = False ) for p in prefixes ], dim = 1 )
124
198
w = [self .get_tensor (f"{ p } .g_idx" ) for p in prefixes ]
125
199
for w2 in w [1 :]:
126
200
torch .testing .assert_close (w2 , w [0 ])
@@ -141,39 +215,36 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
141
215
weight = torch .cat (w , dim = dim )
142
216
return weight
143
217
144
- def get_multi_weights_row (self , prefix : str , quantize : str ):
218
+ def get_multi_weights_row (self , prefix : str , quantize : str , row_perm = None , noshard = False ):
145
219
if quantize == "gptq" :
146
220
bits , groupsize = self ._get_gptq_params ()
147
221
148
- use_exllama = bits == 4
149
-
150
- if self .process_group .size () > 1 :
151
- g_idx = self .get_tensor (f"{ prefix } .g_idx" )
152
- if g_idx is not None :
153
- if not torch .equal (g_idx .cpu (), torch .tensor ([i // groupsize for i in range (g_idx .shape [0 ])], dtype = torch .int32 )) and not (g_idx == 0 ).all ():
154
- # Exllama implementation does not support row tensor parallelism with act-order, as
155
- # it would require to reorder input activations that are split unto several GPUs
156
- use_exllama = False
222
+ from text_generation_server .utils .layers import HAS_EXLLAMA
223
+ is_preshuffle = (row_perm != None )
224
+ is_masked_matmul = noshard
225
+ assert (is_preshuffle != is_masked_matmul
226
+ or not (is_preshuffle or is_masked_matmul )), f"TP-aware optimization can't both be enabled at the same time { is_preshuffle = } , { is_masked_matmul = } "
227
+ use_exllama = (bits == 4 ) and HAS_EXLLAMA or (is_preshuffle or is_masked_matmul )
228
+ if self .process_group .rank == 0 :
229
+ if use_exllama :
230
+ logger .info (f"Using exllama kernels for row { prefix } " )
231
+ else :
232
+ logger .warning (
233
+ "Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
234
+ " or not currently installed, try using BUILD_EXTENSIONS=True"
235
+ )
157
236
158
237
try :
159
- qweight = self .get_sharded (f"{ prefix } .qweight" , dim = 0 )
238
+ qweight = self .get_sharded (f"{ prefix } .qweight" ,
239
+ dim = 0 ,
240
+ perm = row_perm if use_exllama else None ,
241
+ packed = True ,
242
+ ) if not is_masked_matmul else self .get_tensor (f"{ prefix } .qweight" )
160
243
except RuntimeError :
161
244
raise RuntimeError ("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" )
162
245
163
- from text_generation_server .utils .layers import HAS_EXLLAMA
164
- if use_exllama :
165
- use_exllama = HAS_EXLLAMA
166
- if self .process_group .rank == 0 :
167
- if use_exllama :
168
- logger .info (f"Using exllama kernels for row { prefix } " )
169
- else :
170
- logger .warning (
171
- "Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
172
- " or not currently installed, try using BUILD_EXTENSIONS=True"
173
- )
174
-
175
246
if use_exllama :
176
- if groupsize >= 0 :
247
+ if groupsize >= 0 and not is_masked_matmul :
177
248
# Exllama reorders the weights in advance and the activations on the fly, thus
178
249
# the scales and zero-points do not need to be reordered.
179
250
qzeros = self .get_sharded (f"{ prefix } .qzeros" , dim = 0 )
@@ -183,7 +254,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
183
254
scales = self .get_tensor (f"{ prefix } .scales" )
184
255
185
256
# For tp > 1, at this point we know we do not use act-order
186
- if self .process_group .size () == 1 :
257
+ if ( self .process_group .size () == 1 or is_masked_matmul ) and not is_preshuffle :
187
258
g_idx = self .get_tensor (f"{ prefix } .g_idx" )
188
259
else :
189
260
g_idx = None
@@ -197,7 +268,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
197
268
198
269
weight = (qweight , qzeros , scales , g_idx , bits , groupsize , use_exllama )
199
270
else :
200
- weight = self .get_sharded (f"{ prefix } .weight" , dim = 1 )
271
+ weight = self .get_sharded (f"{ prefix } .weight" , dim = 1 ) if not noshard else self . get_tensor ( f" { prefix } .weight" )
201
272
return weight
202
273
203
274
def _get_gptq_params (self ) -> Tuple [int , int ]:
0 commit comments