Skip to content

Commit e82676d

Browse files
FIxes for batch prompt creation
1 parent 10149dc commit e82676d

File tree

5 files changed

+161
-9
lines changed

5 files changed

+161
-9
lines changed

extras/openai_batch_inference.ipynb

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"https://platform.openai.com/docs/guides/batch/model-availability"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 7,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"from openai import OpenAI\n",
17+
"import os\n",
18+
"import yaml\n",
19+
"\n",
20+
"# os.environ.get(\"OPENAI_API_KEY\")\n",
21+
"# Get config\n",
22+
"CONFIG_FILE = \"../src/config.yaml\"\n",
23+
"\n",
24+
"if os.path.exists(CONFIG_FILE):\n",
25+
" with open(CONFIG_FILE, \"r\") as file:\n",
26+
" config = yaml.safe_load(file)\n",
27+
"else:\n",
28+
" config = {}\n",
29+
"\n",
30+
"os.environ[\"OPENAI_API_KEY\"] = config[\"llm\"][\"openai_key\"]\n",
31+
"client = OpenAI()\n"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": 8,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"\n",
41+
"file_path = \"../.scrapy/results/gpt-4o-mini/autogen_typeevalpy_benchmark/batch_prompt.jsonl\"\n",
42+
"\n",
43+
"batch_input_file = client.files.create(\n",
44+
" file=open(file_path, \"rb\"),\n",
45+
" purpose=\"batch\"\n",
46+
")"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 9,
52+
"metadata": {},
53+
"outputs": [
54+
{
55+
"data": {
56+
"text/plain": [
57+
"Batch(id='batch_Ufc5EkBcAbSceD4XzNF67KjO', completion_window='24h', created_at=1725549306, endpoint='/v1/chat/completions', input_file_id='file-PyayymAZ5TCcbTZUZ5nPN5zI', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1725635706, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'Batch Trial'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))"
58+
]
59+
},
60+
"execution_count": 9,
61+
"metadata": {},
62+
"output_type": "execute_result"
63+
}
64+
],
65+
"source": [
66+
"batch_input_file_id = batch_input_file.id\n",
67+
"\n",
68+
"client.batches.create(\n",
69+
" input_file_id=batch_input_file_id,\n",
70+
" endpoint=\"/v1/chat/completions\",\n",
71+
" completion_window=\"24h\",\n",
72+
" metadata={\n",
73+
" \"description\": \"Batch Trial\"\n",
74+
" }\n",
75+
")"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": 11,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"from openai import OpenAI\n",
85+
"client = OpenAI()\n",
86+
"\n",
87+
"file_info = client.batches.retrieve(\"batch_Ufc5EkBcAbSceD4XzNF67KjO\")"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 14,
93+
"metadata": {},
94+
"outputs": [
95+
{
96+
"name": "stdout",
97+
"output_type": "stream",
98+
"text": [
99+
"{\"id\": \"batch_req_HWbEoplsNE1CoXKmYvRKUxK0\", \"custom_id\": \"request-types-0\", \"response\": {\"status_code\": 200, \"request_id\": \"e55227a57343f3bfedc16b45f392647e\", \"body\": {\"id\": \"chatcmpl-A48UNoLfHd5nmm1SRA3E3GEEH0RNc\", \"object\": \"chat.completion\", \"created\": 1725549307, \"model\": \"gpt-4o-mini-2024-07-18\", \"choices\": [{\"index\": 0, \"message\": {\"role\": \"assistant\", \"content\": \"1. int \\n2. int \\n3. function \\n4. function \\n5. int \", \"refusal\": null}, \"logprobs\": null, \"finish_reason\": \"stop\"}], \"usage\": {\"prompt_tokens\": 357, \"completion_tokens\": 20, \"total_tokens\": 377}, \"system_fingerprint\": \"fp_f33667828e\"}}, \"error\": null}\n",
100+
"{\"id\": \"batch_req_GdlSnHJQ0dBcoL8SKDoP9d0F\", \"custom_id\": \"request-types-1\", \"response\": {\"status_code\": 200, \"request_id\": \"9e01043de3caf7e095f7d799c7f3279b\", \"body\": {\"id\": \"chatcmpl-A48UNq9mYTJbpRYQJQNfPdyFu4Whj\", \"object\": \"chat.completion\", \"created\": 1725549307, \"model\": \"gpt-4o-mini-2024-07-18\", \"choices\": [{\"index\": 0, \"message\": {\"role\": \"assistant\", \"content\": \"1. float \\n2. float \\n3. function \\n4. function \\n5. float \", \"refusal\": null}, \"logprobs\": null, \"finish_reason\": \"stop\"}], \"usage\": {\"prompt_tokens\": 359, \"completion_tokens\": 20, \"total_tokens\": 379}, \"system_fingerprint\": \"fp_f33667828e\"}}, \"error\": null}\n",
101+
"{\"id\": \"batch_req_05mDnLSDPh9f3cO5BDtuEnzJ\", \"custom_id\": \"request-types-2\", \"response\": {\"status_code\": 200, \"request_id\": \"54eda3aa77bc48b7ebf99431901c86f7\", \"body\": {\"id\": \"chatcmpl-A48UNwnSBiX2lmAbptZz4yyUMQgrp\", \"object\": \"chat.completion\", \"created\": 1725549307, \"model\": \"gpt-4o-mini-2024-07-18\", \"choices\": [{\"index\": 0, \"message\": {\"role\": \"assistant\", \"content\": \"1. str \\n2. str \\n3. function \\n4. function \\n5. str \", \"refusal\": null}, \"logprobs\": null, \"finish_reason\": \"stop\"}], \"usage\": {\"prompt_tokens\": 359, \"completion_tokens\": 20, \"total_tokens\": 379}, \"system_fingerprint\": \"fp_5bd87c427a\"}}, \"error\": null}\n",
102+
"\n"
103+
]
104+
}
105+
],
106+
"source": [
107+
"from openai import OpenAI\n",
108+
"client = OpenAI()\n",
109+
"\n",
110+
"file_response = client.files.content(file_info.output_file_id)\n",
111+
"print(file_response.text)"
112+
]
113+
}
114+
],
115+
"metadata": {
116+
"kernelspec": {
117+
"display_name": ".venv",
118+
"language": "python",
119+
"name": "python3"
120+
},
121+
"language_info": {
122+
"codemirror_mode": {
123+
"name": "ipython",
124+
"version": 3
125+
},
126+
"file_extension": ".py",
127+
"mimetype": "text/x-python",
128+
"name": "python",
129+
"nbconvert_exporter": "python",
130+
"pygments_lexer": "ipython3",
131+
"version": "3.10.12"
132+
}
133+
},
134+
"nbformat": 4,
135+
"nbformat_minor": 2
136+
}

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ docker==6.1.3
77
tabulate==0.9.0
88
PyYAML
99
requests==2.31.0
10-
tqdm
10+
tqdm
11+
tiktoken
12+
openai

