@@ -63,7 +63,7 @@ my-training/
63
63
64
64
Add the following code to ` setup.py ` and add additional required packages on line 11:
65
65
66
- ``` python {class="line-numbers linkable-line-numbers" data-line="11 "}
66
+ ``` python {class="line-numbers linkable-line-numbers" data-line="9 "}
67
67
from setuptools import find_packages, setup
68
68
69
69
setup(
72
72
packages = find_packages(),
73
73
include_package_data = True ,
74
74
install_requires = [
75
- " google-cloud-aiplatform" ,
76
- " google-cloud-storage" ,
77
75
# TODO : Add additional required packages
78
76
],
79
77
)
@@ -90,15 +88,18 @@ If you haven't already, create a folder called <file>model</file> and create an
90
88
91
89
<p ><strong >4. Add <code >training.py</code > code</strong ></p >
92
90
93
- <p >Copy this template into <file >training.py</file >:</p >
91
+ <p >You can set up your training script to use a hard coded set of labels or allow users to pass in a set of labels when using the training script. Allowing users to pass in labels when using training scripts makes your training script more flexible for reuse.</p >
92
+ <p >Copy one of the following templates into <file >training.py</file >, depending on how you want to handle labels:</p >
94
93
95
- {{% expand "Click to see the template" %}}
94
+ {{% expand "Click to see the template without parsing labels (recommended for use with UI) " %}}
96
95
97
- ``` python {class="line-numbers linkable-line-numbers" data-line="126,170 " }
96
+ ``` python {class="line-numbers linkable-line-numbers" data-line="134 " }
98
97
import argparse
99
98
import json
100
99
import os
101
100
import typing as ty
101
+ from tensorflow.keras import Model # Add proper import
102
+ import tensorflow as tf # Add proper import
102
103
103
104
single_label = " MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
104
105
multi_label = " MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
@@ -108,23 +109,29 @@ unknown_label = "UNKNOWN"
108
109
API_KEY = os.environ[' API_KEY' ]
109
110
API_KEY_ID = os.environ[' API_KEY_ID' ]
110
111
112
+ DEFAULT_EPOCHS = 200
111
113
112
114
# This parses the required args for the training script.
113
115
# The model_dir variable will contain the output directory where
114
116
# the ML model that this script creates should be stored.
115
117
# The data_json variable will contain the metadata for the dataset
116
118
# that you should use to train the model.
119
+
120
+
117
121
def parse_args ():
118
- """ Returns dataset file, model output directory, and num_epochs if present.
119
- These must be parsed as command line arguments and then used as the model
120
- input and output, respectively. The number of epochs can be used to
121
- optionally override the default.
122
+ """ Returns dataset file, model output directory, and num_epochs
123
+ if present. These must be parsed as command line arguments and then used
124
+ as the model input and output, respectively. The number of epochs can be
125
+ used to optionally override the default.
122
126
"""
123
127
parser = argparse.ArgumentParser()
124
- parser.add_argument(" --dataset_file" , dest = " data_json" , type = str )
125
- parser.add_argument(" --model_output_directory" , dest = " model_dir" , type = str )
128
+ parser.add_argument(" --dataset_file" , dest = " data_json" ,
129
+ type = str , required = True )
130
+ parser.add_argument(" --model_output_directory" , dest = " model_dir" ,
131
+ type = str , required = True )
126
132
parser.add_argument(" --num_epochs" , dest = " num_epochs" , type = int )
127
133
args = parser.parse_args()
134
+
128
135
return args.data_json, args.model_dir, args.num_epochs
129
136
130
137
@@ -250,12 +257,17 @@ def save_model(
250
257
model_dir: output directory for model artifacts
251
258
model_name: name of saved model
252
259
"""
253
- file_type = " "
254
-
255
- # Save the model to the output directory.
260
+ # Save the model to the output directory
261
+ file_type = " tflite" # Add proper file type
256
262
filename = os.path.join(model_dir, f " { model_name} . { file_type} " )
263
+
264
+ # Example: Convert to TFLite
265
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
266
+ tflite_model = converter.convert()
267
+
268
+ # Save the model
257
269
with open (filename, " wb" ) as f:
258
- f.write(model )
270
+ f.write(tflite_model )
259
271
260
272
261
273
if __name__ == " __main__" :
@@ -273,14 +285,244 @@ if __name__ == "__main__":
273
285
image_filenames, image_labels = parse_filenames_and_labels_from_json(
274
286
DATA_JSON , LABELS , model_type)
275
287
288
+ # Validate epochs
289
+ epochs = (
290
+ DEFAULT_EPOCHS if NUM_EPOCHS is None
291
+ or NUM_EPOCHS <= 0 else int (NUM_EPOCHS )
292
+ )
293
+
276
294
# Build and compile model on data
277
- model = build_and_compile_model()
295
+ model = build_and_compile_model(image_labels, model_type, IMG_SIZE + ( 3 ,) )
278
296
279
297
# Save labels.txt file
280
298
save_labels(LABELS + [unknown_label], MODEL_DIR )
281
299
# Convert the model to tflite
282
300
save_model(
283
- model, MODEL_DIR , " classification_model" , IMG_SIZE + (3 ,)
301
+ model, MODEL_DIR , " classification_model"
302
+ )
303
+ ```
304
+
305
+ {{% /expand %}}
306
+
307
+ {{% expand "Click to see the template with parsed labels" %}}
308
+
309
+ ``` python {class="line-numbers linkable-line-numbers" data-line="148" }
310
+ import argparse
311
+ import json
312
+ import os
313
+ import typing as ty
314
+ from tensorflow.keras import Model # Add proper import
315
+ import tensorflow as tf # Add proper import
316
+
317
+ single_label = " MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
318
+ multi_label = " MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
319
+ labels_filename = " labels.txt"
320
+ unknown_label = " UNKNOWN"
321
+
322
+ API_KEY = os.environ[' API_KEY' ]
323
+ API_KEY_ID = os.environ[' API_KEY_ID' ]
324
+
325
+ DEFAULT_EPOCHS = 200
326
+
327
+ # This parses the required args for the training script.
328
+ # The model_dir variable will contain the output directory where
329
+ # the ML model that this script creates should be stored.
330
+ # The data_json variable will contain the metadata for the dataset
331
+ # that you should use to train the model.
332
+
333
+
334
+ def parse_args ():
335
+ """ Returns dataset file, model output directory, labels, and num_epochs
336
+ if present. These must be parsed as command line arguments and then used
337
+ as the model input and output, respectively. The number of epochs can be
338
+ used to optionally override the default.
339
+ """
340
+ parser = argparse.ArgumentParser()
341
+ parser.add_argument(" --dataset_file" , dest = " data_json" ,
342
+ type = str , required = True )
343
+ parser.add_argument(" --model_output_directory" , dest = " model_dir" ,
344
+ type = str , required = True )
345
+ parser.add_argument(" --num_epochs" , dest = " num_epochs" , type = int )
346
+ parser.add_argument(
347
+ " --labels" ,
348
+ dest = " labels" ,
349
+ type = str ,
350
+ required = True ,
351
+ help = " Space-separated list of labels, \
352
+ enclosed in single quotes (e.g., 'label1 label2')." ,
353
+ )
354
+ args = parser.parse_args()
355
+
356
+ if not args.labels:
357
+ raise ValueError (" Labels must be provided" )
358
+
359
+ labels = [label.strip() for label in args.labels.strip(" '" ).split()]
360
+ return args.data_json, args.model_dir, args.num_epochs, labels
361
+
362
+
363
+ # This is used for parsing the dataset file (produced and stored in Viam),
364
+ # parse it to get the label annotations
365
+ # Used for training classifiction models
366
+
367
+
368
+ def parse_filenames_and_labels_from_json (
369
+ filename : str , all_labels : ty.List[str ], model_type : str
370
+ ) -> ty.Tuple[ty.List[str ], ty.List[str ]]:
371
+ """ Load and parse JSON file to return image filenames and corresponding
372
+ labels. The JSON file contains lines, where each line has the key
373
+ "image_path" and "classification_annotations".
374
+ Args:
375
+ filename: JSONLines file containing filenames and labels
376
+ all_labels: list of all N_LABELS
377
+ model_type: string single_label or multi_label
378
+ """
379
+ image_filenames = []
380
+ image_labels = []
381
+
382
+ with open (filename, " rb" ) as f:
383
+ for line in f:
384
+ json_line = json.loads(line)
385
+ image_filenames.append(json_line[" image_path" ])
386
+
387
+ annotations = json_line[" classification_annotations" ]
388
+ labels = [unknown_label]
389
+ for annotation in annotations:
390
+ if model_type == multi_label:
391
+ if annotation[" annotation_label" ] in all_labels:
392
+ labels.append(annotation[" annotation_label" ])
393
+ # For single label model, we want at most one label.
394
+ # If multiple valid labels are present, we arbitrarily select
395
+ # the last one.
396
+ if model_type == single_label:
397
+ if annotation[" annotation_label" ] in all_labels:
398
+ labels = [annotation[" annotation_label" ]]
399
+ image_labels.append(labels)
400
+ return image_filenames, image_labels
401
+
402
+
403
+ # Parse the dataset file (produced and stored in Viam) to get
404
+ # bounding box annotations
405
+ # Used for training object detection models
406
+ def parse_filenames_and_bboxes_from_json (
407
+ filename : str ,
408
+ all_labels : ty.List[str ],
409
+ ) -> ty.Tuple[ty.List[str ], ty.List[str ], ty.List[ty.List[float ]]]:
410
+ """ Load and parse JSON file to return image filenames
411
+ and corresponding labels with bboxes.
412
+ Args:
413
+ filename: JSONLines file containing filenames and bboxes
414
+ all_labels: list of all N_LABELS
415
+ """
416
+ image_filenames = []
417
+ bbox_labels = []
418
+ bbox_coords = []
419
+
420
+ with open (filename, " rb" ) as f:
421
+ for line in f:
422
+ json_line = json.loads(line)
423
+ image_filenames.append(json_line[" image_path" ])
424
+ annotations = json_line[" bounding_box_annotations" ]
425
+ labels = []
426
+ coords = []
427
+ for annotation in annotations:
428
+ if annotation[" annotation_label" ] in all_labels:
429
+ labels.append(annotation[" annotation_label" ])
430
+ # Store coordinates in rel_yxyx format so that
431
+ # we can use the keras_cv function
432
+ coords.append(
433
+ [
434
+ annotation[" y_min_normalized" ],
435
+ annotation[" x_min_normalized" ],
436
+ annotation[" y_max_normalized" ],
437
+ annotation[" x_max_normalized" ],
438
+ ]
439
+ )
440
+ bbox_labels.append(labels)
441
+ bbox_coords.append(coords)
442
+ return image_filenames, bbox_labels, bbox_coords
443
+
444
+
445
+ # Build the model
446
+ def build_and_compile_model (
447
+ labels : ty.List[str ], model_type : str , input_shape : ty.Tuple[int , int , int ]
448
+ ) -> Model:
449
+ """ Builds and compiles a model
450
+ Args:
451
+ labels: list of string lists, where each string list contains up to
452
+ N_LABEL labels associated with an image
453
+ model_type: string single_label or multi_label
454
+ input_shape: 3D shape of input
455
+ """
456
+
457
+ # TODO : Add logic to build and compile model
458
+
459
+ return model
460
+
461
+
462
+ def save_labels (labels : ty.List[str ], model_dir : str ) -> None :
463
+ """ Saves a label.txt of output labels to the specified model directory.
464
+ Args:
465
+ labels: list of string lists, where each string list contains up to
466
+ N_LABEL labels associated with an image
467
+ model_dir: output directory for model artifacts
468
+ """
469
+ filename = os.path.join(model_dir, labels_filename)
470
+ with open (filename, " w" ) as f:
471
+ for label in labels[:- 1 ]:
472
+ f.write(label + " \n " )
473
+ f.write(labels[- 1 ])
474
+
475
+
476
+ def save_model (
477
+ model : Model,
478
+ model_dir : str ,
479
+ model_name : str ,
480
+ ) -> None :
481
+ """ Save model as a TFLite model.
482
+ Args:
483
+ model: trained model
484
+ model_dir: output directory for model artifacts
485
+ model_name: name of saved model
486
+ """
487
+ # Save the model to the output directory
488
+ file_type = " tflite" # Add proper file type
489
+ filename = os.path.join(model_dir, f " { model_name} . { file_type} " )
490
+
491
+ # Example: Convert to TFLite
492
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
493
+ tflite_model = converter.convert()
494
+
495
+ # Save the model
496
+ with open (filename, " wb" ) as f:
497
+ f.write(tflite_model)
498
+
499
+
500
+ if __name__ == " __main__" :
501
+ DATA_JSON , MODEL_DIR , NUM_EPOCHS , LABELS = parse_args()
502
+
503
+ IMG_SIZE = (256 , 256 )
504
+
505
+ # Read dataset file.
506
+ # The model type can be changed based on whether you want the model to
507
+ # output one label per image or multiple labels per image
508
+ model_type = multi_label
509
+ image_filenames, image_labels = parse_filenames_and_labels_from_json(
510
+ DATA_JSON , LABELS , model_type)
511
+
512
+ # Validate epochs
513
+ epochs = (
514
+ DEFAULT_EPOCHS if NUM_EPOCHS is None
515
+ or NUM_EPOCHS <= 0 else int (NUM_EPOCHS )
516
+ )
517
+
518
+ # Build and compile model on data
519
+ model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3 ,))
520
+
521
+ # Save labels.txt file
522
+ save_labels(LABELS + [unknown_label], MODEL_DIR )
523
+ # Convert the model to tflite
524
+ save_model(
525
+ model, MODEL_DIR , " classification_model"
284
526
)
285
527
```
286
528
@@ -300,6 +542,10 @@ The script you are creating must take the following command line inputs:
300
542
- ` dataset_file ` : a file containing the data and metadata for the training job
301
543
- ` model_output_directory ` : the location where the produced model artifacts are saved to
302
544
545
+ If you used the training script template that allows users to pass in labels, it will also take the following command line inputs:
546
+
547
+ - ` labels ` : space separated list of labels, enclosed in single quotes
548
+
303
549
The ` parse_args() ` function in the template parses your arguments.
304
550
305
551
You can add additional custom command line inputs by adding them to the ` parse_args() ` function.
@@ -547,6 +793,11 @@ In the Viam app, navigate to your list of [**DATASETS**](https://app.viam.com/da
547
793
548
794
Click ** Train model** and select ** Train on a custom training script** , then follow the prompts.
549
795
796
+ {{% alert title="Tip" color="tip" %}}
797
+ If you used the version of <file >training.py</file > that allows users to pass in labels, your training job will fail with the error ` ERROR training.py: error: the following arguments are required: --labels ` .
798
+ To use labels, you must use the CLI.
799
+ {{% /alert %}}
800
+
550
801
{{% /tab %}}
551
802
{{% tab name="CLI" %}}
552
803
0 commit comments