diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py index 05f6c1a4..510d90e7 100644 --- a/lit_nlp/examples/lm_salience_demo.py +++ b/lit_nlp/examples/lm_salience_demo.py @@ -8,7 +8,7 @@ To run with the default configuration (Gemma on TensorFlow via Keras): blaze run -c opt examples:lm_salience_demo -- \ - --models=gemma_instruct_2b_en:gemma_instruct_2b_en \ + --models=gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en \ --port=8890 --alsologtostderr MODELS: @@ -64,7 +64,7 @@ _MODELS = flags.DEFINE_list( "models", - ["gemma_instruct_2b_en:gemma_instruct_2b_en"], + ["gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en"], "Models to load, as :. Path can be a URL, a local file path, or" " the name of a preset for the configured Deep Learning framework (either" " KerasNLP or HuggingFace Transformers; see --dl_framework for more). This" @@ -91,6 +91,10 @@ ), ) +_BATCH_SIZE = flags.DEFINE_integer( + "batch_size", 4, "The number of examples to process per batch.", +) + _DL_BACKEND = flags.DEFINE_enum( "dl_backend", "tensorflow", @@ -278,10 +282,9 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: path = file_cache.cached_path( path, extract_compressed_file=path.endswith(".tar.gz"), - copy_directories=True, ) - if _DL_FRAMEWORK.value == "keras": + if _DL_FRAMEWORK.value == "kerasnlp": # pylint: disable=g-import-not-at-top from keras_nlp import models as keras_models from lit_nlp.examples.models import instrumented_keras_lms as lit_keras @@ -289,7 +292,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Load the weights once for the underlying Keras model. model = keras_models.CausalLM.from_preset(path) models |= lit_keras.initialize_model_group_for_salience( - model_name, model, max_length=512, batch_size=4 + model_name, model, max_length=512, batch_size=_BATCH_SIZE.value ) # Disable embeddings from the generation model. # TODO(lit-dev): re-enable embeddings if we can figure out why UMAP was @@ -301,7 +304,11 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Assuming a valid decoder model name supported by # `transformers.AutoModelForCausalLM` is provided to "path". models |= pretrained_lms.initialize_model_group_for_salience( - model_name, path, framework=_DL_BACKEND.value, max_new_tokens=512 + model_name, + path, + batch_size=_BATCH_SIZE.value, + framework=_DL_BACKEND.value, + max_new_tokens=512, ) for name in datasets: