Skip to content

Add yolo-world pre-annotator example #444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions label_studio_ml/examples/yolo_world/.env.dev
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# yolo-world ultralytics configuration
CHECKPOINT="yolov8l-world.pt"
CONF_THRESHOLD=0.1
IOU_THRESHOLD=0.3

# your label studio host and credentals
LABEL_STUDIO_HOST="host_name:port"
LABEL_STUDIO_ACCESS_TOKEN="token"

# your s3 endpoint
AWS_ENDPOINT_URL="host_name:port"
AWS_ACCESS_KEY="your-access-key"
AWS_SECRET_ACCESS="your-secret-key"


15 changes: 15 additions & 0 deletions label_studio_ml/examples/yolo_world/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM python:3.11-slim

WORKDIR /app

RUN apt-get -y update \
&& apt-get install -y git \
&& apt-get install -y wget \
&& apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev gcc

# Install Base
COPY requirements.txt .
RUN pip install --upgrade pip
RUN pip install --no-cache-dir -r requirements.txt

WORKDIR /app
27 changes: 27 additions & 0 deletions label_studio_ml/examples/yolo_world/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# YOLO-WORLD
This is an example of running YOLO-World as a pre-annotator with prompts. The example shows how to make a backend in case if you are using s3 (minio) storage.
### Run pre-annotator
```bash
docker compose up -d
```
Don't forget change default port on available one in docker-compose.yml.

### Configure label studio prompt window in project_name/Settings/Labeling Interface:
```html
<View>
<Image name="image" value="$image"/>
<Style>
.lsf-main-content.lsf-requesting .prompt::before { content: ' loading...'; color: #808080; }
</Style>
<View className="prompt">
<TextArea name="prompt" toName="image" editable="true" rows="2" maxSubmissions="1" showSubmitButton="true"/>
</View>
<RectangleLabels name="label" toName="image">
<Label value="dog" background="blue"/>
<Label value="cat" background="#FFA39E"/>
</RectangleLabels>
</View>
```

### Input window example
![labeling](assets/yolo-world-example.gif)
114 changes: 114 additions & 0 deletions label_studio_ml/examples/yolo_world/_wsgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import argparse
import json
import logging
import logging.config

logging.config.dictConfig({
"version": 1,
"formatters": {
"standard": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": os.getenv('LOG_LEVEL'),
"stream": "ext://sys.stdout",
"formatter": "standard"
}
},
"root": {
"level": os.getenv('LOG_LEVEL'),
"handlers": [
"console"
],
"propagate": True
}
})

from label_studio_ml.api import init_app
from yolo_world import YOLOWorldBackend


_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')


def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
if not os.path.exists(config_path):
return dict()
with open(config_path) as f:
config = json.load(f)
assert isinstance(config, dict)
return config


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Label studio')
parser.add_argument(
'-p', '--port', dest='port', type=int, default=9090,
help='Server port')
parser.add_argument(
'--host', dest='host', type=str, default='0.0.0.0',
help='Server host')
parser.add_argument(
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
help='Additional LabelStudioMLBase model initialization kwargs')
parser.add_argument(
'-d', '--debug', dest='debug', action='store_true',
help='Switch debug mode')
parser.add_argument(
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
help='Logging level')
parser.add_argument(
'--model-dir', dest='model_dir', default=os.path.dirname(__file__),
help='Directory where models are stored (relative to the project directory)')
parser.add_argument(
'--check', dest='check', action='store_true',
help='Validate model instance before launching server')

args = parser.parse_args()

# setup logging level
if args.log_level:
logging.root.setLevel(args.log_level)

def isfloat(value):
try:
float(value)
return True
except ValueError:
return False

def parse_kwargs():
param = dict()
for k, v in args.kwargs:
if v.isdigit():
param[k] = int(v)
elif v == 'True' or v == 'true':
param[k] = True
elif v == 'False' or v == 'false':
param[k] = False
elif isfloat(v):
param[k] = float(v)
else:
param[k] = v
return param

kwargs = get_kwargs_from_config()

if args.kwargs:
kwargs.update(parse_kwargs())

if args.check:
print('Check "' + YOLOWorldBackend.__name__ + '" instance creation..')
model = YOLOWorldBackend(**kwargs)

app = init_app(model_class=YOLOWorldBackend)

app.run(host=args.host, port=args.port, debug=args.debug)

else:
# for uWSGI use
app = init_app(model_class=YOLOWorldBackend)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions label_studio_ml/examples/yolo_world/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
version: "3.11"

services:
ml-yolo-world-backend:
build:
context: .
runtime: nvidia
env_file:
- ./.env.dev
volumes:
- ./yolo_world.py:/app/yolo_world.py
- ./_wsgi.py:/app/_wsgi.py
ports:
- "9090:9090"
command: gunicorn --preload --bind :9090 --workers 2 --threads 4 --timeout 0 _wsgi:app

