Skip to content

Commit 539375e

Browse files
committed
TP-aware optimizations draft
1 parent dacfe50 commit 539375e

File tree

4 files changed

+210
-39
lines changed

4 files changed

+210
-39
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,22 @@ def __init__(
222222
weights=weights,
223223
bias=False,
224224
)
225+
226+
noshard_o_proj = False
227+
if config.quantize == 'gptq':
228+
from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ
229+
noshard_o_proj = IS_TP_AWARE_GPTQ
230+
225231
self.o_proj = TensorParallelRowLinear.load(
226232
config,
227233
prefix=f"{prefix}.o_proj",
228234
weights=weights,
229235
bias=False,
236+
noshard=noshard_o_proj, # Don't shard o_proj weight matrix if TP-aware optimization is desired
230237
)
238+
self.noshard_o_proj = noshard_o_proj
239+
self.world_size = weights.process_group.size()
240+
self.rank = weights.process_group.rank()
231241

232242
def forward(
233243
self,
@@ -285,9 +295,19 @@ def forward(
285295
1,
286296
False,
287297
)
298+
attn_output = attn_output.reshape(-1, self.num_heads * self.head_size)
288299

289-
return self.o_proj(attn_output.reshape(-1, self.num_heads * self.head_size))
300+
# TP-aware Masked Matmul Optimization by zero filling the activation
301+
# and multiply with full weight matrix in o_proj
302+
if self.noshard_o_proj:
303+
shard_size = attn_output.shape[1]
304+
assert shard_size*self.world_size == self.o_proj.linear.height
305+
zf_attn_output = torch.zeros((attn_output.shape[0], shard_size*self.world_size), dtype=attn_output.dtype, device=attn_output.device)
306+
start_idx = self.rank * shard_size
307+
zf_attn_output[:, start_idx:start_idx+shard_size] = attn_output
308+
attn_output = zf_attn_output
290309

310+
return self.o_proj(attn_output)
291311

292312
class LlamaMLP(nn.Module):
293313
def __init__(self, prefix, config, weights):
@@ -303,19 +323,32 @@ def __init__(self, prefix, config, weights):
303323
else "none",
304324
)
305325
)
326+
327+
# For TP-aware preshuffle optimization, load the g_idx of down_proj for computing perm
328+
# When perm==None the original unoptimized control path is taken
329+
perm = None
330+
if config.quantize=="gptq":
331+
from text_generation_server.utils.layers import IS_TP_AWARE_GPTQ
332+
if IS_TP_AWARE_GPTQ:
333+
down_proj_g_idx = weights.get_tensor(f"{prefix}.down_proj.g_idx")
334+
if down_proj_g_idx is not None:
335+
perm = torch.argsort(down_proj_g_idx)
336+
306337
# Fuse gate and up proj
307338
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
308339
config,
309340
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
310341
weights=weights,
311342
dim=0,
312343
bias=False,
344+
col_perm=perm,
313345
)
314346
self.down_proj = TensorParallelRowLinear.load(
315347
config,
316348
prefix=f"{prefix}.down_proj",
317349
weights=weights,
318350
bias=False,
351+
row_perm=perm,
319352
)
320353
self.intermediate_size = (
321354
config.intermediate_size // weights.process_group.size()
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
3+
# Shuffle columns of scales
4+
def shuffle_and_replace_scales(state_dict, scales_name, col_perm):
5+
scales = state_dict[scales_name]
6+
assert len(col_perm) == scales.shape[1]
7+
8+
shuffled_scales = scales[:,col_perm]
9+
state_dict[scales_name] = shuffled_scales
10+
11+
def unpack_shuffle_repack_and_replace_qzeros(state_dict, bits, qzeros_name, col_perm):
12+
qzeros = state_dict[qzeros_name]
13+
mask = 2**bits - 1
14+
pack_size = 32 // bits
15+
assert len(col_perm) == qzeros.shape[1] * pack_size
16+
17+
#unpack
18+
unpacked_qzeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*pack_size), dtype=torch.int)
19+
for i in range(pack_size):
20+
unpacked_qzeros[:, i::pack_size] = (qzeros >> (i*bits)) & (mask)
21+
22+
# shuffle
23+
shuffled_qzeros = unpacked_qzeros[:,col_perm]
24+
25+
# repack
26+
packed_qzeros = torch.zeros_like(qzeros)
27+
for i in range(pack_size):
28+
packed_qzeros |= (shuffled_qzeros[:, i::pack_size] & mask) << (i*bits)
29+
30+
state_dict[qzeros_name] = packed_qzeros
31+
32+
def shuffle_and_replace_qweight(state_dict, bits, group_size, qweight_name, g_idx_name=None, next_g_idx_name=None, stable=False):
33+
qweight = state_dict[qweight_name]
34+
35+
# unpack qweight
36+
mask = 2**bits - 1
37+
pack_size = 32 // bits
38+
unpacked_qweight = torch.zeros((qweight.shape[0]*pack_size, qweight.shape[1]), dtype=torch.int)
39+
for i in range(pack_size):
40+
unpacked_qweight[i::pack_size] = (qweight >> (i*bits)) & (mask)
41+
42+
# reorder rows conditionally
43+
if not (g_idx_name is None):
44+
g_idx = state_dict[g_idx_name]
45+
row_perm = torch.argsort(g_idx, stable=stable)
46+
unpacked_qweight = unpacked_qweight[row_perm]
47+
48+
# reorder columns conditionally
49+
if not (next_g_idx_name is None):
50+
next_g_idx = state_dict[next_g_idx_name]
51+
col_perm = torch.argsort(next_g_idx, stable=stable)
52+
unpacked_qweight = unpacked_qweight[:,col_perm]
53+
54+
# pack qweight
55+
packed_qweight = torch.zeros_like(qweight)
56+
for i in range(pack_size):
57+
packed_qweight |= (unpacked_qweight[i::pack_size] & mask) << (i*bits)
58+
59+
# replace qweight with new reordered one in state_dict
60+
print(f'replacing {qweight_name}')
61+
state_dict[qweight_name] = packed_qweight
62+
63+
if not (g_idx_name is None):
64+
print(f'replacing {g_idx_name}')
65+
state_dict[g_idx_name] = torch.arange(0, len(g_idx), dtype=torch.int) // group_size

