16
16
"""Configuration and init library for Search Bootstrap projects."""
17
17
18
18
import dataclasses
19
+ import hashlib
19
20
from typing import Sequence
20
21
21
22
from chirp .inference import embed_lib
@@ -95,15 +96,19 @@ class BootstrapConfig:
95
96
audio_globs : Sequence [str ] | None = None
96
97
model_key : str | None = None
97
98
model_config : config_dict .ConfigDict | None = None
99
+ tf_record_shards : int | None = None
98
100
99
101
@classmethod
100
102
def load_from_embedding_config (
101
- cls , embeddings_path : str , annotated_path : str
103
+ cls , embeddings_path : str , annotated_path : str , tf_record_shards : int = 1
102
104
):
103
105
"""Instantiate from a configuration written alongside embeddings."""
104
106
embedding_config = embed_lib .load_embedding_config (embeddings_path )
105
107
embed_fn_config = embedding_config .embed_fn_config
106
108
tensor_dtype = embed_fn_config .get ('tensor_dtype' , 'float32' )
109
+ tf_record_shards = embedding_config .get (
110
+ 'tf_record_shards' , tf_record_shards
111
+ )
107
112
108
113
# Extract the embedding model config from the embedding_config.
109
114
if embed_fn_config .model_key == 'separate_embed_model' :
@@ -122,4 +127,14 @@ def load_from_embedding_config(
122
127
file_id_depth = embed_fn_config .file_id_depth ,
123
128
audio_globs = embedding_config .source_file_patterns ,
124
129
tensor_dtype = tensor_dtype ,
130
+ tf_record_shards = tf_record_shards ,
125
131
)
132
+
133
+ def embedding_config_hash (self , digest_size : int = 10 ) -> str :
134
+ """Returns a stable hash of the model key and config."""
135
+ config_str = self .model_config .to_json (sort_keys = True )
136
+ encoded_str = f'{ self .model_key } ;{ config_str } ' .encode ('utf-8' )
137
+
138
+ hash_obj = hashlib .blake2b (digest_size = digest_size )
139
+ hash_obj .update (encoded_str )
140
+ return hash_obj .hexdigest ()
0 commit comments