Skip to content

Commit 20bdab5

Browse files
authored
Merge pull request #276 from roboflow/fix-version-model
version.model should return None if no model
2 parents 1920a94 + f8576ba commit 20bdab5

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

roboflow/core/version.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,15 @@ def __init__(
101101

102102
version_without_workspace = os.path.basename(str(version))
103103

104-
if self.type == TYPE_OBJECT_DETECTION:
104+
version_info = requests.get(f"{API_URL}/{workspace}/{project}/{self.version}?api_key={self.__api_key}")
105+
106+
# check if version has a model
107+
if version_info.status_code == 200:
108+
version_info = version_info.json()["version"]
109+
110+
if ("models" in version_info) and (not version_info["models"]):
111+
self.model = None
112+
elif self.type == TYPE_OBJECT_DETECTION:
105113
self.model = ObjectDetectionModel(
106114
self.__api_key,
107115
self.id,

tests/__init__.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def setUp(self):
116116
},
117117
"versions": [
118118
{
119-
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/2",
119+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/{PROJECT_VERSION}",
120120
"name": "augmented-416x416",
121121
"created": 1663104679.539,
122122
"images": 240,
@@ -158,6 +158,58 @@ def setUp(self):
158158
status=200,
159159
)
160160

161+
# Get version
162+
responses.add(
163+
responses.GET,
164+
f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/{PROJECT_VERSION}?api_key={ROBOFLOW_API_KEY}",
165+
json={
166+
"workspace": {"name": WORKSPACE_NAME, "url": WORKSPACE_NAME, "members": 1},
167+
"project": {
168+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}",
169+
"type": "object-detection",
170+
"name": "Hard Hat Sample",
171+
"created": 1593802673.521,
172+
"updated": 1663269501.654,
173+
"images": 100,
174+
"unannotated": 3,
175+
"annotation": "Workers",
176+
"versions": 2,
177+
"public": False,
178+
"splits": {"test": 10, "train": 70, "valid": 20},
179+
"colors": {
180+
"person": "#FF00FF",
181+
"helmet": "#C7FC00",
182+
"head": "#8622FF",
183+
},
184+
"classes": {"person": 9, "helmet": 287, "head": 90},
185+
},
186+
"version": {
187+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/{PROJECT_VERSION}",
188+
"name": "augmented-416x416",
189+
"created": 1663104679.539,
190+
"images": 240,
191+
"splits": {"train": 210, "test": 10, "valid": 20},
192+
"generating": False,
193+
"progress": 1,
194+
"preprocessing": {
195+
"resize": {"height": "416", "enabled": True, "width": "416", "format": "Stretch to"},
196+
"auto-orient": {"enabled": True},
197+
},
198+
"augmentation": {
199+
"blur": {"enabled": True, "pixels": 1.5},
200+
"image": {"enabled": True, "versions": 3},
201+
"rotate": {"degrees": 15, "enabled": True},
202+
"crop": {"enabled": True, "percent": 40, "min": 0},
203+
"flip": {"horizontal": True, "enabled": True, "vertical": False},
204+
},
205+
"exports": [],
206+
"models": {},
207+
"classes": [],
208+
},
209+
},
210+
status=200,
211+
)
212+
161213
# Upload image
162214
responses.add(
163215
responses.POST,

tests/test_queries.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,5 @@ def test_version_fields(self):
6565
@ordered
6666
def test_version_methods(self):
6767
self.assertTrue(
68-
(
69-
isinstance(self.version.model, ClassificationModel)
70-
or (isinstance(self.version.model, ObjectDetectionModel))
71-
)
68+
self.version.model is None or isinstance(self.version.model, (ClassificationModel, ObjectDetectionModel))
7269
)

0 commit comments

Comments
 (0)