src/target_tools/llms/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ packaging
1212
peft
1313
torch==2.1.2
1414
torchvision==0.16.2
15-
torchaudio
15+
torchaudio
16+
tiktoken

src/target_tools/llms/src/runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ def model_evaluation_openai(
201201

202202
utils.get_prompt_cost(prompts)
203203
utils.dump_ft_jsonl(id_mapping, f"{results_dst}/ft_dataset.jsonl")
204-
utils.dump_batch_prompt_jsonl(id_mapping, f"{results_dst}/batch_prompt.jsonl")
204+
utils.dump_batch_prompt_jsonl(
205+
id_mapping,
206+
f"{results_dst}/batch_prompt.jsonl",
207+
model=model_name,
208+
)
205209

206210
request_outputs = openai_helpers.process_requests(
207211
model_name,

src/target_tools/llms/src/utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def get_prompt(prompt_id, file_path, answers_placeholders=True, use_system_promp
266266
def dump_ft_jsonl(id_mapping, output_file):
267267
mappings = copy.deepcopy(id_mapping)
268268
for _m in mappings.values():
269-
print(_m)
270269
assistant_message = {
271270
"role": "assistant",
272271
"content": generate_answers_for_fine_tuning(_m["json_filepath"]),
@@ -281,12 +280,22 @@ def dump_ft_jsonl(id_mapping, output_file):
281280
output.write("\n")
282281

283282

284-
def dump_batch_prompt_jsonl(id_mapping, output_file):
285-
prompts = [x["prompt"] for x in id_mapping.values()]
286-
283+
def dump_batch_prompt_jsonl(
284+
id_mapping, output_file, id_prefix="types", model="gpt-4o-mini"
285+
):
287286
with open(output_file, "w") as output:
288-
for _m in prompts:
289-
output.write(json.dumps(_m))
287+
for idx, _m in id_mapping.items():
288+
prompt_dict = {
289+
"custom_id": f"request-{id_prefix}-{idx}",
290+
"method": "POST",
291+
"url": "/v1/chat/completions",
292+
"body": {
293+
"model": model,
294+
"messages": _m["prompt"],
295+
"max_tokens": 250,
296+
},
297+
}
298+
output.write(json.dumps(prompt_dict))
290299
output.write("\n")
291300

292301

0 commit comments

Comments
 (0)