diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 27753707..b135ff49 100644 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -76,178 +76,6 @@ def build_calib_dataset(self): else: self.calib_dataset = load_from_disk(self.calib_dataset_path) - def get_calib_samples(self): - if self.calib_dataset_name == 'custom_txt' or self.calib_dataset_name == 'custom_mm': - samples = self.calib_dataset - else: - preproc = PREPROC_REGISTRY[self.preproc] - preproc_param_dict = { - 'calib_dataset': self.calib_dataset, - 'tokenizer': self.tokenizer, - 'n_samples': self.n_samples, - 'seq_len': self.seq_len - } - if self.preproc == 'txt_general_preproc': - preproc_param_dict['key'] = self.key - samples = preproc(**preproc_param_dict) - return samples - - def get_pad_setting(self, length): - if self.tokenizer.padding_side == 'left': - return [length, 0] - elif self.tokenizer.padding_side == 'right': - return [0, length] - else: - raise Exception(f'Not support padding_side: {self.tokenizer.padding_side}.') - - def txt_group_samples_with_mask(self, samples): - calib_samples = [] - input_ids = [] - attention_mask = [] - pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) - if self.calib_bs < 0: - samples_len = [sample.shape[-1] for sample in samples] - max_len = max(samples_len) - samples_tmp = [] - attention_mask_tmp = [] - for sample in samples: - samples_tmp.append( - F.pad( - sample, - self.get_pad_setting(max_len - sample.shape[-1]), - value=pad_token_id - ) - ) - attention_mask_tmp.append( - F.pad( - torch.ones(1, sample.shape[-1], dtype=torch.int64), - self.get_pad_setting(max_len - sample.shape[-1]), - value=0 - ) - ) - batch_input_ids = torch.cat(samples_tmp, dim=0) - batch_attention_mask = torch.cat(attention_mask_tmp, dim=0) - calib_samples.append( - {'input_ids': batch_input_ids, 'attention_mask': batch_attention_mask} - ) - elif self.calib_bs == 1: - input_ids = samples - attention_mask = [torch.ones(1, sample.shape[-1], dtype=torch.int64) for sample in samples] # noqa - for i in range(len(samples)): - calib_samples.append( - {'input_ids': input_ids[i], 'attention_mask': attention_mask[i]} - ) - elif self.calib_bs > 1: - for i in range(0, len(samples), self.calib_bs): - start = i - end = min(i + self.calib_bs, len(samples)) - batch_samples = samples[start:end] - batch_samples_len = [sample.shape[-1] for sample in batch_samples] - batch_max_len = max(batch_samples_len) - samples_tmp = [] - attention_mask_tmp = [] - for sample in batch_samples: - samples_tmp.append( - F.pad( - sample, - self.get_pad_setting(batch_max_len - sample.shape[-1]), - value=pad_token_id - ) - ) - attention_mask_tmp.append( - F.pad( - torch.ones(1, sample.shape[-1], dtype=torch.int64), - self.get_pad_setting(batch_max_len - sample.shape[-1]), - value=0 - ) - ) - batch_input_ids = torch.cat(samples_tmp, dim=0) - batch_attention_mask = torch.cat(attention_mask_tmp, dim=0) - calib_samples.append( - { - 'input_ids': batch_input_ids, - 'attention_mask': batch_attention_mask - } - ) - return calib_samples - - def txt_group_samples_wo_mask(self, samples): # without mask - calib_samples = [] - if self.calib_bs < 0: - batch = torch.cat(samples, dim=0) - calib_samples.append({'input_ids': batch}) - elif self.calib_bs == 1: - for i in range(len(samples)): - calib_samples.append({'input_ids': samples[i]}) - elif self.calib_bs > 1: - for i in range(0, len(samples), self.calib_bs): - start = i - end = min(i + self.calib_bs, len(samples)) - batch = samples[start:end] - batch = torch.cat(batch, dim=0) - calib_samples.append({'input_ids': batch}) - return calib_samples - - def img_txt_group_samples_with_mask(self, samples): - calib_samples = [] - if self.calib_bs < 0: - calib_samples.append(self.batch_process(samples, calib_or_eval='calib')) - elif self.calib_bs == 1: - calib_samples = [self.batch_process([sample], calib_or_eval='calib') for sample in samples] # noqa - elif self.calib_bs > 1: - for i in range(0, len(samples), self.calib_bs): - start = i - end = min(i + self.calib_bs, len(samples)) - batch = samples[start:end] - calib_samples.append(self.batch_process(batch, calib_or_eval='calib')) - return calib_samples - - def audio_txt_group_samples_with_mask(self, samples): - calib_samples = [] - if self.calib_bs < 0: - calib_samples.append(self.batch_process(samples, calib_or_eval='calib')) - elif self.calib_bs == 1: - calib_samples = [self.batch_process([sample], calib_or_eval='calib') for sample in samples] # noqa - elif self.calib_bs > 1: - for i in range(0, len(samples), self.calib_bs): - start = i - end = min(i + self.calib_bs, len(samples)) - batch = samples[start:end] - calib_samples.append(self.batch_process(batch, calib_or_eval='calib')) - return calib_samples - - def audio_img_txt_group_samples_with_mask(self, samples): - calib_samples = [] - if self.calib_bs < 0: - calib_samples.append(self.batch_process(samples, calib_or_eval='calib')) - elif self.calib_bs == 1: - calib_samples = [self.batch_process([sample], calib_or_eval='calib') for sample in samples] # noqa - elif self.calib_bs > 1: - for i in range(0, len(samples), self.calib_bs): - start = i - end = min(i + self.calib_bs, len(samples)) - batch = samples[start:end] - calib_samples.append(self.batch_process(batch, calib_or_eval='calib')) - return calib_samples - - def img_group_samples_wo_mask(self, samples): # without mask - calib_samples = [] - if self.calib_bs < 0: - batch = {'pixel_values': torch.cat([sample['pixel_values'] - for sample in samples], dim=0)} - calib_samples.append(batch) - elif self.calib_bs == 1: - calib_samples = samples - elif self.calib_bs > 1: - for i in range(0, len(samples), self.calib_bs): - start = i - end = min(i + self.calib_bs, len(samples)) - batch = samples[start:end] - batch = {'pixel_values': torch.cat([sample['pixel_values'] - for sample in batch], dim=0)} - calib_samples.append(batch) - return calib_samples - def get_calib_model_inputs(self, samples): if not self.padding: assert not self.calib_dataset_name == 'custom_mm'