13 changes: 13 additions & 0 deletions label_studio_ml/examples/yolo_world/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# base
gunicorn==20.1.0
label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git
Flask~=2.3
colorama~=0.4
requests~=2.31
dill==0.3.8
boto3==1.34.29
label-studio-tools

# yolo-world
torch==2.0.1
ultralytics==8.1.15
143 changes: 143 additions & 0 deletions label_studio_ml/examples/yolo_world/yolo_world.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import logging
from uuid import uuid4
from urllib.parse import urlparse
from typing import List, Dict, Optional

from ultralytics import YOLOWorld
import boto3
from botocore.exceptions import ClientError

from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import get_image_local_path


APPLICATION_NAME = "task_label_studio_backend_yolo_world"
logger = logging.getLogger(APPLICATION_NAME)

# Model checkpoint
CHECKPOINT = os.environ.get("CHECKPOINT", "yolov8l-world.pt")

# Label Studio
LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST")
LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN")

# S3 credentials
AWS_ENDPOINT_URL = os.environ.get('AWS_ENDPOINT_URL')
AWS_ACCESS_KEY = os.environ.get('AWS_ACCESS_KEY')
AWS_SECRET_ACCESS = os.environ.get('AWS_SECRET_ACCESS')

# Model thresholds
CONF_THRESHOLD = os.environ.get("CONF_THRESHOLD", 0.1)
IOU_THRESHOLD = os.environ.get("IOU_THRESHOLD", 0.3)


class YOLOWorldBackend(LabelStudioMLBase):

def __init__(self, project_id, **kwargs):
# don't forget to initialize base class...
super().__init__(**kwargs)
self.model = YOLOWorld(CHECKPOINT)
self.conf_thres = float(CONF_THRESHOLD)
self.iou_thres = float(IOU_THRESHOLD)

@staticmethod
def _get_image_url(image_url):
if image_url.startswith('s3://'):
r = urlparse(image_url, allow_fragments=False)
bucket_name = r.netloc
key = r.path.lstrip('/')
client = boto3.client('s3',
endpoint_url=AWS_ENDPOINT_URL,
aws_access_key_id=AWS_ACCESS_KEY,
aws_secret_access_key=AWS_SECRET_ACCESS,
)
try:
image_url = client.generate_presigned_url(
ClientMethod='get_object',
Params={'Bucket': bucket_name, 'Key': key}
)
except ClientError as exc:
logger.warning(f'Can\'t generate presigned URL for {image_url}. Reason: {exc}')
return image_url

def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs):
if not context or not context.get('result'):
return []
prompt = context['result'][0]['value']['text'][0]
logger.info(f"PROMPT: {prompt}")
self.from_name_r, self.to_name_r, self.value_r = self.get_first_tag_occurence('RectangleLabels', 'Image')
return self._predict(tasks, prompt)

def _predict(self, tasks: List[Dict], prompt: str):

# parse prompt
labels = prompt.split(", ")
self.model.set_classes(labels)

image_paths = []
for task in tasks:
raw_img_path = task['data']['image']
try:
image_url = self._get_image_url(raw_img_path)
img_path = get_image_local_path(
image_url,
label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN,
label_studio_host=LABEL_STUDIO_HOST
)
except:
img_path = raw_img_path
image_paths.append(img_path)

predictions = []
for image_path in image_paths:
result = self.model.predict(image_path, conf=self.conf_thres, iou=self.iou_thres)[0]
img_height, img_width = result.orig_shape
box_by_task = result.boxes.xyxy.cpu().numpy().astype(float)
scores = result.boxes.conf.cpu().numpy().astype(float)
classes =result.boxes.cls.cpu().numpy().astype(int)
all_points = []
all_scores = []
all_lengths = []
all_classes = []

for box, score, cls in zip(box_by_task, scores, classes):
all_points.append(box)
all_scores.append(score)
all_classes.append(labels[cls])
all_lengths.append((img_height, img_width))

predictions.append(self.get_results(all_points, all_scores, all_classes, all_lengths))

return predictions

def get_results(self, all_points, all_scores, all_classes, all_lengths):
results = []
for box, score, cls, length in zip(all_points, all_scores, all_classes, all_lengths):
# random ID
label_id = str(uuid4())[:9]

height, width = length
results.append({
'id': label_id,
'from_name': self.from_name_r,
'to_name': self.to_name_r,
'original_width': width,
'original_height': height,
'image_rotation': 0,
'value': {
'rotation': 0,
'rectanglelabels': [cls],
'width': (box[2] - box[0]) / width * 100,
'height': (box[3] - box[1]) / height * 100,
'x': box[0] / width * 100,
'y': box[1] / height * 100
},
'score': score,
'type': 'rectanglelabels',
'readonly': False
})

return {
'result': results
}
Loading