1616"""Configuration and init library for Search Bootstrap projects."""
1717
1818import dataclasses
19+ import hashlib
1920from typing import Sequence
2021
2122from chirp .inference import embed_lib
@@ -95,15 +96,19 @@ class BootstrapConfig:
9596 audio_globs : Sequence [str ] | None = None
9697 model_key : str | None = None
9798 model_config : config_dict .ConfigDict | None = None
99+ tf_record_shards : int | None = None
98100
99101 @classmethod
100102 def load_from_embedding_config (
101- cls , embeddings_path : str , annotated_path : str
103+ cls , embeddings_path : str , annotated_path : str , tf_record_shards : int = 1
102104 ):
103105 """Instantiate from a configuration written alongside embeddings."""
104106 embedding_config = embed_lib .load_embedding_config (embeddings_path )
105107 embed_fn_config = embedding_config .embed_fn_config
106108 tensor_dtype = embed_fn_config .get ('tensor_dtype' , 'float32' )
109+ tf_record_shards = embedding_config .get (
110+ 'tf_record_shards' , tf_record_shards
111+ )
107112
108113 # Extract the embedding model config from the embedding_config.
109114 if embed_fn_config .model_key == 'separate_embed_model' :
@@ -122,4 +127,14 @@ def load_from_embedding_config(
122127 file_id_depth = embed_fn_config .file_id_depth ,
123128 audio_globs = embedding_config .source_file_patterns ,
124129 tensor_dtype = tensor_dtype ,
130+ tf_record_shards = tf_record_shards ,
125131 )
132+
133+ def embedding_config_hash (self , digest_size : int = 10 ) -> str :
134+ """Returns a stable hash of the model key and config."""
135+ config_str = self .model_config .to_json (sort_keys = True )
136+ encoded_str = f'{ self .model_key } ;{ config_str } ' .encode ('utf-8' )
137+
138+ hash_obj = hashlib .blake2b (digest_size = digest_size )
139+ hash_obj .update (encoded_str )
140+ return hash_obj .hexdigest ()
0 commit comments