-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgist.py
101 lines (97 loc) · 3.58 KB
/
gist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class GISTArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
gist_loss_type: Optional[str] = field(
default=None,
metadata={
"help": (
"The type of loss to use, defaults to contrastive. Can be contrastive, improved_contrastive, triplet_contrastive, orthogonal, hierarchical_contrastive, guided, guided-triplet, or guided-triplet-soft."
)
},
)
gist_output_dim: Optional[int] = field(
default=None,
metadata={
"help": (
"The output dimension of the model. When None, it will be the same as the experts."
)
},
)
gist_normalize: bool = field(
default=True,
metadata={"help": "Whether to normalize the embeddings."},
)
gist_denormalize_experts: bool = field(
default=False,
metadata={"help": "Whether to denormalize the experts. This invalidates the use of gist_normalize_experts."},
)
gist_schedule_cl_temperature: bool = field(
default=False,
metadata={"help": "Whether to schedule the contrastive loss temperature. This will overrider the cl_temperature argument."},
)
gist_cl_temperature_decay_rate: float = field(
default=0.9999,
metadata={"help": "The decay rate for the contrastive loss temperature."},
)
gist_cl_temperature_init: float = field(
default=1.0,
metadata={"help": "The initial contrastive loss temperature."},
)
gist_cl_temperature_min: float = field(
default=0.001,
metadata={"help": "The minimum contrastive loss temperature."},
)
gist_orthogonal_loss_margin: float = field(
default=0.0,
metadata={"help": "The margin for the cosine/orthogonal loss."},
)
gist_use_query_instruction: bool = field(
default=False,
metadata={"help": "Whether to use query instruction."},
)
gist_medi_data_name: str = field(
default="medi-data.json",
metadata={"help": "The name of the medi data."},
)
gist_hcl_num_subembeddings: int = field(
default=1,
metadata={"help": "The number of subembeddings for the hierarchical contrastive loss."},
)
gist_freeze_base_num_steps: int = field(
default=0,
metadata={"help": "The number of steps to freeze the base model. If 0, the base model will not be frozen."},
)
gist_guide_model_name_or_path: str = field(
default=None,
metadata={"help": "The guide model for identifying hard negatives. If this is provided, the `MixEmbGuidedTrainer` will be used."}
)
gist_medi_data_name_revision: str = field(
default=None,
metadata={"help": "The revision for the dataset if medi_data_name is from Hf Hub."}
)
gist_script_id: str = field(
default=None,
metadata={"help": "The script id is for validating the parameters."}
)
gist_cl_temperature: Optional[float] = field(
default=None,
metadata={"help": "contrastive temperature"},
)
gist_tl_margin: Optional[float] = field(
default=None,
metadata={"help": "margin for triplet loss"},
)
gist_auto_model_pooling: Optional[str] = field(
default="mean",
metadata={"help": "auto model pooling"},
)
gist_negative_mode: Optional[str] = field(
default="all",
metadata={"help": "negative mode. Can be all, hard, or hard+random"},
)
def __post_init__(self):
pass