Skip to content

Commit

Permalink
remove some code
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Dec 14, 2024
1 parent 3a0006b commit cf802af
Showing 1 changed file with 0 additions and 172 deletions.
172 changes: 0 additions & 172 deletions llmc/data/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,178 +76,6 @@ def build_calib_dataset(self):
else:
self.calib_dataset = load_from_disk(self.calib_dataset_path)

def get_calib_samples(self):
if self.calib_dataset_name == 'custom_txt' or self.calib_dataset_name == 'custom_mm':
samples = self.calib_dataset
else:
preproc = PREPROC_REGISTRY[self.preproc]
preproc_param_dict = {
'calib_dataset': self.calib_dataset,
'tokenizer': self.tokenizer,
'n_samples': self.n_samples,
'seq_len': self.seq_len
}
if self.preproc == 'txt_general_preproc':
preproc_param_dict['key'] = self.key
samples = preproc(**preproc_param_dict)
return samples

def get_pad_setting(self, length):
if self.tokenizer.padding_side == 'left':
return [length, 0]
elif self.tokenizer.padding_side == 'right':
return [0, length]
else:
raise Exception(f'Not support padding_side: {self.tokenizer.padding_side}.')

def txt_group_samples_with_mask(self, samples):
calib_samples = []
input_ids = []
attention_mask = []
pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
if self.calib_bs < 0:
samples_len = [sample.shape[-1] for sample in samples]
max_len = max(samples_len)
samples_tmp = []
attention_mask_tmp = []
for sample in samples:
samples_tmp.append(
F.pad(
sample,
self.get_pad_setting(max_len - sample.shape[-1]),
value=pad_token_id
)
)
attention_mask_tmp.append(
F.pad(
torch.ones(1, sample.shape[-1], dtype=torch.int64),
self.get_pad_setting(max_len - sample.shape[-1]),
value=0
)
)
batch_input_ids = torch.cat(samples_tmp, dim=0)
batch_attention_mask = torch.cat(attention_mask_tmp, dim=0)
calib_samples.append(
{'input_ids': batch_input_ids, 'attention_mask': batch_attention_mask}
)
elif self.calib_bs == 1:
input_ids = samples
attention_mask = [torch.ones(1, sample.shape[-1], dtype=torch.int64) for sample in samples] # noqa
for i in range(len(samples)):
calib_samples.append(
{'input_ids': input_ids[i], 'attention_mask': attention_mask[i]}
)
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch_samples = samples[start:end]
batch_samples_len = [sample.shape[-1] for sample in batch_samples]
batch_max_len = max(batch_samples_len)
samples_tmp = []
attention_mask_tmp = []
for sample in batch_samples:
samples_tmp.append(
F.pad(
sample,
self.get_pad_setting(batch_max_len - sample.shape[-1]),
value=pad_token_id
)
)
attention_mask_tmp.append(
F.pad(
torch.ones(1, sample.shape[-1], dtype=torch.int64),
self.get_pad_setting(batch_max_len - sample.shape[-1]),
value=0
)
)
batch_input_ids = torch.cat(samples_tmp, dim=0)
batch_attention_mask = torch.cat(attention_mask_tmp, dim=0)
calib_samples.append(
{
'input_ids': batch_input_ids,
'attention_mask': batch_attention_mask
}
)
return calib_samples

def txt_group_samples_wo_mask(self, samples): # without mask
calib_samples = []
if self.calib_bs < 0:
batch = torch.cat(samples, dim=0)
calib_samples.append({'input_ids': batch})
elif self.calib_bs == 1:
for i in range(len(samples)):
calib_samples.append({'input_ids': samples[i]})
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
batch = torch.cat(batch, dim=0)
calib_samples.append({'input_ids': batch})
return calib_samples

def img_txt_group_samples_with_mask(self, samples):
calib_samples = []
if self.calib_bs < 0:
calib_samples.append(self.batch_process(samples, calib_or_eval='calib'))
elif self.calib_bs == 1:
calib_samples = [self.batch_process([sample], calib_or_eval='calib') for sample in samples] # noqa
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
calib_samples.append(self.batch_process(batch, calib_or_eval='calib'))
return calib_samples

def audio_txt_group_samples_with_mask(self, samples):
calib_samples = []
if self.calib_bs < 0:
calib_samples.append(self.batch_process(samples, calib_or_eval='calib'))
elif self.calib_bs == 1:
calib_samples = [self.batch_process([sample], calib_or_eval='calib') for sample in samples] # noqa
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
calib_samples.append(self.batch_process(batch, calib_or_eval='calib'))
return calib_samples

def audio_img_txt_group_samples_with_mask(self, samples):
calib_samples = []
if self.calib_bs < 0:
calib_samples.append(self.batch_process(samples, calib_or_eval='calib'))
elif self.calib_bs == 1:
calib_samples = [self.batch_process([sample], calib_or_eval='calib') for sample in samples] # noqa
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
calib_samples.append(self.batch_process(batch, calib_or_eval='calib'))
return calib_samples

def img_group_samples_wo_mask(self, samples): # without mask
calib_samples = []
if self.calib_bs < 0:
batch = {'pixel_values': torch.cat([sample['pixel_values']
for sample in samples], dim=0)}
calib_samples.append(batch)
elif self.calib_bs == 1:
calib_samples = samples
elif self.calib_bs > 1:
for i in range(0, len(samples), self.calib_bs):
start = i
end = min(i + self.calib_bs, len(samples))
batch = samples[start:end]
batch = {'pixel_values': torch.cat([sample['pixel_values']
for sample in batch], dim=0)}
calib_samples.append(batch)
return calib_samples

def get_calib_model_inputs(self, samples):
if not self.padding:
assert not self.calib_dataset_name == 'custom_mm'
Expand Down

0 comments on commit cf802af

Please sign in to comment.