-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto labeling using GPT-4 #35
base: main
Are you sure you want to change the base?
Changes from all commits
98aad61
884713b
31b1995
d8555b5
e777219
98a3bee
6defc43
9b859fc
df8abbc
b84de92
2b556fc
2fa181c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,3 +59,4 @@ datasets = "^2.18.0" | |
wandb = "^0.16.5" | ||
loguru = "^0.7.2" | ||
scikit-learn = "^1.4.2" | ||
openai = "^1.34.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
"""Script for pre-labeling the dataset. | ||
|
||
Run with: | ||
``` | ||
poetry run python -m scripts.pre_labeling_dataset | ||
``` | ||
""" | ||
|
||
import os | ||
import random | ||
import json | ||
import argparse | ||
import base64 | ||
from io import BytesIO | ||
|
||
from PIL import Image | ||
import jsonlines | ||
from huggingface_hub import HfApi | ||
from loguru import logger | ||
from openai import AzureOpenAI | ||
|
||
from scripts.constants import ( | ||
ASSETS_FOLDER, | ||
CLASS_CONCEPTS_VALUES, | ||
DATASET_NAME, | ||
HF_TOKEN, | ||
SPLITS, | ||
CONCEPTS, | ||
LABELED_CLASSES, | ||
) | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
|
||
def save_metadata(hf_api: HfApi, metadata: dict, split: str, push_to_hub: bool = False): | ||
with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl", mode="w") as writer: | ||
writer.write_all(metadata) | ||
|
||
if push_to_hub: | ||
hf_api.upload_file( | ||
path_or_fileobj=os.path.join("{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}", "metadata.jsonl"), | ||
path_in_repo=f"data/{split}/metadata.jsonl", | ||
repo_id=DATASET_NAME, | ||
repo_type="dataset", | ||
) | ||
|
||
|
||
def get_votes(hf_api: HfApi): | ||
hf_api.snapshot_download( | ||
local_dir=f"{ASSETS_FOLDER}/{DATASET_NAME}", | ||
repo_id=DATASET_NAME, | ||
repo_type="dataset", | ||
) | ||
metadata = {} | ||
for split in SPLITS: | ||
metadata[split] = [] | ||
with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl") as reader: | ||
for row in reader: | ||
metadata[split].append(row) | ||
votes = {} | ||
for filename in os.listdir(f"{ASSETS_FOLDER}/{DATASET_NAME}/votes"): | ||
with open(f"{ASSETS_FOLDER}/{DATASET_NAME}/votes/{filename}") as f: | ||
key = filename.split(".")[0] | ||
votes[key] = json.load(f) | ||
return metadata, votes | ||
|
||
|
||
def get_pre_labeled_concepts(item: dict): | ||
active_concepts = CLASS_CONCEPTS_VALUES[item["class"]] | ||
return {c: c in active_concepts for c in CONCEPTS} | ||
|
||
|
||
def compute_concepts(votes): | ||
vote_sum = {c: 0 for c in CONCEPTS} | ||
for vote in votes.values(): | ||
for c in CONCEPTS: | ||
if c not in vote: | ||
continue | ||
vote_sum[c] += 2 * vote[c] - 1 | ||
return {c: vote_sum[c] > 0 if vote_sum[c] != 0 else None for c in CONCEPTS} | ||
|
||
|
||
class OpenAIRequest: | ||
def __init__(self, model: str="gpt-4o"): | ||
self.client = AzureOpenAI( | ||
api_version=os.environ["AZURE_OPENAI_API_VERSION"], | ||
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], | ||
) | ||
self.concepts = ",".join(CONCEPTS) | ||
|
||
def __call__(self, item: dict, icl: dict, **kwargs): | ||
"""Send a request to the OpenAI API.""" | ||
message = [ | ||
{ | ||
"role": "system", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": """\ | ||
You are a helpful assistant that can help annotating images. Answer by giving the list of concepts you can see in the provided image. | ||
|
||
Given an image and its class, annotate the concepts' presence in the image using a JSON format. | ||
|
||
The labels must be provided according to the following JSON schema: | ||
{concept_schema} | ||
""".format(concept_schema={"properties":{concept: {"type":"boolean"} for concept in self.concepts}}) | ||
} | ||
], | ||
}, | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": f"""\ | ||
Here is an image and its class: | ||
|
||
Class: {icl["class"]}\nImage: | ||
""" | ||
}, | ||
{ | ||
"type": "image", | ||
"image": icl["image"] | ||
} | ||
], | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": f"Concepts: {icl['concepts']}" | ||
} | ||
], | ||
}, | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": f"Now here is another image and its class, provide the concepts: \nClass: {item['class']}\nImage:" | ||
}, | ||
{ | ||
"type": "image", | ||
"image": item["image"] | ||
} | ||
], | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": "Concepts:" | ||
} | ||
] | ||
} | ||
] | ||
|
||
return self.client.chat.completions.create( | ||
model=self.model, | ||
messages=message, | ||
**kwargs | ||
) | ||
|
||
|
||
def image2base64(image: BytesIO) -> str: | ||
"""Convert image to base64 string.""" | ||
return base64.b64encode(image.getvalue()).decode("utf-8") | ||
|
||
|
||
def get_icl_example_dict() -> dict: | ||
"""Build ICL example manually.""" | ||
return { | ||
"class": "lettuce", | ||
# TODO: update path | ||
"image": image2base64(BytesIO(open("images/00000000.jpg", "rb").read())), | ||
"concepts": { | ||
"leaf": True, | ||
"green": True, | ||
"stem": False, | ||
"red": False, | ||
"black": False, | ||
"blue": False, | ||
"ovaloid": False, | ||
"sphere": False, | ||
"cylinder": False, | ||
"cube": False, | ||
"brown": False, | ||
"orange": False, | ||
"yellow": False, | ||
"white": False, | ||
"tail": False, | ||
"seed": False, | ||
"pulp": False, | ||
"soil": False, | ||
"tree": False, | ||
} | ||
} | ||
|
||
|
||
def main(args): | ||
hf_api = HfApi(token=HF_TOKEN) | ||
|
||
logger.info("Download metadata and votes") | ||
metadata, votes = get_votes(hf_api) | ||
|
||
for split in SPLITS: | ||
for item in metadata[split]: | ||
if item["class"] in LABELED_CLASSES: | ||
continue | ||
key = item["id"] | ||
|
||
item_dict = { | ||
"class": item["class"], | ||
"image": image2base64(item["image"]), # TODO: fix open image | ||
} | ||
|
||
icl_dict = get_icl_example_dict() | ||
|
||
openai_request = OpenAIRequest(model=args.model) | ||
response = openai_request( | ||
item=item_dict, | ||
icl=icl_dict, | ||
max_tokens=200, | ||
temperature=0, | ||
) | ||
|
||
pred = response.choices[0].message.content | ||
pred = pred[pred.rfind("{"):pred.rfind("}")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Potential off-by-one error The slicing operation might exclude the closing brace '}'. Consider using 'pred.rfind("}") + 1' to include it. |
||
print(pred) | ||
|
||
concepts = get_pre_labeled_concepts(item) | ||
if args.model not in votes[key]: | ||
continue | ||
votes[key] = {args.model: concepts} | ||
|
||
logger.info("Save votes locally") | ||
for key in votes: | ||
with open(f"{ASSETS_FOLDER}/{DATASET_NAME}/votes/{key}.json", "w") as f: | ||
json.dump(votes[key], f) | ||
|
||
if args.push_to_hub: | ||
logger.info("Upload votes to Hub") | ||
hf_api.upload_folder( | ||
folder_path=f"{ASSETS_FOLDER}/{DATASET_NAME}", | ||
repo_id=DATASET_NAME, | ||
repo_type="dataset", | ||
allow_patterns=["votes/*"], | ||
) | ||
|
||
new_metadata = {} | ||
for split in ["train", "test"]: | ||
new_metadata[split] = [] | ||
with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl") as reader: | ||
for row in reader: | ||
s_id = row["id"] | ||
if s_id in votes: | ||
row.update(compute_concepts(votes[s_id])) | ||
new_metadata[split].append(row) | ||
with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl", mode="w") as writer: | ||
writer.write_all(new_metadata[split]) | ||
|
||
if args.push_to_hub: | ||
logger.info("Upload metadata to Hub") | ||
for split in SPLITS: | ||
save_metadata(hf_api, new_metadata[split], split, push_to_hub=True) | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser("auto-label-dataset") | ||
parser.add_argument( | ||
imenelydiaker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"--model", type=str, default="gpt-4o", help="Specify the model to use, e.g., 'gpt-4o'") | ||
parser.add_argument( | ||
"--push_to_hub", action="store_true", help="Flag to push the results to the hub") | ||
"--model", type=str, default="gpt-4o") | ||
parser.add_argument( | ||
"--push_to_hub", action=argparse.BooleanOptionalAction, default=False) | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: Unused import 'random'
The 'random' module is imported but not used anywhere in the script. Consider removing it to keep the code clean.