Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed inference of 70B awq model #2531

Merged
merged 2 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from onmt.utils.logging import init_logger
from onmt.translate.translator import build_translator
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
from onmt.inference_engine import InferenceEnginePY
from onmt.opts import config_opts, translate_opts
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
Expand All @@ -17,29 +13,12 @@ def translate(opt):
ArgumentParser._get_all_transform_translate(opt)
ArgumentParser._validate_transforms_opts(opt)
ArgumentParser.validate_translate_opts_dynamic(opt)
logger = init_logger(opt.log_file)

set_random_seed(opt.seed, use_gpu(opt))

translator = build_translator(opt, logger=logger, report_score=False)

transforms_cls = get_transforms_cls(opt._all_transform)

infer_iter = build_dynamic_dataset_iter(
opt,
transforms_cls,
translator.vocabs,
task=CorpusTask.INFER,
copy=translator.copy_attn,
device_id=opt.gpu,
)

_, _ = translator._translate(
infer_iter,
transform=infer_iter.transforms,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
engine = InferenceEnginePY(opt)
_, _ = engine.infer_file()
engine.terminate()


def _get_parser():
Expand Down
16 changes: 6 additions & 10 deletions onmt/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from onmt.constants import CorpusTask, DefaultTokens, ModelTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.utils.distributed import ErrorHandler, spawned_infer
from onmt.utils.logging import logger
from onmt.utils.logging import init_logger
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe


Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(self, opt):

super().__init__(opt)
self.opt = opt
self.logger = init_logger(opt.log_file)

if opt.world_size > 1:
mp = torch.multiprocessing.get_context("spawn")
Expand All @@ -92,10 +93,6 @@ def __init__(self, opt):
self.queue_result = []
self.procs = []

print("world_size: ", opt.world_size)
print("gpu_ranks: ", opt.gpu_ranks)
print("opt.gpu: ", opt.gpu)

for device_id in range(opt.world_size):
self.queue_instruct.append(mp.Queue())
self.queue_result.append(mp.Queue())
Expand All @@ -113,12 +110,11 @@ def __init__(self, opt):
)
)
self.procs[device_id].start()
print(" Starting process pid: %d " % self.procs[device_id].pid)
self.error_handler.add_child(self.procs[device_id].pid)
else:
self.device_id = 0 if opt.world_size == 1 else -1
self.device_id = opt.gpu
self.translator = build_translator(
opt, self.device_id, logger=logger, report_score=True
opt, self.device_id, logger=self.logger, report_score=True
)
self.transforms_cls = get_transforms_cls(opt._all_transform)
self.vocabs = self.translator.vocabs
Expand Down Expand Up @@ -168,9 +164,9 @@ def __init__(self, opt):

super().__init__(opt)
self.opt = opt
self.logger = logger
self.logger = init_logger(opt.log_file)
assert self.opt.world_size <= 1, "World size must be less than 1."
self.device_id = 0 if opt.world_size == 1 else -1
self.device_id = opt.gpu
if opt.world_size == 1:
self.device_index = opt.gpu_ranks
self.device = "cuda"
Expand Down
22 changes: 17 additions & 5 deletions onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,19 @@ def tensorify(vocabs, minibatch, device, left_pad=False):
)

if minibatch[0][0]["tgt"] is not None:
tbatchtgt = [
torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]
if left_pad:
tbatchtgt = [
torch.tensor(
ex["tgt"]["tgt_ids"], dtype=torch.long, device=device
).flip(dims=[0])
for ex, indice in minibatch
]
else:
tbatchtgt = [
torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device)
for ex, indice in minibatch
]

padidx = vocabs["tgt"][DefaultTokens.PAD]
tbatchtgt = pad_sequence(tbatchtgt, batch_first=True, padding_value=padidx)
tbatchtgt = tbatchtgt[:, :, None]
Expand All @@ -258,7 +267,10 @@ def tensorify(vocabs, minibatch, device, left_pad=False):
dtype=torch.long,
device=device,
)
tensor_batch["tgt"] = tbatchtgt
if left_pad:
tensor_batch["tgt"] = tbatchtgt.flip(dims=[1])
else:
tensor_batch["tgt"] = tbatchtgt
tensor_batch["tgtlen"] = tbatchtgtlen

