-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathrun_commonsense.py
656 lines (574 loc) · 25.4 KB
/
run_commonsense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
#!/usr/bin/env python
# coding=utf-8
import copy
import json
import logging
import os
import re
import sys
from dataclasses import dataclass
from dataclasses import field
from typing import Optional
import datasets
import torch
import transformers
from datasets import load_dataset
from peft import LoraConfig
from peft import PeftModel
from peft import get_peft_model
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import HfArgumentParser
from transformers import Trainer
from transformers import TrainingArguments
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from nncf import NNCFConfig
from nncf.experimental.torch.nas.bootstrapNAS.elasticity.elasticity_dim import ElasticityDim
from nncf.experimental.torch.nas.bootstrapNAS.training.model_creator_helpers import (
create_compressed_model_from_algo_names,
)
from nncf.experimental.torch.nas.bootstrapNAS import BaseSearchAlgorithm
from nncf.torch.model_creation import create_nncf_network
check_min_version("4.31.0")
logger = logging.getLogger(__name__)
TEST_DATASETS = ["boolq", "piqa", "social_i_qa", "winogrande", "ARC-Easy", "ARC-Challenge", "openbookqa", "hellaswag"]
@dataclass
class LonasTrainingArguments(TrainingArguments):
lora_r: int = field(default=32, metadata={"help": "Lora R dimension."})
lora_alpha: float = field(default=64, metadata={"help": " Lora alpha."})
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout."})
target_modules: str = field(
default="q_proj,v_proj", metadata={"help": "Which module will be added the lora adapter."}
)
lora: bool = field(default=False, metadata={"help": "Whether to apply lora or not."})
train_on_inputs: bool = field(default=True)
do_test: bool = field(default=False)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_path: Optional[str] = field(default=None, metadata={"help": "The path of the dataset to use."})
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
val_set_size: int = field(default=120)
max_seq_length: int = field(
default=128,
metadata={
"help": (
"The maximal total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": (
"Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximal length in the batch."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the training data."}
)
validation_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the validation data."}
)
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
cutoff_len: int = field(
default=256,
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
lora_weights: str = field(default=None)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, LonasTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
# load model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
cache_dir=model_args.cache_dir,
)
if training_args.lora and model_args.lora_weights is None:
logger.info("adding LoRA modules...")
config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
target_modules=training_args.target_modules.split(","),
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
elif training_args.lora:
logger.info("Loading LoRA modules...")
model = PeftModel.from_pretrained(model, model_args.lora_weights, torch_dtype=torch.float16, device_map="auto")
nncf_config = None
if training_args.nncf_config is not None:
nncf_config = NNCFConfig.from_json(training_args.nncf_config)
if nncf_config.get("log_dir") is None:
nncf_config["log_dir"] = training_args.output_dir
if not os.path.exists(training_args.output_dir) and training_args.local_rank in [-1, 0]:
os.makedirs(nncf_config["log_dir"])
compression_ctrl = None
if nncf_config is not None:
nncf_network = create_nncf_network(model, nncf_config)
algo_name = nncf_config.get("bootstrapNAS", {}).get("training", {}).get("algorithm", "progressive_shrinking")
compression_ctrl, model = create_compressed_model_from_algo_names(
nncf_network, nncf_config, algo_names=[algo_name]
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left" # Allow batched inference
# Load data
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=data_args.cutoff_len,
padding=True,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < data_args.cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_prompt(data_point):
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:
{data_point["output"]}"""
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenize(full_prompt)
if not training_args.train_on_inputs:
user_prompt = generate_prompt({**data_point, "output": ""})
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
]
return tokenized_full_prompt
train_dataset, eval_dataset = None, None
if training_args.do_train or training_args.do_search:
data = load_dataset("json", data_files=data_args.dataset_path)
val_set_size = data_args.val_set_size
if val_set_size > 0:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_dataset = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
eval_dataset = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
else:
train_dataset = data["train"].shuffle().map(generate_and_tokenize_prompt)
eval_dataset = None
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
compression_ctrl=compression_ctrl,
)
if nncf_config is not None:
if not (training_args.local_rank in [-1, 0] or training_args.no_cuda):
compression_ctrl.distributed()
model.config.use_cache = False
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.save_model() # Saves the tokenizer too for easy upload
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
def extract_answer(dataset_name, sentence: str) -> float:
dataset = dataset_name
if dataset == "boolq":
sentence_ = sentence.strip()
pred_answers = re.findall(r"true|false", sentence_)
if not pred_answers:
return ""
return pred_answers[0]
elif dataset == "piqa":
sentence_ = sentence.strip()
pred_answers = re.findall(r"solution1|solution2", sentence_)
if not pred_answers:
return ""
return pred_answers[0]
elif dataset in ["social_i_qa", "ARC-Challenge", "ARC-Easy", "openbookqa"]:
sentence_ = sentence.strip()
pred_answers = re.findall(r"answer1|answer2|answer3|answer4|answer5", sentence_)
if not pred_answers:
return ""
return pred_answers[0]
elif dataset == "hellaswag":
sentence_ = sentence.strip()
pred_answers = re.findall(r"ending1|ending2|ending3|ending4", sentence_)
if not pred_answers:
return ""
return pred_answers[0]
elif dataset == "winogrande":
sentence_ = sentence.strip()
pred_answers = re.findall(r"option1|option2", sentence_)
if not pred_answers:
return ""
return pred_answers[0]
def load_test_data(test_dataset) -> list:
"""
read data from dataset file
"""
file_path = f"datasets/{test_dataset}/test.json"
if not os.path.exists(file_path):
raise FileNotFoundError(f"can not find dataset file : {file_path}")
json_data = json.load(open(file_path, "r"))
return json_data
def generate_prompt_eval(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
def evaluate_one_sample(
instruction,
input=None,
model=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=32,
**kwargs,
):
prompts = generate_prompt_eval(instruction, input)
inputs = tokenizer(prompts, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
use_cache=True,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return output.split("### Response:")[1].strip()
def evaluate(model, dataset_name, save_file):
model.eval()
dataset = load_test_data(dataset_name)
total = len(dataset)
correct = 0
output_data = []
for idx, data in enumerate(dataset):
instruction = data.get("instruction")
output = evaluate_one_sample(instruction, model=model)
label = data.get("answer")
flag = False
predict = extract_answer(dataset_name, output)
if label == predict:
correct += 1
flag = True
new_data = copy.deepcopy(data)
new_data["output_pred"] = output
new_data["pred"] = predict
new_data["flag"] = flag
output_data.append(new_data)
print(data["instruction"])
print(output)
print("prediction:", predict)
print("label:", label)
print(f"\rtest:{idx + 1}/{total} | accuracy {correct} {correct / (idx + 1)}")
with open(save_file, "w+") as f:
json.dump(output_data, f, indent=4)
acc = correct / total
return acc
def test_subnetwork(subnetwork, name):
logger.info(f"*** Evaluation - {name} ***")
non_zero_params = sum([(param.data != 0).sum().item() for _, param in subnetwork.named_parameters()])
macs, weights = trainer.compression_ctrl.multi_elasticity_handler.count_flops_and_weights_for_active_subnet()
metrics = {
f"{name}_non_zero_params": non_zero_params,
f"{name}_macs": str(macs / 2000000),
f"{name}_weights": str(weights),
}
trainer.save_metrics("eval", metrics)
trainer.log_metrics("eval", metrics)
all_results = []
metrics = {}
for test_dataset in TEST_DATASETS:
logger.info(f"*** Evaluation on {test_dataset} ***")
save_file = os.path.join(training_args.output_dir, f"{name}.{test_dataset}.res.json")
accuracy = evaluate(subnetwork, test_dataset, save_file)
all_results.append(accuracy)
metrics[f"{name}_{test_dataset}_accuracy"] = accuracy
trainer.save_metrics("eval", metrics)
metrics[f"{name}_avg_accuracy"] = sum(all_results) / len(all_results)
trainer.save_metrics("eval", metrics)
trainer.log_metrics("eval", metrics)
if training_args.do_test and training_args.local_rank <= 0:
if compression_ctrl is not None:
trainer.compression_ctrl.multi_elasticity_handler.enable_all()
compression_ctrl.multi_elasticity_handler.width_handler.width_num_params_indicator = -1
# Heuristic subnetwork
heuristic_config = {
k: v[(len(v) - 1) // 2] for k, v in compression_ctrl.multi_elasticity_handler.width_search_space.items()
}
heuristic_config = {ElasticityDim.WIDTH: heuristic_config}
trainer.compression_ctrl.multi_elasticity_handler.activate_subnet_for_config(heuristic_config)
test_subnetwork(trainer.model, "Heuristic")
else:
all_results = []
for test_dataset in TEST_DATASETS:
logger.info(f"*** Evaluation on {test_dataset} ***")
save_file = os.path.join(training_args.output_dir, f"{test_dataset}.res.json")
non_zero_params = sum([(param.data != 0).sum().item() for _, param in trainer.model.named_parameters()])
accuracy = evaluate(trainer.model, test_dataset, save_file)
all_results.append(accuracy)
metrics = {
f"{test_dataset}_accuracy": accuracy,
"non_zero_params": non_zero_params,
}
trainer.save_metrics("eval", metrics)
avg_metrics = {
"avg_accuracy": sum(all_results) / len(all_results),
}
trainer.save_metrics("eval", avg_metrics)
trainer.log_metrics("eval", avg_metrics)
# Searching
if training_args.do_search and nncf_config is not None and training_args.local_rank <= 0:
logger.info("*** Search ***")
trainer.compression_ctrl.multi_elasticity_handler.enable_all()
search_algo = BaseSearchAlgorithm.from_config(trainer.model, trainer.compression_ctrl, nncf_config)
def validate_model_fn(model_, eval_dataset):
correct = 0
for data in eval_dataset:
instruction = data.get('instruction')
output = evaluate_one_sample(instruction, model=model_)
label = data.get('answer')
dataset_name = None
# which dataset
# TODO: Refactor hard-coded values for better flexibility and maintainability.
if label in ['true', 'false']:
dataset_name = "boolq"
elif 'solution' in label:
dataset_name = "piqa"
elif 'answer' in label:
dataset_name = "social_i_qa" # "ARC-Challenge", "ARC-Easy", "openbookqa"
elif 'ending' in label:
dataset_name = "hellaswag"
elif 'option' in label:
dataset_name = "winogrande"
predict = extract_answer(dataset_name, output)
if label == predict:
correct += 1
acc = correct / len(eval_dataset)
return acc
# Test Maximal subnetwork and Heuristic subnetwork on the validation dataset
# Maximal
trainer.compression_ctrl.multi_elasticity_handler.activate_supernet()
max_eval_acc = validate_model_fn(trainer.model, eval_dataset)
# Heuristic
compression_ctrl.multi_elasticity_handler.width_handler.width_num_params_indicator = -1
heuristic_config = {k: v[(len(v) - 1) // 2]
for k, v in compression_ctrl.multi_elasticity_handler.width_search_space.items()}
heuristic_config = {
ElasticityDim.WIDTH: heuristic_config
}
trainer.compression_ctrl.multi_elasticity_handler.activate_subnet_for_config(heuristic_config)
heu_eval_acc = validate_model_fn(trainer.model, eval_dataset)
metrics = {
"val_maximal_accuracy": max_eval_acc,
"val_heuristic_accuracy": heu_eval_acc,
}
trainer.save_metrics("eval", metrics)
trainer.log_metrics("eval", metrics)
elasticity_ctrl, best_config, performance_metrics = search_algo.run(
validate_model_fn, eval_dataset, training_args.output_dir
)
search_algo.search_progression_to_csv()
search_algo.evaluators_to_csv()
search_algo.visualize_search_progression()
logger.info("Best config: {best_config}".format(best_config=best_config))
logger.info("Performance metrics: {performance_metrics}".format(performance_metrics=performance_metrics))
trainer.save_metrics("eval", {
"performance_metrics": list(performance_metrics)
})
# test best config
trainer.compression_ctrl.multi_elasticity_handler.activate_subnet_for_config(best_config)
best_eval_acc = validate_model_fn(trainer.model, eval_dataset)
trainer.save_metrics("eval", {
"val_best_accuracy": best_eval_acc
})
kwargs = {"finetuned_from": model_args.model_name_or_path}
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(**kwargs)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()