Skip to content

Commit 827f0e1

Browse files
sdenton4copybara-github
authored andcommitted
Allow configuring and testing XLA in model export.
PiperOrigin-RevId: 723135244
1 parent ac6e047 commit 827f0e1

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

Diff for: chirp/train/classifier.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def export_tf_model(
400400
tf_lite_dtype: str = "float16",
401401
tf_lite_select_ops: bool = True,
402402
export_dir: str | None = None,
403+
enable_xla: bool = False,
403404
):
404405
"""Export SavedModel and TFLite."""
405406
# Get model_ouput keys from output_head_metadatas and add the 'embedding' key
@@ -427,7 +428,7 @@ def infer_fn(audio_batch, variables):
427428
else:
428429
shape = (1,) + input_shape
429430
converted_model = export_utils.Jax2TfModelWrapper(
430-
infer_fn, variables, shape, False
431+
infer_fn, variables, shape, enable_xla=enable_xla
431432
)
432433
class_lists = {
433434
md.key: md.class_list for md in model_bundle.output_head_metadatas

Diff for: chirp/train_tests/frontend_test.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def test_inverse(self, module_type, inverse_module_type, module_kwargs):
132132
"freq_range": (60, 10_000),
133133
},
134134
"atol": 1e-4,
135+
"enable_xla": False,
135136
},
136137
{
137138
"module_type": frontend.SimpleMelspec,
@@ -169,7 +170,12 @@ def test_inverse(self, module_type, inverse_module_type, module_kwargs):
169170
},
170171
)
171172
def test_tflite_stft_export(
172-
self, module_type, module_kwargs, signal_shape=None, atol=1e-6
173+
self,
174+
module_type,
175+
module_kwargs,
176+
signal_shape=None,
177+
atol=1e-6,
178+
enable_xla=False,
173179
):
174180
# Note that the TFLite stft requires power-of-two nfft, given by:
175181
# nfft = 2 * (features - 1).
@@ -182,7 +188,8 @@ def test_tflite_stft_export(
182188

183189
tf_predict = tf.function(
184190
jax2tf.convert(
185-
lambda signal: fe.apply(params, signal), enable_xla=False
191+
lambda signal: fe.apply(params, signal),
192+
enable_xla=enable_xla,
186193
),
187194
input_signature=[
188195
tf.TensorSpec(shape=signal.shape, dtype=tf.float32, name="input")

Diff for: chirp/train_tests/train_test.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,20 @@ def test_config_structure(self):
157157
jax.tree_util.tree_structure(test_config.to_dict()),
158158
)
159159

160-
def test_export_model(self):
160+
@parameterized.named_parameters(
161+
# Note that b0 tests tend to timeout.
162+
# ("xla", True, False),
163+
("no_xla", False, False),
164+
)
165+
def test_export_model(self, enable_xla, test_b0):
161166
# NOTE: This test might fail when run on a machine that has a GPU but when
162167
# CUDA is not linked (JAX will detect the GPU so jax2tf will try to create
163168
# a TF graph on the GPU and fail)
164169
config = self._get_test_config()
165-
config = self._add_const_model_config(config)
170+
if test_b0:
171+
config = self._add_b0_model_config(config)
172+
else:
173+
config = self._add_const_model_config(config)
166174
config = self._add_pcen_melspec_frontend(config)
167175

168176
model_bundle, train_state = classifier.initialize_model(
@@ -177,6 +185,7 @@ def test_export_model(self):
177185
config.init_config.input_shape,
178186
num_train_steps=0,
179187
eval_sleep_s=0,
188+
enable_xla=enable_xla,
180189
)
181190
self.assertTrue(
182191
tf.io.gfile.exists(os.path.join(self.train_dir, "model.tflite"))

0 commit comments

Comments
 (0)