server/text_generation_server/utils/layers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
HAS_BITS_AND_BYTES = False
1414
HAS_EXLLAMA = False
1515
EXLLAMA_VERSION = None
16+
# TODO: should disable TP-aware GPTQ automatically if deployment is single GPU
17+
IS_TP_AWARE_GPTQ = (os.getenv("DISABLE_TP_AWARE_GPTQ","False").lower() == "false")
1618

1719
if torch.cuda.is_available():
1820
try:
@@ -265,13 +267,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
265267

266268
class TensorParallelColumnLinear(SuperLayer):
267269
@classmethod
268-
def load(cls, config, prefix: str, weights, bias: bool):
269-
return cls.load_multi(config, [prefix], weights, bias, dim=0)
270+
def load(cls, config, prefix: str, weights, bias: bool, col_perm=None):
271+
return cls.load_multi(config, [prefix], weights, bias, dim=0, col_perm=col_perm)
270272

271273
@classmethod
272-
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
274+
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int, col_perm=None):
273275
weight = weights.get_multi_weights_col(
274-
prefixes, quantize=config.quantize, dim=dim
276+
prefixes, quantize=config.quantize, dim=dim, col_perm=col_perm
275277
)
276278

277279
if bias:
@@ -289,8 +291,8 @@ def __init__(self, linear, process_group):
289291
self.process_group = process_group
290292

291293
@classmethod
292-
def load(cls, config, prefix: str, weights, bias: bool):
293-
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
294+
def load(cls, config, prefix: str, weights, bias: bool, row_perm=None, noshard=False):
295+
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize, row_perm=row_perm, noshard=noshard)
294296
if bias and weights.process_group.rank() == 0:
295297
# Rank is only on the first rank process
296298
bias = weights.get_tensor(f"{prefix}.bias")

