-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathprompts_optimization_comfig.py
50 lines (48 loc) · 1.6 KB
/
prompts_optimization_comfig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class PromptsOptimizationConfig:
num_prompts: Optional[int] = field(
default=3,
metadata={"help": "Number of optimized prompts."},
)
prompt_len: Optional[int] = field(
default=128,
metadata={"help": "Prompt length of each prompt"},
)
dissim_coef: Optional[float] = field(
default=0.3,
metadata={"help": "Used in aux loss for prompts similarity penalty"},
)
special_token_coef: Optional[float] = field(
default=0.8,
metadata={
"help": "Used in aux loss for penalty of using forbidden (special) tokens"
},
)
gumbel_temp: Optional[float] = field(
default=0.5,
metadata={"help": "Temperature for gumbel softmax trick"},
)
gumbel_noise_scale: Optional[float] = field(
default=0.05,
metadata={"help": "Multiplier of added gumbel noise inside softmax"},
)
forbidden_token_ids: Optional[List[int]] = field(
default=None,
metadata={"help": "List of ids of forbidden tokens in created prompts"},
)
inserted_chat_role: str = field(
default="system",
metadata={"help": "Chat role used for templating of created prompts insertion"},
)
fused_forward: bool = field(
default=True,
metadata={
"help": "Use full in-batch forward, instead of for loop, memory usage increase."
},
)
init_prompt: Optional[str] = field(
default=None,
metadata={"help": "Prompt to init optimization from"},
)