Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Dec 13, 2024
1 parent 796a23b commit b7fb234
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
31 changes: 25 additions & 6 deletions llmc/eval/eval_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,29 @@ def eval(self, model_llmc, model_org=None, eval_pos=None):
if self.inference_per_block:
handles = self.register_hooks(model_llmc)
else:
model_llmc.model.cuda()
model_llmc.model.eval()
if model_llmc.mm_model:
model_llmc.mm_model.cuda()
else:
model_llmc.model.cuda()

if model_llmc.mm_model:
model_llmc.mm_model.eval()
else:
model_llmc.model.eval()

if model_org is not None:
if self.inference_per_block:
handles_org = self.register_hooks(model_org)
else:
model_org.model.cuda()
if model_org.mm_model:
model_org.mm_model.cuda()
else:
model_org.model.cuda()

model_org.model.eval()
if model_org.mm_model:
model_org.mm_model.eval()
else:
model_org.model.eval()

eval_res = self.eval_func(
model_org,
Expand All @@ -207,9 +220,15 @@ def eval(self, model_llmc, model_org=None, eval_pos=None):
for h in handles + handles_org:
h.remove()

model_llmc.model.cpu()
if model_llmc.mm_model:
model_llmc.mm_model.cpu()
else:
model_llmc.model.cpu()
if model_org is not None:
model_org.model.cpu()
if model_org.mm_model:
model_org.mm_model.cpu()
else:
model_org.model.cpu()

gc.collect()
torch.cuda.empty_cache()
Expand Down
17 changes: 12 additions & 5 deletions llmc/eval/eval_custom_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@ def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
k: (v.cuda() if torch.is_tensor(v) else v)
for k, v in data.items()
}
generated_ids = model.model.generate(
**data,
max_new_tokens=self.max_new_tokens,
do_sample=False
)
if model.mm_model:
generated_ids = model.mm_model.generate(
**data,
max_new_tokens=self.max_new_tokens,
do_sample=False
)
else:
generated_ids = model.model.generate(
**data,
max_new_tokens=self.max_new_tokens,
do_sample=False
)
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
responses.append(response)
responses = self.flatten_2d_to_1d(responses)
Expand Down

0 comments on commit b7fb234

Please sign in to comment.