Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Dec 13, 2024
1 parent e180c1f commit 2654cc7
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 24 deletions.
22 changes: 1 addition & 21 deletions llmc/eval/eval_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def register_hooks(self, model):
return handles

@torch.no_grad()
def eval(self, model_llmc, model_org=None, eval_pos=None):
def eval(self, model_llmc, eval_pos=None):
handles, handles_org = [], []
if self.inference_per_block:
handles = self.register_hooks(model_llmc)
Expand All @@ -194,22 +194,7 @@ def eval(self, model_llmc, model_org=None, eval_pos=None):
else:
model_llmc.model.eval()

if model_org is not None:
if self.inference_per_block:
handles_org = self.register_hooks(model_org)
else:
if model_org.mm_model:
model_org.mm_model.cuda()
else:
model_org.model.cuda()

if model_org.mm_model:
model_org.mm_model.eval()
else:
model_org.model.eval()

eval_res = self.eval_func(
model_org,
model_llmc,
self.testenc,
self.seq_len,
Expand All @@ -224,11 +209,6 @@ def eval(self, model_llmc, model_org=None, eval_pos=None):
model_llmc.mm_model.cpu()
else:
model_llmc.model.cpu()
if model_org is not None:
if model_org.mm_model:
model_org.mm_model.cpu()
else:
model_org.model.cpu()

gc.collect()
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion llmc/eval/eval_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, model, config):
self.add_chat_temp = self.eval_cfg.get('add_chat_temp', False)

@torch.no_grad()
def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
def eval_func(self, model, testenc, seq_len, bs, eval_pos):
samples = []
pbar = tqdm(total=len(testenc) * bs, dynamic_ncols=True, position=0, desc='Evaluating')

Expand Down
2 changes: 1 addition & 1 deletion llmc/eval/eval_custom_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, model, config):
self.max_new_tokens = self.eval_cfg.get('max_new_tokens', 32)

@torch.no_grad()
def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
def eval_func(self, model, testenc, seq_len, bs, eval_pos):
responses = []
for data in testenc:
data = {
Expand Down
2 changes: 1 addition & 1 deletion llmc/eval/eval_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class PerplexityEval(BaseEval):

@torch.no_grad()
def eval_func(self, org_model, model, testenc, seq_len, bs, eval_pos):
def eval_func(self, model, testenc, seq_len, bs, eval_pos):
testenc = testenc.input_ids
nsamples = testenc.numel() // seq_len

Expand Down

0 comments on commit 2654cc7

Please sign in to comment.