Skip to content

Commit eaf077a

Browse files
authored
DOCS-3423: Update training script to parse labels based on feedback from etai/tahiya (#3941)
1 parent b7ce749 commit eaf077a

File tree

1 file changed

+269
-18
lines changed

1 file changed

+269
-18
lines changed

docs/data-ai/ai/train.md

+269-18
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ my-training/
6363

6464
Add the following code to `setup.py` and add additional required packages on line 11:
6565

66-
```python {class="line-numbers linkable-line-numbers" data-line="11"}
66+
```python {class="line-numbers linkable-line-numbers" data-line="9"}
6767
from setuptools import find_packages, setup
6868

6969
setup(
@@ -72,8 +72,6 @@ setup(
7272
packages=find_packages(),
7373
include_package_data=True,
7474
install_requires=[
75-
"google-cloud-aiplatform",
76-
"google-cloud-storage",
7775
# TODO: Add additional required packages
7876
],
7977
)
@@ -90,15 +88,18 @@ If you haven't already, create a folder called <file>model</file> and create an
9088

9189
<p><strong>4. Add <code>training.py</code> code</strong></p>
9290

93-
<p>Copy this template into <file>training.py</file>:</p>
91+
<p>You can set up your training script to use a hard coded set of labels or allow users to pass in a set of labels when using the training script. Allowing users to pass in labels when using training scripts makes your training script more flexible for reuse.</p>
92+
<p>Copy one of the following templates into <file>training.py</file>, depending on how you want to handle labels:</p>
9493

95-
{{% expand "Click to see the template" %}}
94+
{{% expand "Click to see the template without parsing labels (recommended for use with UI)" %}}
9695

97-
```python {class="line-numbers linkable-line-numbers" data-line="126,170" }
96+
```python {class="line-numbers linkable-line-numbers" data-line="134" }
9897
import argparse
9998
import json
10099
import os
101100
import typing as ty
101+
from tensorflow.keras import Model # Add proper import
102+
import tensorflow as tf # Add proper import
102103

103104
single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
104105
multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
@@ -108,23 +109,29 @@ unknown_label = "UNKNOWN"
108109
API_KEY = os.environ['API_KEY']
109110
API_KEY_ID = os.environ['API_KEY_ID']
110111

112+
DEFAULT_EPOCHS = 200
111113

112114
# This parses the required args for the training script.
113115
# The model_dir variable will contain the output directory where
114116
# the ML model that this script creates should be stored.
115117
# The data_json variable will contain the metadata for the dataset
116118
# that you should use to train the model.
119+
120+
117121
def parse_args():
118-
"""Returns dataset file, model output directory, and num_epochs if present.
119-
These must be parsed as command line arguments and then used as the model
120-
input and output, respectively. The number of epochs can be used to
121-
optionally override the default.
122+
"""Returns dataset file, model output directory, and num_epochs
123+
if present. These must be parsed as command line arguments and then used
124+
as the model input and output, respectively. The number of epochs can be
125+
used to optionally override the default.
122126
"""
123127
parser = argparse.ArgumentParser()
124-
parser.add_argument("--dataset_file", dest="data_json", type=str)
125-
parser.add_argument("--model_output_directory", dest="model_dir", type=str)
128+
parser.add_argument("--dataset_file", dest="data_json",
129+
type=str, required=True)
130+
parser.add_argument("--model_output_directory", dest="model_dir",
131+
type=str, required=True)
126132
parser.add_argument("--num_epochs", dest="num_epochs", type=int)
127133
args = parser.parse_args()
134+
128135
return args.data_json, args.model_dir, args.num_epochs
129136

130137

@@ -250,12 +257,17 @@ def save_model(
250257
model_dir: output directory for model artifacts
251258
model_name: name of saved model
252259
"""
253-
file_type = ""
254-
255-
# Save the model to the output directory.
260+
# Save the model to the output directory
261+
file_type = "tflite" # Add proper file type
256262
filename = os.path.join(model_dir, f"{model_name}.{file_type}")
263+
264+
# Example: Convert to TFLite
265+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
266+
tflite_model = converter.convert()
267+
268+
# Save the model
257269
with open(filename, "wb") as f:
258-
f.write(model)
270+
f.write(tflite_model)
259271

260272

261273
if __name__ == "__main__":
@@ -273,14 +285,244 @@ if __name__ == "__main__":
273285
image_filenames, image_labels = parse_filenames_and_labels_from_json(
274286
DATA_JSON, LABELS, model_type)
275287

288+
# Validate epochs
289+
epochs = (
290+
DEFAULT_EPOCHS if NUM_EPOCHS is None
291+
or NUM_EPOCHS <= 0 else int(NUM_EPOCHS)
292+
)
293+
276294
# Build and compile model on data
277-
model = build_and_compile_model()
295+
model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,))
278296

