|
1 | 1 | from __future__ import annotations
|
2 |
| -import sys, os |
| 2 | + |
| 3 | +import os |
| 4 | +import sys |
| 5 | + |
3 | 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
4 | 7 | from human_eval.data import write_jsonl, read_problems
|
5 | 8 | from exllamav2 import model_init
|
|
21 | 24 | parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
|
22 | 25 | parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
|
23 | 26 | parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
|
24 |
| -parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6") |
| 27 | +parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6", default = 0.6) |
| 28 | +parser.add_argument("--top_k", type = int, help = "Top-k sampling, default: 50", default = 50) |
| 29 | +parser.add_argument("--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6) |
| 30 | +parser.add_argument("-trp", "--token_repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0) |
25 | 31 | model_init.add_args(parser)
|
26 | 32 | args = parser.parse_args()
|
27 | 33 |
|
|
118 | 124 | )
|
119 | 125 |
|
120 | 126 | gen_settings = ExLlamaV2Sampler.Settings(
|
121 |
| - token_repetition_penalty = 1.0, |
122 |
| - temperature = 0.6, |
123 |
| - top_k = 50, |
124 |
| - top_p = 0.6 |
| 127 | + token_repetition_penalty=args.token_repetition_penalty, |
| 128 | + temperature=args.temperature, |
| 129 | + top_k=args.top_k, |
| 130 | + top_p=args.top_p |
125 | 131 | )
|
126 | 132 |
|
127 | 133 | # Get problems
|
|
0 commit comments