diff --git a/docs/data-ai/ai/train.md b/docs/data-ai/ai/train.md
index a91301554d..266cf6ef72 100644
--- a/docs/data-ai/ai/train.md
+++ b/docs/data-ai/ai/train.md
@@ -63,7 +63,7 @@ my-training/
Add the following code to `setup.py` and add additional required packages on line 11:
-```python {class="line-numbers linkable-line-numbers" data-line="11"}
+```python {class="line-numbers linkable-line-numbers" data-line="9"}
from setuptools import find_packages, setup
setup(
@@ -72,8 +72,6 @@ setup(
packages=find_packages(),
include_package_data=True,
install_requires=[
- "google-cloud-aiplatform",
- "google-cloud-storage",
# TODO: Add additional required packages
],
)
@@ -90,15 +88,18 @@ If you haven't already, create a folder called model and create an
4. Add training.py
code
-Copy this template into training.py:
+You can set up your training script to use a hard coded set of labels or allow users to pass in a set of labels when using the training script. Allowing users to pass in labels when using training scripts makes your training script more flexible for reuse.
+Copy one of the following templates into training.py, depending on how you want to handle labels:
-{{% expand "Click to see the template" %}}
+{{% expand "Click to see the template without parsing labels (recommended for use with UI)" %}}
-```python {class="line-numbers linkable-line-numbers" data-line="126,170" }
+```python {class="line-numbers linkable-line-numbers" data-line="134" }
import argparse
import json
import os
import typing as ty
+from tensorflow.keras import Model # Add proper import
+import tensorflow as tf # Add proper import
single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
@@ -108,23 +109,29 @@ unknown_label = "UNKNOWN"
API_KEY = os.environ['API_KEY']
API_KEY_ID = os.environ['API_KEY_ID']
+DEFAULT_EPOCHS = 200
# This parses the required args for the training script.
# The model_dir variable will contain the output directory where
# the ML model that this script creates should be stored.
# The data_json variable will contain the metadata for the dataset
# that you should use to train the model.
+
+
def parse_args():
- """Returns dataset file, model output directory, and num_epochs if present.
- These must be parsed as command line arguments and then used as the model
- input and output, respectively. The number of epochs can be used to
- optionally override the default.
+ """Returns dataset file, model output directory, and num_epochs
+ if present. These must be parsed as command line arguments and then used
+ as the model input and output, respectively. The number of epochs can be
+ used to optionally override the default.
"""
parser = argparse.ArgumentParser()
- parser.add_argument("--dataset_file", dest="data_json", type=str)
- parser.add_argument("--model_output_directory", dest="model_dir", type=str)
+ parser.add_argument("--dataset_file", dest="data_json",
+ type=str, required=True)
+ parser.add_argument("--model_output_directory", dest="model_dir",
+ type=str, required=True)
parser.add_argument("--num_epochs", dest="num_epochs", type=int)
args = parser.parse_args()
+
return args.data_json, args.model_dir, args.num_epochs
@@ -250,12 +257,17 @@ def save_model(
model_dir: output directory for model artifacts
model_name: name of saved model
"""
- file_type = ""
-
- # Save the model to the output directory.
+ # Save the model to the output directory
+ file_type = "tflite" # Add proper file type
filename = os.path.join(model_dir, f"{model_name}.{file_type}")
+
+ # Example: Convert to TFLite
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ tflite_model = converter.convert()
+
+ # Save the model
with open(filename, "wb") as f:
- f.write(model)
+ f.write(tflite_model)
if __name__ == "__main__":
@@ -273,14 +285,244 @@ if __name__ == "__main__":
image_filenames, image_labels = parse_filenames_and_labels_from_json(
DATA_JSON, LABELS, model_type)
+ # Validate epochs
+ epochs = (
+ DEFAULT_EPOCHS if NUM_EPOCHS is None
+ or NUM_EPOCHS <= 0 else int(NUM_EPOCHS)
+ )
+
# Build and compile model on data
- model = build_and_compile_model()
+ model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,))
# Save labels.txt file
save_labels(LABELS + [unknown_label], MODEL_DIR)
# Convert the model to tflite
save_model(
- model, MODEL_DIR, "classification_model", IMG_SIZE + (3,)
+ model, MODEL_DIR, "classification_model"
+ )
+```
+
+{{% /expand %}}
+
+{{% expand "Click to see the template with parsed labels" %}}
+
+```python {class="line-numbers linkable-line-numbers" data-line="148" }
+import argparse
+import json
+import os
+import typing as ty
+from tensorflow.keras import Model # Add proper import
+import tensorflow as tf # Add proper import
+
+single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
+multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
+labels_filename = "labels.txt"
+unknown_label = "UNKNOWN"
+
+API_KEY = os.environ['API_KEY']
+API_KEY_ID = os.environ['API_KEY_ID']
+
+DEFAULT_EPOCHS = 200
+
+# This parses the required args for the training script.
+# The model_dir variable will contain the output directory where
+# the ML model that this script creates should be stored.
+# The data_json variable will contain the metadata for the dataset
+# that you should use to train the model.
+
+
+def parse_args():
+ """Returns dataset file, model output directory, labels, and num_epochs
+ if present. These must be parsed as command line arguments and then used
+ as the model input and output, respectively. The number of epochs can be
+ used to optionally override the default.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_file", dest="data_json",
+ type=str, required=True)
+ parser.add_argument("--model_output_directory", dest="model_dir",
+ type=str, required=True)
+ parser.add_argument("--num_epochs", dest="num_epochs", type=int)
+ parser.add_argument(
+ "--labels",
+ dest="labels",
+ type=str,
+ required=True,
+ help="Space-separated list of labels, \
+ enclosed in single quotes (e.g., 'label1 label2').",
+ )
+ args = parser.parse_args()
+
+ if not args.labels:
+ raise ValueError("Labels must be provided")
+
+ labels = [label.strip() for label in args.labels.strip("'").split()]
+ return args.data_json, args.model_dir, args.num_epochs, labels
+
+
+# This is used for parsing the dataset file (produced and stored in Viam),
+# parse it to get the label annotations
+# Used for training classifiction models
+
+
+def parse_filenames_and_labels_from_json(
+ filename: str, all_labels: ty.List[str], model_type: str
+) -> ty.Tuple[ty.List[str], ty.List[str]]:
+ """Load and parse JSON file to return image filenames and corresponding
+ labels. The JSON file contains lines, where each line has the key
+ "image_path" and "classification_annotations".
+ Args:
+ filename: JSONLines file containing filenames and labels
+ all_labels: list of all N_LABELS
+ model_type: string single_label or multi_label
+ """
+ image_filenames = []
+ image_labels = []
+
+ with open(filename, "rb") as f:
+ for line in f:
+ json_line = json.loads(line)
+ image_filenames.append(json_line["image_path"])
+
+ annotations = json_line["classification_annotations"]
+ labels = [unknown_label]
+ for annotation in annotations:
+ if model_type == multi_label:
+ if annotation["annotation_label"] in all_labels:
+ labels.append(annotation["annotation_label"])
+ # For single label model, we want at most one label.
+ # If multiple valid labels are present, we arbitrarily select
+ # the last one.
+ if model_type == single_label:
+ if annotation["annotation_label"] in all_labels:
+ labels = [annotation["annotation_label"]]
+ image_labels.append(labels)
+ return image_filenames, image_labels
+
+
+# Parse the dataset file (produced and stored in Viam) to get
+# bounding box annotations
+# Used for training object detection models
+def parse_filenames_and_bboxes_from_json(
+ filename: str,
+ all_labels: ty.List[str],
+) -> ty.Tuple[ty.List[str], ty.List[str], ty.List[ty.List[float]]]:
+ """Load and parse JSON file to return image filenames
+ and corresponding labels with bboxes.
+ Args:
+ filename: JSONLines file containing filenames and bboxes
+ all_labels: list of all N_LABELS
+ """
+ image_filenames = []
+ bbox_labels = []
+ bbox_coords = []
+
+ with open(filename, "rb") as f:
+ for line in f:
+ json_line = json.loads(line)
+ image_filenames.append(json_line["image_path"])
+ annotations = json_line["bounding_box_annotations"]
+ labels = []
+ coords = []
+ for annotation in annotations:
+ if annotation["annotation_label"] in all_labels:
+ labels.append(annotation["annotation_label"])
+ # Store coordinates in rel_yxyx format so that
+ # we can use the keras_cv function
+ coords.append(
+ [
+ annotation["y_min_normalized"],
+ annotation["x_min_normalized"],
+ annotation["y_max_normalized"],
+ annotation["x_max_normalized"],
+ ]
+ )
+ bbox_labels.append(labels)
+ bbox_coords.append(coords)
+ return image_filenames, bbox_labels, bbox_coords
+
+
+# Build the model
+def build_and_compile_model(
+ labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int]
+) -> Model:
+ """Builds and compiles a model
+ Args:
+ labels: list of string lists, where each string list contains up to
+ N_LABEL labels associated with an image
+ model_type: string single_label or multi_label
+ input_shape: 3D shape of input
+ """
+
+ # TODO: Add logic to build and compile model
+
+ return model
+
+
+def save_labels(labels: ty.List[str], model_dir: str) -> None:
+ """Saves a label.txt of output labels to the specified model directory.
+ Args:
+ labels: list of string lists, where each string list contains up to
+ N_LABEL labels associated with an image
+ model_dir: output directory for model artifacts
+ """
+ filename = os.path.join(model_dir, labels_filename)
+ with open(filename, "w") as f:
+ for label in labels[:-1]:
+ f.write(label + "\n")
+ f.write(labels[-1])
+
+
+def save_model(
+ model: Model,
+ model_dir: str,
+ model_name: str,
+) -> None:
+ """Save model as a TFLite model.
+ Args:
+ model: trained model
+ model_dir: output directory for model artifacts
+ model_name: name of saved model
+ """
+ # Save the model to the output directory
+ file_type = "tflite" # Add proper file type
+ filename = os.path.join(model_dir, f"{model_name}.{file_type}")
+
+ # Example: Convert to TFLite
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ tflite_model = converter.convert()
+
+ # Save the model
+ with open(filename, "wb") as f:
+ f.write(tflite_model)
+
+
+if __name__ == "__main__":
+ DATA_JSON, MODEL_DIR, NUM_EPOCHS, LABELS = parse_args()
+
+ IMG_SIZE = (256, 256)
+
+ # Read dataset file.
+ # The model type can be changed based on whether you want the model to
+ # output one label per image or multiple labels per image
+ model_type = multi_label
+ image_filenames, image_labels = parse_filenames_and_labels_from_json(
+ DATA_JSON, LABELS, model_type)
+
+ # Validate epochs
+ epochs = (
+ DEFAULT_EPOCHS if NUM_EPOCHS is None
+ or NUM_EPOCHS <= 0 else int(NUM_EPOCHS)
+ )
+
+ # Build and compile model on data
+ model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,))
+
+ # Save labels.txt file
+ save_labels(LABELS + [unknown_label], MODEL_DIR)
+ # Convert the model to tflite
+ save_model(
+ model, MODEL_DIR, "classification_model"
)
```
@@ -300,6 +542,10 @@ The script you are creating must take the following command line inputs:
- `dataset_file`: a file containing the data and metadata for the training job
- `model_output_directory`: the location where the produced model artifacts are saved to
+If you used the training script template that allows users to pass in labels, it will also take the following command line inputs:
+
+- `labels`: space separated list of labels, enclosed in single quotes
+
The `parse_args()` function in the template parses your arguments.
You can add additional custom command line inputs by adding them to the `parse_args()` function.
@@ -547,6 +793,11 @@ In the Viam app, navigate to your list of [**DATASETS**](https://app.viam.com/da
Click **Train model** and select **Train on a custom training script**, then follow the prompts.
+{{% alert title="Tip" color="tip" %}}
+If you used the version of training.py that allows users to pass in labels, your training job will fail with the error `ERROR training.py: error: the following arguments are required: --labels`.
+To use labels, you must use the CLI.
+{{% /alert %}}
+
{{% /tab %}}
{{% tab name="CLI" %}}