if "align" in minibatch[0][0].keys() and minibatch[0][0]["align"] is not None:
Expand Down
6 changes: 5 additions & 1 deletion onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def load_test_model(opt, device_id=0, model_path=None):
"aawq_gemm",
"aawq_gemv",
]: # if the loaded model is a awq quantized one, inference config cannot overwrite this
if hasattr(opt, "quant_type") and opt.quant_type != model_opt.quant_type:
if (
hasattr(opt, "quant_type")
and opt.quant_type != ""
and opt.quant_type != model_opt.quant_type
):
raise ValueError(
"Model is a awq quantized model, cannot overwrite with another quant method"
)
Expand Down
27 changes: 20 additions & 7 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def count_parameters(self, log=print):
raise NotImplementedError

def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset):
if module.__class__.__name__ == "WQLinear_GEMM":
# ugly patch because in_feat and out_feat are reversed in WQLinear_GEMM
param.data = param.data.transpose(0, 1)
ckpt_t = ckpt_t.transpose(0, 1)
if name.split(".")[-1] in [
"linear_keys",
"linear_values",
Expand Down Expand Up @@ -73,13 +77,22 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
].size()
), "An error in model's partition and checkpoint's slice was detected"
if name + "." + param_name in buf_list:
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
],
)
if module.__class__.__name__ == "WQLinear_GEMM":
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].transpose(0, 1),
)
else:
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
],
)
else:
param.data = ckpt_t[
col_slice_start:col_slice_end,
Expand Down
12 changes: 10 additions & 2 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,8 +1564,16 @@ def _add_quant_opts(parser):
group.add(
"--quant_type",
"-quant_type",
default="bnb_8bit",
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4", "llm_awq", "aawq_gemm", "aawq_gemv"],
default="",
choices=[
"",
"bnb_8bit",
"bnb_FP4",
"bnb_NF4",
"llm_awq",
"aawq_gemm",
"aawq_gemv",
],
type=str,
help="Type of compression.",
)
Expand Down
5 changes: 4 additions & 1 deletion onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,10 @@ def _report_score(self, name, score_total, nb_sentences):
msg = "%s No translations" % (name,)
else:
score = score_total / nb_sentences
ppl = exp(-score_total / nb_sentences)
try:
ppl = exp(-score_total / nb_sentences)
except OverflowError:
ppl = float("inf")
msg = "%s SCORE: %.4f, %s PPL: %.2f NB SENTENCES: %d" % (
name,
score,
Expand Down
3 changes: 1 addition & 2 deletions onmt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result):
init_logger(opt.log_file)
translator = build_translator(opt, device_id, logger=logger, report_score=True)
transforms_cls = get_transforms_cls(opt._all_transform)
print("Device_id: ", device_id, " translator built")
while True:
instruction = queue_instruct.get()
if instruction[0] == "stop":
Expand Down Expand Up @@ -227,7 +226,7 @@ def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result):
device_id=device_id,
)
scores, preds = translator._translate(
infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug
infer_iter, infer_iter.transforms, opt.attn_debug, opt.align_debug
)
queue_result.put(scores)
queue_result.put(preds)
Expand Down
15 changes: 12 additions & 3 deletions tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def main():
cumul_length = 0
# Now we can pipe the full file through the model using the Iterator

for i, batch in enumerate(infer_iter):
for i, (batch, bucket_idx) in enumerate(infer_iter):
# reminder a batch includes .src .tgt .indices and it is sorted
batch_size = len(batch["srclen"])
src = batch["src"]
src_len = batch["srclen"]

# print(batch)
outputs, attns = model(src, None, src_len, with_align=False)
# Compute and retrieve the loss for EACH sentence
loss, _ = valid_loss(batch, outputs, attns)
Expand All @@ -102,7 +102,16 @@ def main():
cumul_length += batch["tgt"][:, 1:, 0].ne(padding_idx).sum().cpu()
# Now we need to rearrange the batch of ppl
# in the original order with indices
sent_ppl_orig = ppl.gather(0, batch["cid_line_number"].argsort(0))
sent_ppl_orig = ppl.gather(
0,
torch.tensor(
sorted(
range(len(batch["cid_line_number"])),
key=lambda k: batch["cid_line_number"][k],
),
device=ppl.device,
),
)
for j in range(batch_size):
ppl_file.write(str(sent_ppl_orig[j].item()) + "\n")
logger.info(
Expand Down
Loading