server/text_generation_server/utils/weights.py

Lines changed: 103 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,44 @@
1010

1111
QUANTIZE_CONFIG_FILENAME = "quantize_config.json"
1212

13+
def unpack(x, dim, bits=4):
14+
return unpack_row(x, bits) if dim == 0 else unpack_col(x, bits)
15+
16+
def unpack_col(x, bits):
17+
mask = 2**bits - 1
18+
pack_size = 32 // bits
19+
unpacked_x = torch.zeros((x.shape[0], x.shape[1]*pack_size), dtype=torch.int)
20+
for i in range(pack_size):
21+
unpacked_x[:, i::pack_size] = (x >> (i*bits)) & (mask)
22+
return unpacked_x
23+
24+
def unpack_row(x, bits):
25+
mask = 2**bits - 1
26+
pack_size = 32 // bits
27+
unpacked_x = torch.zeros((x.shape[0]*pack_size, x.shape[1]), dtype=torch.int)
28+
for i in range(pack_size):
29+
unpacked_x[i::pack_size] = (x >> (i*bits)) & (mask)
30+
return unpacked_x
31+
32+
33+
def pack(x, dim, bits=4):
34+
return pack_row(x, bits) if dim == 0 else pack_col(x, bits)
35+
36+
def pack_col(x, bits):
37+
mask = 2**bits - 1
38+
pack_size = 32 // bits
39+
packed_x = torch.zeros((x.shape[0], x.shape[1]//pack_size), dtype=torch.int)
40+
for i in range(pack_size):
41+
packed_x |= (x[:, i::pack_size] & mask) << (i*bits)
42+
return packed_x
43+
44+
def pack_row(x, bits):
45+
mask = 2**bits - 1
46+
pack_size = 32 // bits
47+
packed_x = torch.zeros((x.shape[0]//pack_size, x.shape[1]), dtype=torch.int)
48+
for i in range(pack_size):
49+
packed_x |= (x[i::pack_size] & mask) << (i*bits)
50+
return packed_x
1351

1452
class Weights:
1553
def __init__(
@@ -101,7 +139,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int):
101139
tensor = tensor.to(device=self.device)
102140
return tensor
103141

104-
def get_sharded(self, tensor_name: str, dim: int):
142+
def get_sharded(self, tensor_name: str, dim: int, perm=None, packed=False):
105143
filename, tensor_name = self.get_filename(tensor_name)
106144
f = self._get_handle(filename)
107145
slice_ = f.get_slice(tensor_name)
@@ -110,17 +148,53 @@ def get_sharded(self, tensor_name: str, dim: int):
110148
assert (
111149
size % world_size == 0
112150
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
113-
return self.get_partial_sharded(tensor_name, dim)
151+
if perm is None:
152+
return self.get_partial_sharded(tensor_name, dim)
153+
else:
154+
return self.get_shuffle_sharded(tensor_name, dim, perm, packed)
155+
156+
def get_shuffle_sharded(self, tensor_name: str, dim: int, perm, packed: bool):
157+
filename, tensor_name = self.get_filename(tensor_name)
158+
world_size = self.process_group.size()
159+
rank = self.process_group.rank()
114160

115-
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
161+
f = self._get_handle(filename)
162+
tensor = f.get_tensor(tensor_name)
163+
perm = perm.to(device=tensor.device)
164+
size = tensor.shape[dim]
165+
block_size = size // world_size
166+
start = rank * block_size
167+
stop = (rank + 1) * block_size
168+
169+
# TODO: pack-unpack on cuda to speed up this part
170+
if dim == 0:
171+
if packed:
172+
tensor = pack(unpack(tensor, dim)[perm], dim)[start:stop]
173+
else:
174+
tensor = tensor[perm][start:stop]
175+
elif dim == 1:
176+
if packed:
177+
tensor = pack(unpack(tensor, dim)[:, perm], dim)[:, start:stop]
178+
else:
179+
tensor = tensor[:, perm][:, start:stop]
180+
else:
181+
raise NotImplementedError("Let's make that generic when needed")
182+
# Special case for gptq which shouldn't convert
183+
# u4 which are disguised as int32
184+
if tensor.dtype != torch.int32:
185+
tensor = tensor.to(dtype=self.dtype)
186+
tensor = tensor.to(device=self.device)
187+
return tensor
188+
189+
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int, col_perm=None):
116190
if quantize == "gptq":
117191
try:
118-
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
192+
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1)
119193
except RuntimeError:
120194
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
121195

122-
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
123-
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
196+
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1, perm=col_perm, packed=True) for p in prefixes], dim=1)
197+
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1, perm=col_perm, packed=False) for p in prefixes], dim=1)
124198
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
125199
for w2 in w[1:]:
126200
torch.testing.assert_close(w2, w[0])
@@ -141,39 +215,36 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
141215
weight = torch.cat(w, dim=dim)
142216
return weight
143217

