Skip to content

Commit 3a38913

Browse files
committed
Add more arguments to accept values passed via the cmd line.
1 parent e960dfd commit 3a38913

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

eval/humaneval.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22
import sys, os
3+
from email.policy import default
4+
35
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
46
from human_eval.data import write_jsonl, read_problems
57
from exllamav2 import model_init
@@ -21,7 +23,10 @@
2123
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
2224
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
2325
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
24-
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6")
26+
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6", default = 0.6)
27+
parser.add_argument("--top_k", type = int, help = "Top-k sampling, default: 50", default = 50)
28+
parser.add_argument("--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6)
29+
parser.add_argument("-trp", "--token_repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0)
2530
model_init.add_args(parser)
2631
args = parser.parse_args()
2732

@@ -118,10 +123,10 @@
118123
)
119124

120125
gen_settings = ExLlamaV2Sampler.Settings(
121-
token_repetition_penalty = 1.0,
122-
temperature = args.temperature if args.temperature is not None else 0.6,
123-
top_k = 50,
124-
top_p = 0.6
126+
token_repetition_penalty=args.token_repetition_penalty,
127+
temperature=args.temperature,
128+
top_k=args.top_k,
129+
top_p=args.top_p
125130
)
126131

127132
# Get problems

0 commit comments

Comments
 (0)