Skip to content

Commit bae62f8

Browse files
Chirp Teamcopybara-github
authored andcommitted
Cache labeled embeddings in active_learning colab.
PiperOrigin-RevId: 608650393
1 parent 3cdb69a commit bae62f8

File tree

3 files changed

+248
-28
lines changed

3 files changed

+248
-28
lines changed

active_learning.ipynb

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
"config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
9393
" embeddings_path=embeddings_path,\n",
9494
" annotated_path=labeled_data_path)\n",
95+
"\n",
96+
"# Get hash to identify relevant existing embeddings.\n",
97+
"config_hash = config.embedding_config_hash()\n",
98+
"\n",
9599
"embedding_hop_size_s = config.embedding_hop_size_s\n",
96100
"project_state = bootstrap.BootstrapState(config)\n",
97101
"embedding_model = project_state.embedding_model"
@@ -125,7 +129,8 @@
125129
" time_pooling=time_pooling,\n",
126130
" load_audio=False,\n",
127131
" target_sample_rate=-2,\n",
128-
" audio_file_pattern='*'\n",
132+
" audio_file_pattern='*',\n",
133+
" embedding_config_hash=config_hash\n",
129134
")\n",
130135
"\n",
131136
"# Label distribution\n",

chirp/projects/bootstrap/bootstrap.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Configuration and init library for Search Bootstrap projects."""
1717

1818
import dataclasses
19+
import hashlib
1920
from typing import Sequence
2021

2122
from chirp.inference import embed_lib
@@ -95,15 +96,19 @@ class BootstrapConfig:
9596
audio_globs: Sequence[str] | None = None
9697
model_key: str | None = None
9798
model_config: config_dict.ConfigDict | None = None
99+
tf_record_shards: int | None = None
98100

99101
@classmethod
100102
def load_from_embedding_config(
101-
cls, embeddings_path: str, annotated_path: str
103+
cls, embeddings_path: str, annotated_path: str, tf_record_shards: int = 1
102104
):
103105
"""Instantiate from a configuration written alongside embeddings."""
104106
embedding_config = embed_lib.load_embedding_config(embeddings_path)
105107
embed_fn_config = embedding_config.embed_fn_config
106108
tensor_dtype = embed_fn_config.get('tensor_dtype', 'float32')
109+
tf_record_shards = embedding_config.get(
110+
'tf_record_shards', tf_record_shards
111+
)
107112

108113
# Extract the embedding model config from the embedding_config.
109114
if embed_fn_config.model_key == 'separate_embed_model':
@@ -122,4 +127,14 @@ def load_from_embedding_config(
122127
file_id_depth=embed_fn_config.file_id_depth,
123128
audio_globs=embedding_config.source_file_patterns,
124129
tensor_dtype=tensor_dtype,
130+
tf_record_shards=tf_record_shards,
125131
)
132+
133+
def embedding_config_hash(self, digest_size: int = 10) -> str:
134+
"""Returns a stable hash of the model key and config."""
135+
config_str = self.model_config.to_json(sort_keys=True)
136+
encoded_str = f'{self.model_key};{config_str}'.encode('utf-8')
137+
138+
hash_obj = hashlib.blake2b(digest_size=digest_size)
139+
hash_obj.update(encoded_str)
140+
return hash_obj.hexdigest()

0 commit comments

Comments
 (0)