From 98aad610f98645a856ba7a3103a1e53b55d40ebe Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 10:08:14 +0200 Subject: [PATCH 01/12] add script for labelling with GPT4 --- poetry.lock | 181 ++++++++++++++++++++++++++++- pyproject.toml | 1 + scripts/auto_labeling_using_llm.py | 180 ++++++++++++++++++++++++++++ 3 files changed, 360 insertions(+), 2 deletions(-) create mode 100644 scripts/auto_labeling_using_llm.py diff --git a/poetry.lock b/poetry.lock index 14acc17..86214eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiofiles" @@ -764,6 +764,17 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -1398,6 +1409,31 @@ files = [ [package.dependencies] attrs = ">=19.2.0" +[[package]] +name = "jsonpatch" +version = "1.33" +description = "Apply JSON-Patches (RFC 6902)" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +files = [ + {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, + {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, +] + +[package.dependencies] +jsonpointer = ">=1.9" + +[[package]] +name = "jsonpointer" +version = "3.0.0" +description = "Identify specific nodes in a JSON document (RFC 6901)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, + {file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"}, +] + [[package]] name = "jsonschema" version = "4.21.1" @@ -1546,6 +1582,57 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] +[[package]] +name = "langchain-core" +version = "0.2.7" +description = "Building applications with LLMs through composability" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_core-0.2.7-py3-none-any.whl", hash = "sha256:fd02e153c898486dd728d634684ffc64bc257ff2ba443dc7e53d017ac0bf4658"}, + {file = "langchain_core-0.2.7.tar.gz", hash = "sha256:b0b1b6dfbdedb39426fcb8bd3f07e40eec7964856e3fc384c420ca6dba61b34e"}, +] + +[package.dependencies] +jsonpatch = ">=1.33,<2.0" +langsmith = ">=0.1.75,<0.2.0" +packaging = ">=23.2,<25" +pydantic = ">=1,<3" +PyYAML = ">=5.3" +tenacity = ">=8.1.0,<9.0.0" + +[[package]] +name = "langchain-openai" +version = "0.1.8" +description = "An integration package connecting OpenAI and LangChain" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_openai-0.1.8-py3-none-any.whl", hash = "sha256:8125c84223e9f43b05defbca64eedbcf362fd78a680de6c25e64f973b34a8063"}, + {file = "langchain_openai-0.1.8.tar.gz", hash = "sha256:a11fcce15def7917c44232abda6baaa63dfc79fe44be1531eea650d39a44cd95"}, +] + +[package.dependencies] +langchain-core = ">=0.2.2,<0.3" +openai = ">=1.26.0,<2.0.0" +tiktoken = ">=0.7,<1" + +[[package]] +name = "langsmith" +version = "0.1.77" +description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"}, + {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"}, +] + +[package.dependencies] +orjson = ">=3.9.14,<4.0.0" +pydantic = ">=1,<3" +requests = ">=2,<3" + [[package]] name = "loguru" version = "0.7.2" @@ -2075,6 +2162,29 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "openai" +version = "1.34.0" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.34.0-py3-none-any.whl", hash = "sha256:018623c2f795424044675c6230fa3bfbf98d9e0aab45d8fd116f2efb2cfb6b7e"}, + {file = "openai-1.34.0.tar.gz", hash = "sha256:95c8e2da4acd6958e626186957d656597613587195abd0fb2527566a93e76770"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "orjson" version = "3.10.1" @@ -3626,6 +3736,21 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tenacity" +version = "8.4.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-8.4.0-py3-none-any.whl", hash = "sha256:6c72aeb4d300b9858be0d52f1894ee952f0b391ec583cdc19c2523030efcc4eb"}, + {file = "tenacity-8.4.0.tar.gz", hash = "sha256:5ea66b27e881eec324b15adc7be9ec35001d231d4cfb5d10eee1bf9323d723a9"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tensordict" version = "0.3.2" @@ -3673,6 +3798,58 @@ files = [ {file = "threadpoolctl-3.4.0.tar.gz", hash = "sha256:f11b491a03661d6dd7ef692dd422ab34185d982466c49c8f98c8f716b5c93196"}, ] +[[package]] +name = "tiktoken" +version = "0.7.0" +description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79383a6e2c654c6040e5f8506f3750db9ddd71b550c724e673203b4f6b4b4590"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4511c52caacf3c4981d1ae2df85908bd31853f33d30b345c8b6830763f769c"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13c94efacdd3de9aff824a788353aa5749c0faee1fbe3816df365ea450b82311"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8e58c7eb29d2ab35a7a8929cbeea60216a4ccdf42efa8974d8e176d50c9a3df5"}, + {file = "tiktoken-0.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:21a20c3bd1dd3e55b91c1331bf25f4af522c525e771691adbc9a69336fa7f702"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10c7674f81e6e350fcbed7c09a65bca9356eaab27fb2dac65a1e440f2bcfe30f"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811229fde1652fedcca7c6dfe76724d0908775b353556d8a71ed74d866f73f7b"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b6e7dc2e7ad1b3757e8a24597415bafcfb454cebf9a33a01f2e6ba2e663992"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1063c5748be36344c7e18c7913c53e2cca116764c2080177e57d62c7ad4576d1"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:20295d21419bfcca092644f7e2f2138ff947a6eb8cfc732c09cc7d76988d4a89"}, + {file = "tiktoken-0.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:959d993749b083acc57a317cbc643fb85c014d055b2119b739487288f4e5d1cb"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:71c55d066388c55a9c00f61d2c456a6086673ab7dec22dd739c23f77195b1908"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09ed925bccaa8043e34c519fbb2f99110bd07c6fd67714793c21ac298e449410"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20b5c6af30e621b4aca094ee61777a44118f52d886dbe4f02b70dfe05c15350"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d427614c3e074004efa2f2411e16c826f9df427d3c70a54725cae860f09e4bf4"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c46d7af7b8c6987fac9b9f61041b452afe92eb087d29c9ce54951280f899a97"}, + {file = "tiktoken-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:0bc603c30b9e371e7c4c7935aba02af5994a909fc3c0fe66e7004070858d3f8f"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2398fecd38c921bcd68418675a6d155fad5f5e14c2e92fcf5fe566fa5485a858"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f5f6afb52fb8a7ea1c811e435e4188f2bef81b5e0f7a8635cc79b0eef0193d6"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:861f9ee616766d736be4147abac500732b505bf7013cfaf019b85892637f235e"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54031f95c6939f6b78122c0aa03a93273a96365103793a22e1793ee86da31685"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c72baaeaefa03ff9ba9688624143c858d1f6b755bb85d456d59e529e17234769"}, + {file = "tiktoken-0.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:131b8aeb043a8f112aad9f46011dced25d62629091e51d9dc1adbf4a1cc6aa98"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cabc6dc77460df44ec5b879e68692c63551ae4fae7460dd4ff17181df75f1db7"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8d57f29171255f74c0aeacd0651e29aa47dff6f070cb9f35ebc14c82278f3b25"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ee92776fdbb3efa02a83f968c19d4997a55c8e9ce7be821ceee04a1d1ee149c"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a81bac94769cab437dd3ab0b8a4bc4e0f9cf6835bcaa88de71f39af1791727a"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d6d73ea93e91d5ca771256dfc9d1d29f5a554b83821a1dc0891987636e0ae226"}, + {file = "tiktoken-0.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:2bcb28ddf79ffa424f171dfeef9a4daff61a94c631ca6813f43967cb263b83b9"}, + {file = "tiktoken-0.7.0.tar.gz", hash = "sha256:1077266e949c24e0291f6c350433c6f0971365ece2b173a23bc3b9f9defef6b6"}, +] + +[package.dependencies] +regex = ">=2022.1.18" +requests = ">=2.26.0" + +[package.extras] +blobfile = ["blobfile (>=2)"] + [[package]] name = "tokenizers" version = "0.15.2" @@ -4542,4 +4719,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f477d409961cf7a3873085d20e3e21b9e53a394faaf0fad4a6a819345e073b02" +content-hash = "5bd51b69a103505d61d641c9f1c4dc4cc9609983f9fc23ea9a008210a37dc2af" diff --git a/pyproject.toml b/pyproject.toml index b41141f..cf554f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,3 +59,4 @@ datasets = "^2.18.0" wandb = "^0.16.5" loguru = "^0.7.2" scikit-learn = "^1.4.2" +langchain-openai = "^0.1.8" diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py new file mode 100644 index 0000000..4201996 --- /dev/null +++ b/scripts/auto_labeling_using_llm.py @@ -0,0 +1,180 @@ +"""Script for pre-labeling the dataset. + +Run with: +``` +poetry run python -m scripts.pre_labeling_dataset +``` +""" + +import os +import json +import argparse + +import jsonlines +from huggingface_hub import HfApi +from loguru import logger + +from langchain_core.messages import HumanMessage +from langchain_openai import AzureChatOpenAI + +from scripts.constants import ( + ASSETS_FOLDER, + CLASS_CONCEPTS_VALUES, + DATASET_NAME, + HF_TOKEN, + SPLITS, + CONCEPTS, + LABELED_CLASSES, +) + +from dotenv import load_dotenv + +load_dotenv() + +PROMPT = """\ +Given an image and its class, provide the cnocepts that are present in the image in the following format: + +You may choose from the following concepts only: +{concepts} + +Provide the classification in the following format: +Classification::: +Concepts: (concept: e.g., red, sphere, stem, etc.) + +Examples: +Image: {imaage_example_1} +Class: {class_example_1} +Concepts: {concepts_example_1} + +Image: {imaage_example_2} +Class: {class_example_2} +Concepts: {concepts_example_2} + +Now here is an image and its class: +Image: {image} +Class: {class_} + +Classification::: +Concepts: +""" + + +def save_metadata(hf_api: HfApi, metadata: dict, split: str, push_to_hub: bool = False): + with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl", mode="w") as writer: + writer.write_all(metadata) + + if push_to_hub: + hf_api.upload_file( + path_or_fileobj=f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/" "metadata.jsonl", + path_in_repo=f"data/{split}/metadata.jsonl", + repo_id=DATASET_NAME, + repo_type="dataset", + ) + + +def get_votes(hf_api: HfApi): + hf_api.snapshot_download( + local_dir=f"{ASSETS_FOLDER}/{DATASET_NAME}", + repo_id=DATASET_NAME, + repo_type="dataset", + ) + metadata = {} + for split in SPLITS: + metadata[split] = [] + with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl") as reader: + for row in reader: + metadata[split].append(row) + votes = {} + for filename in os.listdir(f"{ASSETS_FOLDER}/{DATASET_NAME}/votes"): + with open(f"{ASSETS_FOLDER}/{DATASET_NAME}/votes/{filename}") as f: + key = filename.split(".")[0] + votes[key] = json.load(f) + return metadata, votes + + +def get_pre_labeled_concepts(item: dict): + active_concepts = CLASS_CONCEPTS_VALUES[item["class"]] + return {c: c in active_concepts for c in CONCEPTS} + + +def compute_concepts(votes): + vote_sum = {c: 0 for c in CONCEPTS} + for vote in votes.values(): + for c in CONCEPTS: + if c not in vote: + continue + vote_sum[c] += 2 * vote[c] - 1 + return {c: vote_sum[c] > 0 if vote_sum[c] != 0 else None for c in CONCEPTS} + + +def main(args): + hf_api = HfApi(token=HF_TOKEN) + + model = AzureChatOpenAI( + openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], + azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], + ) + + logger.info("Download metadata and votes") + metadata, votes = get_votes(hf_api) + + for split in SPLITS: + for item in metadata[split]: + if item["class"] in LABELED_CLASSES: + continue + key = item["id"] + + # Call VLM + message = HumanMessage( + content=PROMPT.format(concepts=", ".join(CONCEPTS), class_=item["class"]), + temperature=0, + ) + response = model.invoke([message]) + pred = response.content.split(":")[1].strip() if ":" in response.content else response.content + print(pred) + concepts = get_pre_labeled_concepts(item) + if "imenelydiaker" not in votes[key]: + continue + votes[key] = {"imenelydiaker": concepts} + + logger.info("Save votes locally") + for key in votes: + with open(f"{ASSETS_FOLDER}/{DATASET_NAME}/votes/{key}.json", "w") as f: + json.dump(votes[key], f) + + if args.push_to_hub: + logger.info("Upload votes to Hub") + hf_api.upload_folder( + folder_path=f"{ASSETS_FOLDER}/{DATASET_NAME}", + repo_id=DATASET_NAME, + repo_type="dataset", + allow_patterns=["votes/*"], + ) + + new_metadata = {} + for split in ["train", "test"]: + new_metadata[split] = [] + with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl") as reader: + for row in reader: + s_id = row["id"] + if s_id in votes: + row.update(compute_concepts(votes[s_id])) + new_metadata[split].append(row) + with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl", mode="w") as writer: + writer.write_all(new_metadata[split]) + + if args.push_to_hub: + logger.info("Upload metadata to Hub") + for split in SPLITS: + save_metadata(hf_api, new_metadata[split], split, push_to_hub=True) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("pre-label-dataset") + parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 884713b7586bce38c4b5c285edd0f7418b1aa10c Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 11:57:59 +0200 Subject: [PATCH 02/12] update autolaveking script --- poetry.lock | 147 +------------------------ pyproject.toml | 2 +- scripts/auto_labeling_using_llm.py | 171 ++++++++++++++++++++++------- 3 files changed, 133 insertions(+), 187 deletions(-) diff --git a/poetry.lock b/poetry.lock index 86214eb..ff86e9f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiofiles" @@ -1409,31 +1409,6 @@ files = [ [package.dependencies] attrs = ">=19.2.0" -[[package]] -name = "jsonpatch" -version = "1.33" -description = "Apply JSON-Patches (RFC 6902)" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" -files = [ - {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, - {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, -] - -[package.dependencies] -jsonpointer = ">=1.9" - -[[package]] -name = "jsonpointer" -version = "3.0.0" -description = "Identify specific nodes in a JSON document (RFC 6901)" -optional = false -python-versions = ">=3.7" -files = [ - {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, - {file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"}, -] - [[package]] name = "jsonschema" version = "4.21.1" @@ -1582,57 +1557,6 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] -[[package]] -name = "langchain-core" -version = "0.2.7" -description = "Building applications with LLMs through composability" -optional = false -python-versions = "<4.0,>=3.8.1" -files = [ - {file = "langchain_core-0.2.7-py3-none-any.whl", hash = "sha256:fd02e153c898486dd728d634684ffc64bc257ff2ba443dc7e53d017ac0bf4658"}, - {file = "langchain_core-0.2.7.tar.gz", hash = "sha256:b0b1b6dfbdedb39426fcb8bd3f07e40eec7964856e3fc384c420ca6dba61b34e"}, -] - -[package.dependencies] -jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.75,<0.2.0" -packaging = ">=23.2,<25" -pydantic = ">=1,<3" -PyYAML = ">=5.3" -tenacity = ">=8.1.0,<9.0.0" - -[[package]] -name = "langchain-openai" -version = "0.1.8" -description = "An integration package connecting OpenAI and LangChain" -optional = false -python-versions = "<4.0,>=3.8.1" -files = [ - {file = "langchain_openai-0.1.8-py3-none-any.whl", hash = "sha256:8125c84223e9f43b05defbca64eedbcf362fd78a680de6c25e64f973b34a8063"}, - {file = "langchain_openai-0.1.8.tar.gz", hash = "sha256:a11fcce15def7917c44232abda6baaa63dfc79fe44be1531eea650d39a44cd95"}, -] - -[package.dependencies] -langchain-core = ">=0.2.2,<0.3" -openai = ">=1.26.0,<2.0.0" -tiktoken = ">=0.7,<1" - -[[package]] -name = "langsmith" -version = "0.1.77" -description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." -optional = false -python-versions = "<4.0,>=3.8.1" -files = [ - {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"}, - {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"}, -] - -[package.dependencies] -orjson = ">=3.9.14,<4.0.0" -pydantic = ">=1,<3" -requests = ">=2,<3" - [[package]] name = "loguru" version = "0.7.2" @@ -3736,21 +3660,6 @@ files = [ [package.dependencies] mpmath = ">=0.19" -[[package]] -name = "tenacity" -version = "8.4.0" -description = "Retry code until it succeeds" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tenacity-8.4.0-py3-none-any.whl", hash = "sha256:6c72aeb4d300b9858be0d52f1894ee952f0b391ec583cdc19c2523030efcc4eb"}, - {file = "tenacity-8.4.0.tar.gz", hash = "sha256:5ea66b27e881eec324b15adc7be9ec35001d231d4cfb5d10eee1bf9323d723a9"}, -] - -[package.extras] -doc = ["reno", "sphinx"] -test = ["pytest", "tornado (>=4.5)", "typeguard"] - [[package]] name = "tensordict" version = "0.3.2" @@ -3798,58 +3707,6 @@ files = [ {file = "threadpoolctl-3.4.0.tar.gz", hash = "sha256:f11b491a03661d6dd7ef692dd422ab34185d982466c49c8f98c8f716b5c93196"}, ] -[[package]] -name = "tiktoken" -version = "0.7.0" -description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, - {file = "tiktoken-0.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225"}, - {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79383a6e2c654c6040e5f8506f3750db9ddd71b550c724e673203b4f6b4b4590"}, - {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4511c52caacf3c4981d1ae2df85908bd31853f33d30b345c8b6830763f769c"}, - {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13c94efacdd3de9aff824a788353aa5749c0faee1fbe3816df365ea450b82311"}, - {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8e58c7eb29d2ab35a7a8929cbeea60216a4ccdf42efa8974d8e176d50c9a3df5"}, - {file = "tiktoken-0.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:21a20c3bd1dd3e55b91c1331bf25f4af522c525e771691adbc9a69336fa7f702"}, - {file = "tiktoken-0.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10c7674f81e6e350fcbed7c09a65bca9356eaab27fb2dac65a1e440f2bcfe30f"}, - {file = "tiktoken-0.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f"}, - {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811229fde1652fedcca7c6dfe76724d0908775b353556d8a71ed74d866f73f7b"}, - {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b6e7dc2e7ad1b3757e8a24597415bafcfb454cebf9a33a01f2e6ba2e663992"}, - {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1063c5748be36344c7e18c7913c53e2cca116764c2080177e57d62c7ad4576d1"}, - {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:20295d21419bfcca092644f7e2f2138ff947a6eb8cfc732c09cc7d76988d4a89"}, - {file = "tiktoken-0.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:959d993749b083acc57a317cbc643fb85c014d055b2119b739487288f4e5d1cb"}, - {file = "tiktoken-0.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:71c55d066388c55a9c00f61d2c456a6086673ab7dec22dd739c23f77195b1908"}, - {file = "tiktoken-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09ed925bccaa8043e34c519fbb2f99110bd07c6fd67714793c21ac298e449410"}, - {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704"}, - {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20b5c6af30e621b4aca094ee61777a44118f52d886dbe4f02b70dfe05c15350"}, - {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d427614c3e074004efa2f2411e16c826f9df427d3c70a54725cae860f09e4bf4"}, - {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c46d7af7b8c6987fac9b9f61041b452afe92eb087d29c9ce54951280f899a97"}, - {file = "tiktoken-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:0bc603c30b9e371e7c4c7935aba02af5994a909fc3c0fe66e7004070858d3f8f"}, - {file = "tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2398fecd38c921bcd68418675a6d155fad5f5e14c2e92fcf5fe566fa5485a858"}, - {file = "tiktoken-0.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f5f6afb52fb8a7ea1c811e435e4188f2bef81b5e0f7a8635cc79b0eef0193d6"}, - {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:861f9ee616766d736be4147abac500732b505bf7013cfaf019b85892637f235e"}, - {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54031f95c6939f6b78122c0aa03a93273a96365103793a22e1793ee86da31685"}, - {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d"}, - {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c72baaeaefa03ff9ba9688624143c858d1f6b755bb85d456d59e529e17234769"}, - {file = "tiktoken-0.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:131b8aeb043a8f112aad9f46011dced25d62629091e51d9dc1adbf4a1cc6aa98"}, - {file = "tiktoken-0.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cabc6dc77460df44ec5b879e68692c63551ae4fae7460dd4ff17181df75f1db7"}, - {file = "tiktoken-0.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8d57f29171255f74c0aeacd0651e29aa47dff6f070cb9f35ebc14c82278f3b25"}, - {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ee92776fdbb3efa02a83f968c19d4997a55c8e9ce7be821ceee04a1d1ee149c"}, - {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf"}, - {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a81bac94769cab437dd3ab0b8a4bc4e0f9cf6835bcaa88de71f39af1791727a"}, - {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d6d73ea93e91d5ca771256dfc9d1d29f5a554b83821a1dc0891987636e0ae226"}, - {file = "tiktoken-0.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:2bcb28ddf79ffa424f171dfeef9a4daff61a94c631ca6813f43967cb263b83b9"}, - {file = "tiktoken-0.7.0.tar.gz", hash = "sha256:1077266e949c24e0291f6c350433c6f0971365ece2b173a23bc3b9f9defef6b6"}, -] - -[package.dependencies] -regex = ">=2022.1.18" -requests = ">=2.26.0" - -[package.extras] -blobfile = ["blobfile (>=2)"] - [[package]] name = "tokenizers" version = "0.15.2" @@ -4719,4 +4576,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "5bd51b69a103505d61d641c9f1c4dc4cc9609983f9fc23ea9a008210a37dc2af" +content-hash = "af3df0483ec9af4baa1ef92c0e88357263e6e11b119e8b7ed78ef66622f438fa" diff --git a/pyproject.toml b/pyproject.toml index cf554f0..3344de7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,4 +59,4 @@ datasets = "^2.18.0" wandb = "^0.16.5" loguru = "^0.7.2" scikit-learn = "^1.4.2" -langchain-openai = "^0.1.8" +openai = "^1.34.0" diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index 4201996..b96f433 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -7,15 +7,17 @@ """ import os +import random import json import argparse +import base64 +from io import BytesIO +from PIL import Image import jsonlines from huggingface_hub import HfApi from loguru import logger - -from langchain_core.messages import HumanMessage -from langchain_openai import AzureChatOpenAI +from openai import AzureOpenAI, ChatCompletion from scripts.constants import ( ASSETS_FOLDER, @@ -31,33 +33,6 @@ load_dotenv() -PROMPT = """\ -Given an image and its class, provide the cnocepts that are present in the image in the following format: - -You may choose from the following concepts only: -{concepts} - -Provide the classification in the following format: -Classification::: -Concepts: (concept: e.g., red, sphere, stem, etc.) - -Examples: -Image: {imaage_example_1} -Class: {class_example_1} -Concepts: {concepts_example_1} - -Image: {imaage_example_2} -Class: {class_example_2} -Concepts: {concepts_example_2} - -Now here is an image and its class: -Image: {image} -Class: {class_} - -Classification::: -Concepts: -""" - def save_metadata(hf_api: HfApi, metadata: dict, split: str, push_to_hub: bool = False): with jsonlines.open(f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/metadata.jsonl", mode="w") as writer: @@ -107,30 +82,144 @@ def compute_concepts(votes): return {c: vote_sum[c] > 0 if vote_sum[c] != 0 else None for c in CONCEPTS} +class OpenAIRequest: + def __init__(self): + self.client = AzureOpenAI( + openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], + azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], + ) + self.concepts = ",".join(CONCEPTS) + print(self.concepts) + + def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: + message = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": """\ +You are a helpful assistant that can help annotating images. Answer by giving the list of concepts you can see in the provided image. + +Given an image and its class, provide the concepts that are present in the image. + +You may choose from the following concepts only: +{self.concepts} + +Provide the classification in the following format: +Concepts: (concept: e.g., red, sphere, stem, etc.) +""" + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"""\ +Here is an image and its class: + +Class: {icl["class"]}\nImage: +""" + }, + { + "type": "image", + "image": icl["image"] + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": f"Concepts: {icl["concepts"]}" + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"Now here is another image and its class, provide the concepts: \nClass: {item["class"]}\nImage:" + }, + { + "type": "image", + "image": item["image"] + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Concepts:" + } + ] + } + ] + + return self.client.chat.completions.create( + model="gpt-4o", + messages=[message], + **kwargs + ) + +def image2base64(image: BytesIO) -> str: + # buffered = BytesIO() + # image.save(buffered, format="JPEG") + # return base64.b64encode(buffered.getvalue()).decode("utf-8") + return base64.b64encode(image.getvalue()).decode() + +def get_icl_example_dict(metadata: dict, split: str) -> dict: + labeled_items_classes = ["tomato", "lemon", "kiwi", "lettuce", "cabbage", "paprika", "beetroots", "bell pepper"] + labeled_items = [item for item in metadata[split] if item["class"] in labeled_items_classes] + + images = [item["image"] for item in labeled_items] + classes = [item["class"] for item in labeled_items] + concepts = [get_pre_labeled_concepts(item) for item in labeled_items] #TODO: remove and replace with correct function + + rand_idx = random.randint(0, len(labeled_items) - 1) # TODO: remove + + return { + "class": classes[rand_idx], + "image": image2base64(images[rand_idx]), + "concepts": ",".join([c for c in concepts[rand_idx] if concepts[rand_idx][c]]), + } + def main(args): hf_api = HfApi(token=HF_TOKEN) - model = AzureChatOpenAI( - openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], - azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], - ) - logger.info("Download metadata and votes") metadata, votes = get_votes(hf_api) for split in SPLITS: for item in metadata[split]: + print(type(item)) + print(item.keys()) + if item["class"] in LABELED_CLASSES: continue key = item["id"] - # Call VLM - message = HumanMessage( - content=PROMPT.format(concepts=", ".join(CONCEPTS), class_=item["class"]), + item_dict = { + "class": item["class"], + "image": image2base64(item["image"]), # TODO: fix open image + } + + icl_dict = get_icl_example_dict(metadata=metadata, split=split) + + response = OpenAIRequest( + item=item_dict, + icl=icl_dict, # TODO: build the ICL dict manually + max_tokens=100, temperature=0, ) - response = model.invoke([message]) - pred = response.content.split(":")[1].strip() if ":" in response.content else response.content + + pred = response.choices[0].message.split(":")[1].strip() if ":" in response.choices[0].message else response.choices[0].message print(pred) concepts = get_pre_labeled_concepts(item) if "imenelydiaker" not in votes[key]: @@ -170,7 +259,7 @@ def main(args): def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser("pre-label-dataset") + parser = argparse.ArgumentParser("auto-label-dataset") parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False) return parser.parse_args() From 31b1995f6df9cefaff7be74eff9de8563e473a58 Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 15:49:18 +0200 Subject: [PATCH 03/12] update script --- scripts/auto_labeling_using_llm.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index b96f433..8a4b070 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -81,15 +81,13 @@ def compute_concepts(votes): vote_sum[c] += 2 * vote[c] - 1 return {c: vote_sum[c] > 0 if vote_sum[c] != 0 else None for c in CONCEPTS} - class OpenAIRequest: def __init__(self): self.client = AzureOpenAI( - openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], + api_version=os.environ["AZURE_OPENAI_API_VERSION"], azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], ) self.concepts = ",".join(CONCEPTS) - print(self.concepts) def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: message = [ @@ -134,7 +132,7 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: "content": [ { "type": "text", - "text": f"Concepts: {icl["concepts"]}" + "text": f"Concepts: {icl['concepts']}" } ], }, @@ -143,7 +141,7 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: "content": [ { "type": "text", - "text": f"Now here is another image and its class, provide the concepts: \nClass: {item["class"]}\nImage:" + "text": f"Now here is another image and its class, provide the concepts: \nClass: {item['class']}\nImage:" }, { "type": "image", @@ -164,15 +162,13 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: return self.client.chat.completions.create( model="gpt-4o", - messages=[message], + messages=message, **kwargs ) def image2base64(image: BytesIO) -> str: - # buffered = BytesIO() - # image.save(buffered, format="JPEG") - # return base64.b64encode(buffered.getvalue()).decode("utf-8") - return base64.b64encode(image.getvalue()).decode() + # Call example: image2base64(BytesIO(open("images/00000000.jpg", "rb").read())) + return base64.b64encode(image.getvalue()).decode("utf-8") def get_icl_example_dict(metadata: dict, split: str) -> dict: labeled_items_classes = ["tomato", "lemon", "kiwi", "lettuce", "cabbage", "paprika", "beetroots", "bell pepper"] @@ -198,9 +194,6 @@ def main(args): for split in SPLITS: for item in metadata[split]: - print(type(item)) - print(item.keys()) - if item["class"] in LABELED_CLASSES: continue key = item["id"] From d8555b55f09348de3756e8f1edbd8e3a1c6fd87d Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 15:59:54 +0200 Subject: [PATCH 04/12] fix call of openai request --- scripts/auto_labeling_using_llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index 8a4b070..7d466d0 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -205,7 +205,8 @@ def main(args): icl_dict = get_icl_example_dict(metadata=metadata, split=split) - response = OpenAIRequest( + openai_request = OpenAIRequest() + response = openai_request( item=item_dict, icl=icl_dict, # TODO: build the ICL dict manually max_tokens=100, From e7772190b2bdedaa661e79195039364a74624a99 Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 17:39:48 +0200 Subject: [PATCH 05/12] update response format in prompt --- scripts/auto_labeling_using_llm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index 7d466d0..894e648 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -104,8 +104,8 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: You may choose from the following concepts only: {self.concepts} -Provide the classification in the following format: -Concepts: (concept: e.g., red, sphere, stem, etc.) +Provide the classification in the following JSON format: +{"red": True, "sphere": True, "stem": False, ...} """ } ], @@ -167,7 +167,6 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: ) def image2base64(image: BytesIO) -> str: - # Call example: image2base64(BytesIO(open("images/00000000.jpg", "rb").read())) return base64.b64encode(image.getvalue()).decode("utf-8") def get_icl_example_dict(metadata: dict, split: str) -> dict: @@ -209,12 +208,14 @@ def main(args): response = openai_request( item=item_dict, icl=icl_dict, # TODO: build the ICL dict manually - max_tokens=100, + max_tokens=200, temperature=0, ) - pred = response.choices[0].message.split(":")[1].strip() if ":" in response.choices[0].message else response.choices[0].message + pred = response.choices[0].message.content + pred = pred[pred.rfind("{"):pred.rfind("}")] print(pred) + concepts = get_pre_labeled_concepts(item) if "imenelydiaker" not in votes[key]: continue From 98a3bee239e0da49dc008594507a3b18dec2484c Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 17:43:40 +0200 Subject: [PATCH 06/12] add icl exmaple manually --- scripts/auto_labeling_using_llm.py | 61 ++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index 894e648..86cecde 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -81,6 +81,7 @@ def compute_concepts(votes): vote_sum[c] += 2 * vote[c] - 1 return {c: vote_sum[c] > 0 if vote_sum[c] != 0 else None for c in CONCEPTS} + class OpenAIRequest: def __init__(self): self.client = AzureOpenAI( @@ -88,8 +89,9 @@ def __init__(self): azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], ) self.concepts = ",".join(CONCEPTS) - + def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: + """Send a request to the OpenAI API.""" message = [ { "role": "system", @@ -166,24 +168,42 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: **kwargs ) + def image2base64(image: BytesIO) -> str: + """Convert image to base64 string.""" return base64.b64encode(image.getvalue()).decode("utf-8") -def get_icl_example_dict(metadata: dict, split: str) -> dict: - labeled_items_classes = ["tomato", "lemon", "kiwi", "lettuce", "cabbage", "paprika", "beetroots", "bell pepper"] - labeled_items = [item for item in metadata[split] if item["class"] in labeled_items_classes] - images = [item["image"] for item in labeled_items] - classes = [item["class"] for item in labeled_items] - concepts = [get_pre_labeled_concepts(item) for item in labeled_items] #TODO: remove and replace with correct function +def get_icl_example_dict() -> dict: + """Build ICL example manually.""" + icl_dict = { + "class": "lettuce", + # TODO: update path + "image": image2base64(BytesIO(open("images/00000000.jpg", "rb").read())), + "concepts": { + "leaf": True, + "green": True, + "stem": False, + "red": False, + "black": False, + "blue": False, + "ovaloid": False, + "sphere": False, + "cylinder": False, + "cube": False, + "brown": False, + "orange": False, + "yellow": False, + "white": False, + "tail": False, + "seed": False, + "pulp": False, + "soil": False, + "tree": False, + } + } + return icl_dict - rand_idx = random.randint(0, len(labeled_items) - 1) # TODO: remove - - return { - "class": classes[rand_idx], - "image": image2base64(images[rand_idx]), - "concepts": ",".join([c for c in concepts[rand_idx] if concepts[rand_idx][c]]), - } def main(args): hf_api = HfApi(token=HF_TOKEN) @@ -199,19 +219,19 @@ def main(args): item_dict = { "class": item["class"], - "image": image2base64(item["image"]), # TODO: fix open image + "image": image2base64(item["image"]), # TODO: fix open image } - icl_dict = get_icl_example_dict(metadata=metadata, split=split) + icl_dict = get_icl_example_dict() openai_request = OpenAIRequest() - response = openai_request( + response = openai_request( item=item_dict, - icl=icl_dict, # TODO: build the ICL dict manually + icl=icl_dict, max_tokens=200, temperature=0, ) - + pred = response.choices[0].message.content pred = pred[pred.rfind("{"):pred.rfind("}")] print(pred) @@ -255,7 +275,8 @@ def main(args): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser("auto-label-dataset") - parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument( + "--push_to_hub", action=argparse.BooleanOptionalAction, default=False) return parser.parse_args() From 6defc4366b9852e8cd232629ebb19fb2325ac636 Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Mon, 17 Jun 2024 17:49:10 +0200 Subject: [PATCH 07/12] add model_name argument --- scripts/auto_labeling_using_llm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index 86cecde..38a992b 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -83,7 +83,7 @@ def compute_concepts(votes): class OpenAIRequest: - def __init__(self): + def __init__(self, model: str="gpt-4o"): self.client = AzureOpenAI( api_version=os.environ["AZURE_OPENAI_API_VERSION"], azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"], @@ -163,7 +163,7 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: ] return self.client.chat.completions.create( - model="gpt-4o", + model=self.model, messages=message, **kwargs ) @@ -275,6 +275,8 @@ def main(args): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser("auto-label-dataset") + parser.add_argument( + "--model", type=str, default="gpt-4o") parser.add_argument( "--push_to_hub", action=argparse.BooleanOptionalAction, default=False) return parser.parse_args() From 9b859fcc76f722856d0ff70b5eb1bcd2c1427204 Mon Sep 17 00:00:00 2001 From: Imene Kerboua <33312980+imenelydiaker@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:50:42 +0200 Subject: [PATCH 08/12] Fix concepts parsing in prompt --- scripts/auto_labeling_using_llm.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index 38a992b..d844181 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -17,7 +17,7 @@ import jsonlines from huggingface_hub import HfApi from loguru import logger -from openai import AzureOpenAI, ChatCompletion +from openai import AzureOpenAI from scripts.constants import ( ASSETS_FOLDER, @@ -40,7 +40,7 @@ def save_metadata(hf_api: HfApi, metadata: dict, split: str, push_to_hub: bool = if push_to_hub: hf_api.upload_file( - path_or_fileobj=f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/" "metadata.jsonl", + path_or_fileobj=os.path.join("{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}", "metadata.jsonl"), path_in_repo=f"data/{split}/metadata.jsonl", repo_id=DATASET_NAME, repo_type="dataset", @@ -90,7 +90,7 @@ def __init__(self, model: str="gpt-4o"): ) self.concepts = ",".join(CONCEPTS) - def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: + def __call__(self, item: dict, icl: dict, **kwargs): """Send a request to the OpenAI API.""" message = [ { @@ -104,11 +104,11 @@ def __call__(self, item: dict, icl: dict, **kwargs) -> ChatCompletion: Given an image and its class, provide the concepts that are present in the image. You may choose from the following concepts only: -{self.concepts} +{concepts} Provide the classification in the following JSON format: -{"red": True, "sphere": True, "stem": False, ...} -""" +{{"red": True, "sphere": True, "stem": False, ...}} +""".format(concepts=self.concepts) } ], }, @@ -176,7 +176,7 @@ def image2base64(image: BytesIO) -> str: def get_icl_example_dict() -> dict: """Build ICL example manually.""" - icl_dict = { + return { "class": "lettuce", # TODO: update path "image": image2base64(BytesIO(open("images/00000000.jpg", "rb").read())), @@ -202,7 +202,6 @@ def get_icl_example_dict() -> dict: "tree": False, } } - return icl_dict def main(args): From df8abbc50ce8095b8c730a7879c5f11010ed2268 Mon Sep 17 00:00:00 2001 From: Imene Kerboua <33312980+imenelydiaker@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:57:44 +0100 Subject: [PATCH 09/12] Update scripts/auto_labeling_using_llm.py Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> --- scripts/auto_labeling_using_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index d844181..c17d100 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -223,7 +223,7 @@ def main(args): icl_dict = get_icl_example_dict() - openai_request = OpenAIRequest() + openai_request = OpenAIRequest(model=args.model) response = openai_request( item=item_dict, icl=icl_dict, From b84de9225becfb0fbadd24bd47ed95b19caa0aa3 Mon Sep 17 00:00:00 2001 From: Imene Kerboua <33312980+imenelydiaker@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:58:24 +0100 Subject: [PATCH 10/12] Update scripts/auto_labeling_using_llm.py Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> --- scripts/auto_labeling_using_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index c17d100..d043f33 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -236,9 +236,9 @@ def main(args): print(pred) concepts = get_pre_labeled_concepts(item) - if "imenelydiaker" not in votes[key]: + if args.model not in votes[key]: continue - votes[key] = {"imenelydiaker": concepts} + votes[key] = {args.model: concepts} logger.info("Save votes locally") for key in votes: From 2b556fc74193f808b1ffb0c5193ea0fbd6327d3e Mon Sep 17 00:00:00 2001 From: Imene Kerboua <33312980+imenelydiaker@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:58:47 +0100 Subject: [PATCH 11/12] Update scripts/auto_labeling_using_llm.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- scripts/auto_labeling_using_llm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index d043f33..b8c581a 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -275,6 +275,9 @@ def main(args): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser("auto-label-dataset") parser.add_argument( + "--model", type=str, default="gpt-4o", help="Specify the model to use, e.g., 'gpt-4o'") + parser.add_argument( + "--push_to_hub", action="store_true", help="Flag to push the results to the hub") "--model", type=str, default="gpt-4o") parser.add_argument( "--push_to_hub", action=argparse.BooleanOptionalAction, default=False) From 2fa181c85e7a65a93f91d9b6c1bedd9771938258 Mon Sep 17 00:00:00 2001 From: Imene Kerboua <33312980+imenelydiaker@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:59:02 +0100 Subject: [PATCH 12/12] Update scripts/auto_labeling_using_llm.py Co-authored-by: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> --- scripts/auto_labeling_using_llm.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/scripts/auto_labeling_using_llm.py b/scripts/auto_labeling_using_llm.py index b8c581a..9d00f66 100644 --- a/scripts/auto_labeling_using_llm.py +++ b/scripts/auto_labeling_using_llm.py @@ -101,14 +101,11 @@ def __call__(self, item: dict, icl: dict, **kwargs): "text": """\ You are a helpful assistant that can help annotating images. Answer by giving the list of concepts you can see in the provided image. -Given an image and its class, provide the concepts that are present in the image. +Given an image and its class, annotate the concepts' presence in the image using a JSON format. -You may choose from the following concepts only: -{concepts} - -Provide the classification in the following JSON format: -{{"red": True, "sphere": True, "stem": False, ...}} -""".format(concepts=self.concepts) +The labels must be provided according to the following JSON schema: +{concept_schema} +""".format(concept_schema={"properties":{concept: {"type":"boolean"} for concept in self.concepts}}) } ], },