144-
def get_multi_weights_row(self, prefix: str, quantize: str):
218+
def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, noshard=False):
145219
if quantize == "gptq":
146220
bits, groupsize = self._get_gptq_params()
147221

148-
use_exllama = bits == 4
149-
150-
if self.process_group.size() > 1:
151-
g_idx = self.get_tensor(f"{prefix}.g_idx")
152-
if g_idx is not None:
153-
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
154-
# Exllama implementation does not support row tensor parallelism with act-order, as
155-
# it would require to reorder input activations that are split unto several GPUs
156-
use_exllama = False
222+
from text_generation_server.utils.layers import HAS_EXLLAMA
223+
is_preshuffle = (row_perm != None)
224+
is_masked_matmul = noshard
225+
assert (is_preshuffle != is_masked_matmul
226+
or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}"
227+
use_exllama = (bits == 4) and HAS_EXLLAMA or (is_preshuffle or is_masked_matmul)
228+
if self.process_group.rank == 0:
229+
if use_exllama:
230+
logger.info(f"Using exllama kernels for row {prefix}")
231+
else:
232+
logger.warning(
233+
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
234+
" or not currently installed, try using BUILD_EXTENSIONS=True"
235+
)
157236

158237
try:
159-
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
238+
qweight = self.get_sharded(f"{prefix}.qweight",
239+
dim=0,
240+
perm=row_perm if use_exllama else None,
241+
packed=True,
242+
) if not is_masked_matmul else self.get_tensor(f"{prefix}.qweight")
160243
except RuntimeError:
161244
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
162245

163-
from text_generation_server.utils.layers import HAS_EXLLAMA
164-
if use_exllama:
165-
use_exllama = HAS_EXLLAMA
166-
if self.process_group.rank == 0:
167-
if use_exllama:
168-
logger.info(f"Using exllama kernels for row {prefix}")
169-
else:
170-
logger.warning(
171-
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are disabled via the DISABLE_EXLLAMA env var,"
172-
" or not currently installed, try using BUILD_EXTENSIONS=True"
173-
)
174-
175246
if use_exllama:
176-
if groupsize >= 0:
247+
if groupsize >= 0 and not is_masked_matmul:
177248
# Exllama reorders the weights in advance and the activations on the fly, thus
178249
# the scales and zero-points do not need to be reordered.
179250
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
@@ -183,7 +254,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
183254
scales = self.get_tensor(f"{prefix}.scales")
184255

185256
# For tp > 1, at this point we know we do not use act-order
186-
if self.process_group.size() == 1:
257+
if (self.process_group.size() == 1 or is_masked_matmul) and not is_preshuffle:
187258
g_idx = self.get_tensor(f"{prefix}.g_idx")
188259
else:
189260
g_idx = None
@@ -197,7 +268,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
197268

198269
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
199270
else:
200-
weight = self.get_sharded(f"{prefix}.weight", dim=1)
271+
weight = self.get_sharded(f"{prefix}.weight", dim=1) if not noshard else self.get_tensor(f"{prefix}.weight")
201272
return weight
202273

203274
def _get_gptq_params(self) -> Tuple[int, int]:

0 commit comments

Comments
 (0)