Skip to content

Commit

Permalink
Cache labeled embeddings in active_learning colab.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608650393
  • Loading branch information
Chirp Team authored and copybara-github committed Feb 20, 2024
1 parent 3cdb69a commit bae62f8
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 28 deletions.
7 changes: 6 additions & 1 deletion active_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
"config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
" embeddings_path=embeddings_path,\n",
" annotated_path=labeled_data_path)\n",
"\n",
"# Get hash to identify relevant existing embeddings.\n",
"config_hash = config.embedding_config_hash()\n",
"\n",
"embedding_hop_size_s = config.embedding_hop_size_s\n",
"project_state = bootstrap.BootstrapState(config)\n",
"embedding_model = project_state.embedding_model"
Expand Down Expand Up @@ -125,7 +129,8 @@
" time_pooling=time_pooling,\n",
" load_audio=False,\n",
" target_sample_rate=-2,\n",
" audio_file_pattern='*'\n",
" audio_file_pattern='*',\n",
" embedding_config_hash=config_hash\n",
")\n",
"\n",
"# Label distribution\n",
Expand Down
17 changes: 16 additions & 1 deletion chirp/projects/bootstrap/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Configuration and init library for Search Bootstrap projects."""

import dataclasses
import hashlib
from typing import Sequence

from chirp.inference import embed_lib
Expand Down Expand Up @@ -95,15 +96,19 @@ class BootstrapConfig:
audio_globs: Sequence[str] | None = None
model_key: str | None = None
model_config: config_dict.ConfigDict | None = None
tf_record_shards: int | None = None

@classmethod
def load_from_embedding_config(
cls, embeddings_path: str, annotated_path: str
cls, embeddings_path: str, annotated_path: str, tf_record_shards: int = 1
):
"""Instantiate from a configuration written alongside embeddings."""
embedding_config = embed_lib.load_embedding_config(embeddings_path)
embed_fn_config = embedding_config.embed_fn_config
tensor_dtype = embed_fn_config.get('tensor_dtype', 'float32')
tf_record_shards = embedding_config.get(
'tf_record_shards', tf_record_shards
)

# Extract the embedding model config from the embedding_config.
if embed_fn_config.model_key == 'separate_embed_model':
Expand All @@ -122,4 +127,14 @@ def load_from_embedding_config(
file_id_depth=embed_fn_config.file_id_depth,
audio_globs=embedding_config.source_file_patterns,
tensor_dtype=tensor_dtype,
tf_record_shards=tf_record_shards,
)

def embedding_config_hash(self, digest_size: int = 10) -> str:
"""Returns a stable hash of the model key and config."""
config_str = self.model_config.to_json(sort_keys=True)
encoded_str = f'{self.model_key};{config_str}'.encode('utf-8')

hash_obj = hashlib.blake2b(digest_size=digest_size)
hash_obj.update(encoded_str)
return hash_obj.hexdigest()
Loading

0 comments on commit bae62f8

Please sign in to comment.