Skip to content

Commit 2082bfe

Browse files
authored
Merge pull request #23 from WildMeOrg/add-classes-method
Add classes method
2 parents ca724bf + 7a2c09c commit 2082bfe

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

scoutbot/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363

6464
from scoutbot import agg, loc, tile, wic, tile_batched # NOQA
65+
from scoutbot.loc import CONFIGS as LOC_CONFIGS # NOQA
6566

6667
# from tile_batched.models import Yolov8DetectionModel
6768
# from tile_batched import get_sliced_prediction_batched
@@ -460,6 +461,26 @@ def batch_v3(
460461
return wic_list, detects_list
461462

462463

464+
def get_classes():
465+
classes = set()
466+
for config in ['v3', 'v3-cls']:
467+
yolov8_model_path = loc.fetch(config=config)
468+
model = tile_batched.Yolov8DetectionModel(
469+
model_path=yolov8_model_path,
470+
confidence_threshold=0.5,
471+
device='cpu',
472+
)
473+
model_classes = list(model.category_names)
474+
classes.update(model_classes)
475+
476+
for config in ['phase1', 'mvp']:
477+
model_classes = LOC_CONFIGS[config]['classes']
478+
classes.update(model_classes)
479+
480+
classes = list(classes)
481+
return classes
482+
483+
463484
def example():
464485
"""
465486
Run the pipeline on an example image with the default configuration

scoutbot/scoutbot.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ def pipeline_filepath_validator(ctx, param, value):
2929
),
3030
]
3131

32-
shared_options = [
32+
output_option = [
3333
click.option(
3434
'--output',
3535
help='Path to output JSON (if unspecified, results are printed to screen)',
3636
default=None,
3737
type=str,
3838
),
39+
]
40+
41+
shared_options = [
3942
click.option(
4043
'--backend_device', # torch backend device
4144
help='Specifies the device for inference (see YOLO and PyTorch documentation for more information).',
@@ -93,6 +96,7 @@ def _add_options(func):
9396
)
9497
@add_options(model_option)
9598
@add_options(shared_options)
99+
@add_options(output_option)
96100
def pipeline(
97101
filepath,
98102
config,
@@ -183,6 +187,7 @@ def pipeline(
183187
)
184188
@add_options(model_option)
185189
@add_options(shared_options)
190+
@add_options(output_option)
186191
def batch(
187192
filepaths,
188193
config,
@@ -300,6 +305,21 @@ def example():
300305
scoutbot.example()
301306

302307

308+
@click.command('get_classes')
309+
@add_options(output_option)
310+
def get_classes(output):
311+
"""
312+
Run a test of the pipeline on an example image with the default configuration.
313+
"""
314+
classes = scoutbot.get_classes()
315+
log.debug('Outputting classes list...')
316+
if output:
317+
with open(output, 'w') as outfile:
318+
json.dump(classes, outfile)
319+
else:
320+
print(ut.repr3(classes))
321+
322+
303323
@click.group()
304324
def cli():
305325
"""
@@ -312,6 +332,8 @@ def cli():
312332
cli.add_command(batch)
313333
cli.add_command(fetch)
314334
cli.add_command(example)
335+
cli.add_command(get_classes)
336+
315337

316338
if __name__ == '__main__':
317339
cli()

0 commit comments

Comments
 (0)