Skip to content

Commit 0e8e0b6

Browse files
committed
Merge remote-tracking branch 'origin/main' into bugfix-trailing-slash
2 parents 8ab273e + 2a8917a commit 0e8e0b6

File tree

9 files changed

+208
-30
lines changed

9 files changed

+208
-30
lines changed

roboflow/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_conditional_configuration_variable(key, default):
6969
TYPE_OBJECT_DETECTION = "object-detection"
7070
TYPE_INSTANCE_SEGMENTATION = "instance-segmentation"
7171
TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation"
72+
TYPE_KEYPOINT_DETECTION = "keypoint-detection"
7273

7374
DEFAULT_BATCH_NAME = "Pip Package Upload"
7475
DEFAULT_JOB_NAME = "Annotated via API"

roboflow/core/version.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
DEMO_KEYS,
2020
TYPE_CLASSICATION,
2121
TYPE_INSTANCE_SEGMENTATION,
22+
TYPE_KEYPOINT_DETECTION,
2223
TYPE_OBJECT_DETECTION,
2324
TYPE_SEMANTIC_SEGMENTATION,
2425
UNIVERSE_URL,
2526
)
2627
from roboflow.core.dataset import Dataset
2728
from roboflow.models.classification import ClassificationModel
2829
from roboflow.models.instance_segmentation import InstanceSegmentationModel
30+
from roboflow.models.keypoint_detection import KeypointDetectionModel
2931
from roboflow.models.object_detection import ObjectDetectionModel
3032
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
3133
from roboflow.util.annotations import amend_data_yaml
@@ -124,6 +126,8 @@ def __init__(
124126
)
125127
elif self.type == TYPE_SEMANTIC_SEGMENTATION:
126128
self.model = SemanticSegmentationModel(self.__api_key, self.id)
129+
elif self.type == TYPE_KEYPOINT_DETECTION:
130+
self.model = KeypointDetectionModel(self.__api_key, self.id, version=version_without_workspace)
127131
else:
128132
self.model = None
129133

roboflow/models/classification.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from PIL import Image
99

1010
from roboflow.config import CLASSIFICATION_MODEL
11+
from roboflow.models.inference import InferenceModel
1112
from roboflow.util.image_utils import check_image_url
1213
from roboflow.util.prediction import PredictionGroup
1314

1415

15-
class ClassificationModel:
16+
class ClassificationModel(InferenceModel):
1617
"""
1718
Run inference on a classification model hosted on Roboflow or served through
1819
Roboflow Inference.
@@ -44,6 +45,7 @@ def __init__(
4445
ClassificationModel Object
4546
"""
4647
# Instantiate different API URL parameters
48+
super(ClassificationModel, self).__init__(api_key, id, version=version)
4749
self.__api_key = api_key
4850
self.id = id
4951
self.name = name

roboflow/models/inference.py

