Skip to content

Commit 10149dc

Browse files
Added jsonl dump and costing
1 parent 36e5d24 commit 10149dc

File tree

2 files changed

+81
-16
lines changed

2 files changed

+81
-16
lines changed

src/target_tools/llms/src/runner.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def model_evaluation_openai(
199199

200200
prompts = [x["prompt"] for x in id_mapping.values()]
201201

202+
utils.get_prompt_cost(prompts)
203+
utils.dump_ft_jsonl(id_mapping, f"{results_dst}/ft_dataset.jsonl")
204+
utils.dump_batch_prompt_jsonl(id_mapping, f"{results_dst}/batch_prompt.jsonl")
205+
202206
request_outputs = openai_helpers.process_requests(
203207
model_name,
204208
prompts,
@@ -272,23 +276,31 @@ def main_runner(args, runner_config, models_to_run, openai_models_models_to_run)
272276
else:
273277
model_path = model["lora_repo"]
274278

275-
pipe = transformers_helpers.load_model_and_configurations(
276-
args.hf_token, model_path, model["quantization"], TEMPARATURE
277-
)
278-
model_start_time = time.time()
279-
model_evaluation_transformers(
280-
model["name"],
281-
args.prompt_id,
282-
python_files,
283-
pipe,
284-
results_dst,
285-
use_system_prompt=model["use_system_prompt"],
286-
batch_size=model["batch_size"],
287-
)
279+
pipe = None
280+
try:
281+
pipe = transformers_helpers.load_model_and_configurations(
282+
args.hf_token, model_path, model["quantization"], TEMPARATURE
283+
)
284+
model_start_time = time.time()
285+
model_evaluation_transformers(
286+
model["name"],
287+
args.prompt_id,
288+
python_files,
289+
pipe,
290+
results_dst,
291+
use_system_prompt=model["use_system_prompt"],
292+
batch_size=model["batch_size"],
293+
)
288294

289-
del pipe
290-
gc.collect()
291-
torch.cuda.empty_cache()
295+
except Exception as e:
296+
logger.error(f"Error in model {model['name']}: {e}")
297+
error_count += 1
298+
traceback.print_exc()
299+
finally:
300+
if pipe is not None:
301+
del pipe
302+
gc.collect()
303+
torch.cuda.empty_cache()
292304

293305
logger.info(
294306
f"Model {model['name']} finished in {time.time()-model_start_time:.2f} seconds"

src/target_tools/llms/src/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import prompts
1111
import copy
12+
import tiktoken
1213

1314
logger = logging.getLogger("runner")
1415
logger.setLevel(logging.DEBUG)
@@ -262,6 +263,58 @@ def get_prompt(prompt_id, file_path, answers_placeholders=True, use_system_promp
262263
return prompt
263264

264265

266+
def dump_ft_jsonl(id_mapping, output_file):
267+
mappings = copy.deepcopy(id_mapping)
268+
for _m in mappings.values():
269+
print(_m)
270+
assistant_message = {
271+
"role": "assistant",
272+
"content": generate_answers_for_fine_tuning(_m["json_filepath"]),
273+
}
274+
_m["prompt"].append(assistant_message)
275+
276+
prompts = [x["prompt"] for x in mappings.values()]
277+
278+
with open(output_file, "w") as output:
279+
for _m in prompts:
280+
output.write(json.dumps(_m))
281+
output.write("\n")
282+
283+
284+
def dump_batch_prompt_jsonl(id_mapping, output_file):
285+
prompts = [x["prompt"] for x in id_mapping.values()]
286+
287+
with open(output_file, "w") as output:
288+
for _m in prompts:
289+
output.write(json.dumps(_m))
290+
output.write("\n")
291+
292+
293+
def get_prompt_cost(prompts):
294+
"""
295+
Retrieves the token count of the given text.
296+
297+
Args:
298+
text (str): The text to be tokenized.
299+
300+
Returns:
301+
int: The token count.
302+
"""
303+
304+
prices_per_token = {
305+
"gpt-4o": 0.000005,
306+
"gpt-4o-mini": 0.00000015,
307+
}
308+
309+
for model, price in prices_per_token.items():
310+
encoding = tiktoken.encoding_for_model(model)
311+
number_of_tokens = len(encoding.encode(str(prompts)))
312+
logger.info(
313+
f"Number of tokens for model `{model}`: {number_of_tokens}"
314+
+ f" Cost: {number_of_tokens * price:.5f}"
315+
)
316+
317+
265318
# Example usage:
266319
# loader = ConfigLoader("models_config.yaml")
267320
# loader.load_config()

0 commit comments

Comments
 (0)