-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathcli_lean4_random.yaml
51 lines (49 loc) · 1.24 KB
/
cli_lean4_random.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
seed_everything: 3407 # https://arxiv.org/abs/2109.08203
trainer:
accelerator: gpu
devices: 1
precision: bf16-mixed
strategy:
class_path: pytorch_lightning.strategies.DeepSpeedStrategy
init_args:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: null
save_dir: null
gradient_clip_val: 1.0
max_steps: 800000
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
verbose: true
save_top_k: 1
save_last: true
monitor: Recall@10_val
mode: max
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: Recall@10_val
patience: 5
mode: max
verbose: true
model:
model_name: google/byt5-small
lr: 1e-4
warmup_steps: 2000
num_retrieved: 100
data:
data_path: data/leandojo_benchmark_4/random/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
num_negatives: 3
num_in_file_negatives: 1
batch_size: 8
eval_batch_size: 64
max_seq_len: 1024
num_workers: 4