+2
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def predict_video(
191191
self.type = "gaze-detection"
192192
elif model_class == "CLIPModel":
193193
self.type = "clip-embed-image"
194+
elif model_class == "KeypointDetectionModel":
195+
self.type = "keypoint-detection"
194196
else:
195197
raise Exception("Model type not supported for video inference.")
196198

roboflow/models/keypoint_detection.py

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import base64
2+
import io
3+
import json
4+
import os
5+
import urllib
6+
7+
import requests
8+
from PIL import Image
9+
10+
from roboflow.config import CLASSIFICATION_MODEL
11+
from roboflow.models.inference import InferenceModel
12+
from roboflow.util.image_utils import check_image_url
13+
from roboflow.util.prediction import PredictionGroup
14+
15+
16+
class KeypointDetectionModel(InferenceModel):
17+
"""
18+
Run inference on a classification model hosted on Roboflow or served through
19+
Roboflow Inference.
20+
"""
21+
22+
def __init__(
23+
self,
24+
api_key: str,
25+
id: str,
26+
name: str = None,
27+
version: int = None,
28+
local: bool = False,
29+
):
30+
"""
31+
Create a ClassificationModel object through which you can run inference.
32+
33+
Args:
34+
api_key (str): private roboflow api key
35+
id (str): the workspace/project id
36+
name (str): is the name of the project
37+
version (int): version number
38+
local (bool): whether the image is local or hosted
39+
colors (dict): colors to use for the image
40+
preprocessing (dict): preprocessing to use for the image
41+
42+
Returns:
43+
ClassificationModel Object
44+
"""
45+
# Instantiate different API URL parameters
46+
super(KeypointDetectionModel, self).__init__(api_key, id, version=version)
47+
self.__api_key = api_key
48+
self.id = id
49+
self.name = name
50+
self.version = version
51+
self.base_url = "https://detect.roboflow.com/"
52+
53+
if self.name is not None and version is not None:
54+
self.__generate_url()
55+
56+
if local:
57+
print("initalizing local keypoint detection model hosted at :" + local)
58+
self.base_url = local
59+
60+
def predict(self, image_path, hosted=False):
61+
"""
62+
Run inference on an image.
63+
64+
Args:
65+
image_path (str): path to the image you'd like to perform prediction on
66+
hosted (bool): whether the image you're providing is hosted on Roboflow
67+
68+
Returns:
69+
PredictionGroup Object
70+
71+
Example:
72+
>>> import roboflow
73+
74+
>>> rf = roboflow.Roboflow(api_key="")
75+
76+
>>> project = rf.workspace().project("PROJECT_ID")
77+
78+
>>> model = project.version("1").model
79+
80+
>>> prediction = model.predict("YOUR_IMAGE.jpg")
81+
"""
82+
self.__generate_url()
83+
self.__exception_check(image_path_check=image_path)
84+
# If image is local image
85+
if not hosted:
86+
# Open Image in RGB Format
87+
image = Image.open(image_path).convert("RGB")
88+
# Create buffer
89+
buffered = io.BytesIO()
90+
image.save(buffered, quality=90, format="JPEG")
91+
img_dims = image.size
92+
# Base64 encode image
93+
img_str = base64.b64encode(buffered.getvalue())
94+
img_str = img_str.decode("ascii")
95+
# Post to API and return response
96+
resp = requests.post(
97+
self.api_url,
98+
data=img_str,
99+
headers={"Content-Type": "application/x-www-form-urlencoded"},
100+
)
101+
else:
102+
# Create API URL for hosted image (slightly different)
103+
self.api_url += "&image=" + urllib.parse.quote_plus(image_path)
104+
# POST to the API
105+
resp = requests.post(self.api_url)
106+
img_dims = {"width": "0", "height": "0"}
107+
108+
if resp.status_code != 200:
109+
raise Exception(resp.text)
110+
111+
return PredictionGroup.create_prediction_group(
112+
resp.json(),
113+
image_dims=img_dims,
114+
image_path=image_path,
115+
prediction_type=CLASSIFICATION_MODEL,
116+
colors=self.colors,
117+
)
118+
119+
def load_model(self, name, version):
120+
"""
121+
Load a model.
122+
123+
Args:
124+
name (str): is the name of the model you'd like to load
125+
version (int): version number
126+
"""
127+
# Load model based on user defined characteristics
128+
self.name = name
129+
self.version = version
130+
self.__generate_url()
131+
132+
def __generate_url(self):
133+
"""
134+
Generate a Roboflow API URL on which to run inference.
135+
136+
Returns:
137+
url (str): the url on which to run inference
138+
"""
139+
140+
# Generates URL based on all parameters
141+
splitted = self.id.rsplit("/")
142+
without_workspace = splitted[1]
143+
version = self.version
144+
if not version and len(splitted) > 2:
145+
version = splitted[2]
146+
147+
self.api_url = "".join(
148+
[
149+
self.base_url + without_workspace + "/" + str(version),
150+
"?api_key=" + self.__api_key,
151+
"&name=YOUR_IMAGE.jpg",
152+
]
153+
)
154+
155+
def __exception_check(self, image_path_check=None):
156+
"""
157+
Check to see if an image exists.
158+
159+
Args:
160+
image_path_check (str): path to the image to check
161+
162+
Raises:
163+
Exception: if image does not exist
164+
"""
165+
# Checks if image exists
166+
if image_path_check is not None:
167+
if not os.path.exists(image_path_check) and not check_image_url(image_path_check):
168+
raise Exception("Image does not exist at " + image_path_check + "!")
169+
170+
def __str__(self):
171+
"""
172+
String representation of classification object
173+
"""
174+
json_value = {
175+
"name": self.name,
176+
"version": self.version,
177+
"base_url": self.base_url,
178+
}
179+
180+
return json.dumps(json_value, indent=2)

roboflow/models/video.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
from roboflow.config import API_URL
1010
from roboflow.models.inference import InferenceModel
1111

12-
SUPPORTED_ROBOFLOW_MODELS = [
13-
"object-detection",
14-
"classification",
15-
"instance-segmentation",
16-
]
12+
SUPPORTED_ROBOFLOW_MODELS = ["object-detection", "classification", "instance-segmentation", "keypoint-detection"]
1713

1814
SUPPORTED_ADDITIONAL_MODELS = {
1915
"clip": {
@@ -97,7 +93,7 @@ def predict(
9793

9894
for model in additional_models:
9995
if model not in SUPPORTED_ADDITIONAL_MODELS:
100-
raise Exception(f"Model {model} is no t supported for video inference.")
96+
raise Exception(f"Model {model} is not supported for video inference.")
10197

10298
if inference_type not in SUPPORTED_ROBOFLOW_MODELS:
10399
raise Exception(f"Model {inference_type} is not supported for video inference.")

roboflow/roboflowpy.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from roboflow.config import APP_URL, get_conditional_configuration_variable, load_roboflow_api_key
1010
from roboflow.models.classification import ClassificationModel
1111
from roboflow.models.instance_segmentation import InstanceSegmentationModel
12+
from roboflow.models.keypoint_detection import KeypointDetectionModel
1213
from roboflow.models.object_detection import ObjectDetectionModel
1314
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
1415

@@ -133,6 +134,7 @@ def infer(args):
133134
"classification": ClassificationModel,
134135
"instance-segmentation": InstanceSegmentationModel,
135136
"semantic-segmentation": SemanticSegmentationModel,
137+
"keypoint-detection": KeypointDetectionModel,
136138
}[projectType]
137139
model = modelClass(api_key, project_url)
138140
kwargs = {}

roboflow/util/folderparser.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,17 @@ def _filterIndividualAnnotations(image, annotation, format):
129129
"iscrowd": 0,
130130
}
131131
imgReference = imgReferences[0]
132-
_annotation = {
133-
"name": "annotation.coco.json",
134-
"parsedType": "coco",
135-
"parsed": {
132+
_annotation = {"name": "annotation.coco.json"}
133+
_annotation["rawText"] = json.dumps(
134+
{
136135
"info": parsed["info"],
137136
"licenses": parsed["licenses"],
138137
"categories": parsed["categories"],
139138
"images": [imgReference],
140139
"annotations": [a for a in parsed["annotations"] if a["image_id"] == imgReference["id"]]
141140
or [fake_annotation],
142-
},
143-
}
144-
_annotation["rawText"] = json.dumps(_annotation["parsed"])
141+
}
142+
)
145143
return _annotation
146144
elif format == "createml":
147145
imgReferences = [i for i in parsed if i["image"] == image["name"]]
@@ -151,22 +149,15 @@ def _filterIndividualAnnotations(image, annotation, format):
151149
imgReference = imgReferences[0]
152150
_annotation = {
153151
"name": "annotation.createml.json",
154-
"parsedType": "createml",
155-
"parsed": [imgReference],
156152
"rawText": json.dumps([imgReference]),
157153
}
158154
return _annotation
159155
elif format == "csv":
160-
imgLines = [l["line"] for l in parsed["lines"] if l["file_name"] == image["name"]]
156+
imgLines = [ld["line"] for ld in parsed["lines"] if ld["file_name"] == image["name"]]
161157
if imgLines:
162158
headers = parsed["headers"]
163159
_annotation = {
164160
"name": "annotation.csv",
165-
"parsedType": "csv",
166-
"parsed": {
167-
"headers": headers,
168-
"lines": imgLines,
169-
},
170161
"rawText": "".join([headers] + imgLines),
171162
}
172163
return _annotation
@@ -198,7 +189,7 @@ def _parseAnnotationCSV(filename):
198189
with open(filename, "r") as f:
199190
lines = f.readlines()
200191
headers = lines[0]
201-
lines = [{"file_name": l.split(",")[0].strip(), "line": l} for l in lines[1:]]
192+
lines = [{"file_name": ld.split(",")[0].strip(), "line": ld} for ld in lines[1:]]
202193
return {
203194
"headers": headers,
204195
"lines": lines,

tests/util/test_folderparser.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ def test_parse_sharks_coco(self):
1818
parsed = folderparser.parsefolder(sharksfolder)
1919
testImagePath = "/train/sharks_mp4-20_jpg.rf.90ba2e8e9ca0613f71359efb7ed48b26.jpg"
2020
testImage = [i for i in parsed["images"] if i["file"] == testImagePath][0]
21-
assert len(testImage["annotationfile"]["parsed"]["annotations"]) == 5
21+
assert len(json.loads(testImage["annotationfile"]["rawText"])["annotations"]) == 5
2222

2323
def test_parse_sharks_createml(self):
2424
sharksfolder = f"{thisdir}/../datasets/sharks-tiny-createml"
2525
parsed = folderparser.parsefolder(sharksfolder)
2626
testImagePath = "/train/sharks_mp4-20_jpg.rf.5359121123e86e016401ea2731e47949.jpg"
2727
testImage = [i for i in parsed["images"] if i["file"] == testImagePath][0]
28-
assert len(testImage["annotationfile"]["parsed"]) == 1
29-
imgReference = testImage["annotationfile"]["parsed"][0]
28+
imgParsedAnnotations = json.loads(testImage["annotationfile"]["rawText"])
29+
assert len(imgParsedAnnotations) == 1
30+
imgReference = imgParsedAnnotations[0]
3031
assert len(imgReference["annotations"]) == 5
3132

3233
def test_parse_sharks_yolov9(self):
@@ -47,10 +48,9 @@ def test_parse_mosquitos_csv(self):
4748
testImagePath = "/train_10308.jpeg"
4849
testImage = [i for i in parsed["images"] if i["file"] == testImagePath][0]
4950
assert testImage["annotationfile"]["name"] == "annotation.csv"
50-
headers = testImage["annotationfile"]["parsed"]["headers"]
51-
lines = testImage["annotationfile"]["parsed"]["lines"]
52-
assert headers == "img_fName,img_w,img_h,class_label,bbx_xtl,bbx_ytl,bbx_xbr,bbx_ybr\n"
53-
assert lines == ["train_10308.jpeg,1058,943,japonicus/koreicus,28,187,908,815\n"]
51+
expected = "img_fName,img_w,img_h,class_label,bbx_xtl,bbx_ytl,bbx_xbr,bbx_ybr\n"
52+
expected += "train_10308.jpeg,1058,943,japonicus/koreicus,28,187,908,815\n"
53+
assert testImage["annotationfile"]["rawText"] == expected
5454

5555

5656
def _assertJsonMatchesFile(actual, filename):

0 commit comments

Comments
 (0)