forked from roboflow/roboflow-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
402 lines (306 loc) · 12.9 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import io
import json
import os
import time
import urllib
from typing import Optional, Tuple
from urllib.parse import urljoin
import requests
from PIL import Image
from requests_toolbelt.multipart.encoder import MultipartEncoder
from tqdm import tqdm
from roboflow.config import API_URL
from roboflow.util.image_utils import validate_image_path
from roboflow.util.prediction import PredictionGroup
SUPPORTED_ROBOFLOW_MODELS = ["batch-video"]
SUPPORTED_ADDITIONAL_MODELS = {
"clip": {
"model_id": "clip",
"model_version": "1",
"inference_type": "clip-embed-image",
},
"gaze": {
"model_id": "gaze",
"model_version": "1",
"inference_type": "gaze-detection",
},
}
class InferenceModel:
def __init__(
self,
api_key,
version_id,
colors=None,
*args,
**kwargs,
):
"""
Create an InferenceModel object through which you can run inference.
Args:
api_key (str): private roboflow api key
version_id (str): the ID of the dataset version to use for inference
"""
self.__api_key = api_key
self.id = version_id
if version_id != "BASE_MODEL":
version_info = self.id.rsplit("/")
self.dataset_id = version_info[1]
self.version = version_info[2]
self.colors = {} if colors is None else colors
def __get_image_params(self, image_path):
"""
Get parameters about an image (i.e. dimensions) for use in an inference request.
Args:
image_path (str): path to the image you'd like to perform prediction on
Returns:
Tuple containing a dict of querystring params and a dict of requests kwargs
Raises:
Exception: Image path is not valid
"""
validate_image_path(image_path)
hosted_image = urllib.parse.urlparse(image_path).scheme in ("http", "https")
if hosted_image:
image_dims = {"width": "Undefined", "height": "Undefined"}
return {"image": image_path}, {}, image_dims
image = Image.open(image_path)
dimensions = image.size
image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
buffered = io.BytesIO()
image.save(buffered, quality=90, format="JPEG")
data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")})
return (
{},
{"data": data, "headers": {"Content-Type": data.content_type}},
image_dims,
)
def predict(self, image_path, prediction_type=None, **kwargs):
"""
Infers detections based on image from a specified model and image path.
Args:
image_path (str): path to the image you'd like to perform prediction on
prediction_type (str): type of prediction to perform
**kwargs: Any additional kwargs will be turned into querystring params
Returns:
PredictionGroup Object
Raises:
Exception: Image path is not valid
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("YOUR_IMAGE.jpg")
"""
params, request_kwargs, image_dims = self.__get_image_params(image_path)
params["api_key"] = self.__api_key
params.update(**kwargs)
url = f"{self.api_url}?{urllib.parse.urlencode(params)}" # type: ignore[attr-defined]
response = requests.post(url, **request_kwargs)
response.raise_for_status()
return PredictionGroup.create_prediction_group(
response.json(),
image_path=image_path,
prediction_type=prediction_type,
image_dims=image_dims,
colors=self.colors,
)
def predict_video(
self,
video_path: str,
fps: int = 5,
additional_models: list = [],
prediction_type: str = "batch-video",
) -> Tuple[str, str, Optional[str]]:
"""
Infers detections based on image from specified model and image path.
Args:
video_path (str): path to the video you'd like to perform prediction on
prediction_type (str): type of the model to run
fps (int): frames per second to run inference
Returns:
A list of the signed url and job id
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> job_id,signed_url,signed_url_expires = model.predict_video("video.mp4"
,fps=5, inference_type="object-detection")
"""
signed_url_expires = None
url = urljoin(API_URL, "/video_upload_signed_url?api_key=" + self.__api_key)
# if fps > 5:
# raise Exception("FPS must be less than or equal to 5.")
for model in additional_models:
if model not in SUPPORTED_ADDITIONAL_MODELS:
raise Exception(f"Model {model} is not supported for video inference.")
if prediction_type not in SUPPORTED_ROBOFLOW_MODELS:
raise Exception(f"{prediction_type} is not supported for video inference.")
model_class = self.__class__.__name__
if model_class == "ObjectDetectionModel":
self.type = "object-detection"
elif model_class == "ClassificationModel":
self.type = "classification"
elif model_class == "InstanceSegmentationModel":
self.type = "instance-segmentation"
elif model_class == "GazeModel":
self.type = "gaze-detection"
elif model_class == "CLIPModel":
self.type = "clip-embed-image"
elif model_class == "KeypointDetectionModel":
self.type = "keypoint-detection"
else:
raise Exception("Model type not supported for video inference.")
payload = json.dumps(
{
"file_name": os.path.basename(video_path),
}
)
if not video_path.startswith(("http://", "https://")):
headers = {"Content-Type": "application/json"}
try:
response = requests.request("POST", url, headers=headers, data=payload)
except Exception as e:
raise Exception(f"Error uploading video: {e}")
if not response.ok:
raise Exception(f"Error uploading video: {response.text}")
signed_url = response.json()["signed_url"]
signed_url_expires = signed_url.split("&X-Goog-Expires")[1].split("&")[0].strip("=")
# make a POST request to the signed URL
headers = {"Content-Type": "application/octet-stream"}
try:
with open(video_path, "rb") as f:
video_data = f.read()
except Exception as e:
raise Exception(f"Error reading video: {e}")
try:
result = requests.put(signed_url, data=video_data, headers=headers)
except Exception as e:
raise Exception(f"There was an error uploading the video: {e}")
if not result.ok:
raise Exception(f"There was an error uploading the video: {result.text}")
else:
signed_url = video_path
url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key)
if model_class in ("CLIPModel", "GazeModel"):
if model_class == "CLIPModel":
model = "clip"
else:
model = "gaze"
models = [
{
"model_id": SUPPORTED_ADDITIONAL_MODELS[model]["model_id"],
"model_version": SUPPORTED_ADDITIONAL_MODELS[model]["model_version"],
"inference_type": SUPPORTED_ADDITIONAL_MODELS[model]["inference_type"],
}
]
else:
models = [
{
"model_id": self.dataset_id,
"model_version": self.version,
"inference_type": self.type,
}
]
for model in additional_models:
models.append(SUPPORTED_ADDITIONAL_MODELS[model])
payload = json.dumps({"input_url": signed_url, "infer_fps": fps, "models": models})
headers = {"Content-Type": "application/json"}
try:
response = requests.request("POST", url, headers=headers, data=payload)
except Exception as e:
raise Exception(f"Error starting video inference: {e}")
if not response.ok:
raise Exception(f"Error starting video inference: {response.text}")
job_id = response.json()["job_id"]
self.job_id = job_id
return job_id, signed_url, signed_url_expires
def poll_for_video_results(self, job_id: Optional[str] = None) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.
Returns:
Inference results as a dict
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4")
>>> results = model.poll_for_video_results()
"""
if job_id is None:
job_id = self.job_id
url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key + "&job_id=" + self.job_id)
try:
response = requests.get(url, headers={"Content-Type": "application/json"})
except Exception as e:
raise Exception(f"Error getting video inference results: {e}")
if not response.ok:
raise Exception(f"Error getting video inference results: {response.text}")
data = response.json()
if "status" not in data:
return {} # No status available
if data.get("status") > 1:
return data # Error
elif data.get("status") == 1:
return {} # Still running
else: # done
output_signed_url = data["output_signed_url"]
inference_data = requests.get(output_signed_url, headers={"Content-Type": "application/json"})
# frame_offset and model name are top-level keys
return inference_data.json()
def poll_until_video_results(self, job_id) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.
When inference is complete, the results are returned.
Returns:
Inference results as a dict
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4")
>>> results = model.poll_until_results()
"""
if job_id is None:
job_id = self.job_id
attempts = 0
print(f"Checking for video inference results for job {job_id} every 60s")
while True:
time.sleep(60)
print(f"({attempts * 60}s): Checking for inference results")
response = self.poll_for_video_results()
attempts += 1
if response != {}:
return response
def download(self, format="pt", location="."):
"""
Download the weights associated with a model.
Args:
format (str): The format of the output.
- 'pt': returns a PyTorch weights file
location (str): The location to save the weights file to
"""
supported_formats = ["pt"]
if format not in supported_formats:
raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")
workspace, project, version = self.id.rsplit("/")
# get pt url
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
r.raise_for_status()
pt_weights_url = r.json()["weightsUrl"]
response = requests.get(pt_weights_url, stream=True)
# write the zip file to the desired location
with open(location + "/weights.pt", "wb") as f:
total_length = int(response.headers.get("content-length")) # type: ignore[arg-type]
for chunk in tqdm(
response.iter_content(chunk_size=1024),
desc=f"Downloading weights to {location}/weights.pt",
total=int(total_length / 1024) + 1,
):
if chunk:
f.write(chunk)
f.flush()
return