279297
# Save labels.txt file
280298
save_labels(LABELS + [unknown_label], MODEL_DIR)
281299
# Convert the model to tflite
282300
save_model(
283-
model, MODEL_DIR, "classification_model", IMG_SIZE + (3,)
301+
model, MODEL_DIR, "classification_model"
302+
)
303+
```
304+
305+
{{% /expand %}}
306+
307+
{{% expand "Click to see the template with parsed labels" %}}
308+
309+
```python {class="line-numbers linkable-line-numbers" data-line="148" }
310+
import argparse
311+
import json
312+
import os
313+
import typing as ty
314+
from tensorflow.keras import Model # Add proper import
315+
import tensorflow as tf # Add proper import
316+
317+
single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
318+
multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
319+
labels_filename = "labels.txt"
320+
unknown_label = "UNKNOWN"
321+
322+
API_KEY = os.environ['API_KEY']
323+
API_KEY_ID = os.environ['API_KEY_ID']
324+
325+
DEFAULT_EPOCHS = 200
326+
327+
# This parses the required args for the training script.
328+
# The model_dir variable will contain the output directory where
329+
# the ML model that this script creates should be stored.
330+
# The data_json variable will contain the metadata for the dataset
331+
# that you should use to train the model.
332+
333+
334+
def parse_args():
335+
"""Returns dataset file, model output directory, labels, and num_epochs
336+
if present. These must be parsed as command line arguments and then used
337+
as the model input and output, respectively. The number of epochs can be
338+
used to optionally override the default.
339+
"""
340+
parser = argparse.ArgumentParser()
341+
parser.add_argument("--dataset_file", dest="data_json",
342+
type=str, required=True)
343+
parser.add_argument("--model_output_directory", dest="model_dir",
344+
type=str, required=True)
345+
parser.add_argument("--num_epochs", dest="num_epochs", type=int)
346+
parser.add_argument(
347+
"--labels",
348+
dest="labels",
349+
type=str,
350+
required=True,
351+
help="Space-separated list of labels, \
352+
enclosed in single quotes (e.g., 'label1 label2').",
353+
)
354+
args = parser.parse_args()
355+
356+
if not args.labels:
357+
raise ValueError("Labels must be provided")
358+
359+
labels = [label.strip() for label in args.labels.strip("'").split()]
360+
return args.data_json, args.model_dir, args.num_epochs, labels
361+
362+
363+
# This is used for parsing the dataset file (produced and stored in Viam),
364+
# parse it to get the label annotations
365+
# Used for training classifiction models
366+
367+
368+
def parse_filenames_and_labels_from_json(
369+
filename: str, all_labels: ty.List[str], model_type: str
370+
) -> ty.Tuple[ty.List[str], ty.List[str]]:
371+
"""Load and parse JSON file to return image filenames and corresponding
372+
labels. The JSON file contains lines, where each line has the key
373+
"image_path" and "classification_annotations".
374+
Args:
375+
filename: JSONLines file containing filenames and labels
376+
all_labels: list of all N_LABELS
377+
model_type: string single_label or multi_label
378+
"""
379+
image_filenames = []
380+
image_labels = []
381+
382+
with open(filename, "rb") as f:
383+
for line in f:
384+
json_line = json.loads(line)
385+
image_filenames.append(json_line["image_path"])
386+
387+
annotations = json_line["classification_annotations"]
388+
labels = [unknown_label]
389+
for annotation in annotations:
390+
if model_type == multi_label:
391+
if annotation["annotation_label"] in all_labels:
392+
labels.append(annotation["annotation_label"])
393+
# For single label model, we want at most one label.
394+
# If multiple valid labels are present, we arbitrarily select
395+
# the last one.
396+
if model_type == single_label:
397+
if annotation["annotation_label"] in all_labels:
398+
labels = [annotation["annotation_label"]]
399+
image_labels.append(labels)
400+
return image_filenames, image_labels
401+
402+
403+
# Parse the dataset file (produced and stored in Viam) to get
404+
# bounding box annotations
405+
# Used for training object detection models
406+
def parse_filenames_and_bboxes_from_json(
407+
filename: str,
408+
all_labels: ty.List[str],
409+
) -> ty.Tuple[ty.List[str], ty.List[str], ty.List[ty.List[float]]]:
410+
"""Load and parse JSON file to return image filenames
411+
and corresponding labels with bboxes.
412+
Args:
413+
filename: JSONLines file containing filenames and bboxes
414+
all_labels: list of all N_LABELS
415+
"""
416+
image_filenames = []
417+
bbox_labels = []
418+
bbox_coords = []
419+
420+
with open(filename, "rb") as f:
421+
for line in f:
422+
json_line = json.loads(line)
423+
image_filenames.append(json_line["image_path"])
424+
annotations = json_line["bounding_box_annotations"]
425+
labels = []
426+
coords = []
427+
for annotation in annotations:
428+
if annotation["annotation_label"] in all_labels:
429+
labels.append(annotation["annotation_label"])
430+
# Store coordinates in rel_yxyx format so that
431+
# we can use the keras_cv function
432+
coords.append(
433+
[
434+
annotation["y_min_normalized"],
435+
annotation["x_min_normalized"],
436+
annotation["y_max_normalized"],
437+
annotation["x_max_normalized"],
438+
]
439+
)
440+
bbox_labels.append(labels)
441+
bbox_coords.append(coords)
442+
return image_filenames, bbox_labels, bbox_coords
443+
444+
445+
# Build the model
446+
def build_and_compile_model(
447+
labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int]
448+
) -> Model:
449+
"""Builds and compiles a model
450+
Args:
451+
labels: list of string lists, where each string list contains up to
452+
N_LABEL labels associated with an image
453+
model_type: string single_label or multi_label
454+
input_shape: 3D shape of input
455+
"""
456+
457+
# TODO: Add logic to build and compile model
458+
459+
return model
460+
461+
462+
def save_labels(labels: ty.List[str], model_dir: str) -> None:
463+
"""Saves a label.txt of output labels to the specified model directory.
464+
Args:
465+
labels: list of string lists, where each string list contains up to
466+
N_LABEL labels associated with an image
467+
model_dir: output directory for model artifacts
468+
"""
469+
filename = os.path.join(model_dir, labels_filename)
470+
with open(filename, "w") as f:
471+
for label in labels[:-1]:
472+
f.write(label + "\n")
473+
f.write(labels[-1])
474+
475+
476+
def save_model(
477+
model: Model,
478+
model_dir: str,
479+
model_name: str,
480+
) -> None:
481+
"""Save model as a TFLite model.
482+
Args:
483+
model: trained model
484+
model_dir: output directory for model artifacts
485+
model_name: name of saved model
486+
"""
487+
# Save the model to the output directory
488+
file_type = "tflite" # Add proper file type
489+
filename = os.path.join(model_dir, f"{model_name}.{file_type}")
490+
491+
# Example: Convert to TFLite
492+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
493+
tflite_model = converter.convert()
494+
495+
# Save the model
496+
with open(filename, "wb") as f:
497+
f.write(tflite_model)
498+
499+
500+
if __name__ == "__main__":
501+
DATA_JSON, MODEL_DIR, NUM_EPOCHS, LABELS = parse_args()
502+
503+
IMG_SIZE = (256, 256)
504+
505+
# Read dataset file.
506+
# The model type can be changed based on whether you want the model to
507+
# output one label per image or multiple labels per image
508+
model_type = multi_label
509+
image_filenames, image_labels = parse_filenames_and_labels_from_json(
510+
DATA_JSON, LABELS, model_type)
511+
512+
# Validate epochs
513+
epochs = (
514+
DEFAULT_EPOCHS if NUM_EPOCHS is None
515+
or NUM_EPOCHS <= 0 else int(NUM_EPOCHS)
516+
)
517+
518+
# Build and compile model on data
519+
model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,))
520+
521+
# Save labels.txt file
522+
save_labels(LABELS + [unknown_label], MODEL_DIR)
523+
# Convert the model to tflite
524+
save_model(
525+
model, MODEL_DIR, "classification_model"
284526
)
285527
```
286528

@@ -300,6 +542,10 @@ The script you are creating must take the following command line inputs:
300542
- `dataset_file`: a file containing the data and metadata for the training job
301543
- `model_output_directory`: the location where the produced model artifacts are saved to
302544

545+
If you used the training script template that allows users to pass in labels, it will also take the following command line inputs:
546+
547+
- `labels`: space separated list of labels, enclosed in single quotes
548+
303549
The `parse_args()` function in the template parses your arguments.
304550

305551
You can add additional custom command line inputs by adding them to the `parse_args()` function.
@@ -547,6 +793,11 @@ In the Viam app, navigate to your list of [**DATASETS**](https://app.viam.com/da
547793

548794
Click **Train model** and select **Train on a custom training script**, then follow the prompts.
549795

796+
{{% alert title="Tip" color="tip" %}}
797+
If you used the version of <file>training.py</file> that allows users to pass in labels, your training job will fail with the error `ERROR training.py: error: the following arguments are required: --labels`.
798+
To use labels, you must use the CLI.
799+
{{% /alert %}}
800+
550801
{{% /tab %}}
551802
{{% tab name="CLI" %}}
552803

0 commit comments

Comments
 (0)