Skip to content

Commit 7fe2c86

Browse files
authored
raise error with non-existence image prompts (#1322)
* print non-existence image prompt * reformat
1 parent 76c1cd2 commit 7fe2c86

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

assets/view.jpg

93.3 KB
Loading

torchchat/generate.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,6 @@
2424

2525
from PIL import Image
2626

27-
# torchtune model definition dependencies
28-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
29-
30-
from torchtune.generation import sample as tune_sample
31-
from torchtune.models.llama3 import llama3_tokenizer
32-
33-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
34-
from torchtune.training import set_default_dtype
35-
3627
from torchchat.cli.builder import (
3728
_initialize_model,
3829
_initialize_tokenizer,
@@ -43,6 +34,15 @@
4334
from torchchat.utils.build_utils import device_sync, set_precision
4435
from torchchat.utils.device_info import get_device_info
4536

37+
# torchtune model definition dependencies
38+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
39+
40+
from torchtune.generation import sample as tune_sample
41+
from torchtune.models.llama3 import llama3_tokenizer
42+
43+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
44+
from torchtune.training import set_default_dtype
45+
4646

4747
class _ChatFormatter(ABC):
4848
def __init__(self, tokenizer):
@@ -179,8 +179,15 @@ def from_args(cls, args):
179179

180180
# Validate that all image prompts exist before expensive model load
181181
if image_prompts := getattr(args, "image_prompts", None):
182-
if not all(os.path.exists(image_prompt) for image_prompt in image_prompts):
183-
raise RuntimeError(f"Image prompt {image_prompt} does not exist")
182+
non_existent_image_prompts = [
183+
image_prompt
184+
for image_prompt in image_prompts
185+
if (not os.path.exists(image_prompt))
186+
]
187+
if len(non_existent_image_prompts):
188+
raise RuntimeError(
189+
f"Image prompt {non_existent_image_prompts} does not exist"
190+
)
184191

185192
return cls(
186193
prompt=getattr(args, "prompt", ""),
@@ -938,7 +945,8 @@ def chat(
938945
TransformerCrossAttentionLayer,
939946
TransformerSelfAttentionLayer,
940947
)
941-
decoder = self.model.model.decoder
948+
949+
decoder = self.model.model.decoder
942950
for m in reversed(list(decoder.modules())):
943951
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
944952
m, TransformerCrossAttentionLayer
@@ -984,7 +992,10 @@ def chat(
984992
# `is_torchtune_model` is a misnomer since it doesn't capture all
985993
# torchtune models (i.e. Flamingo)
986994
# See Issue: https://github.com/pytorch/torchchat/issues/1273
987-
elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo:
995+
elif (
996+
not generator_args.is_torchtune_model
997+
and self.model.config.model_type != ModelType.Flamingo
998+
):
988999
max_seq_length = min(
9891000
encoded.size(0) + generator_args.max_new_tokens,
9901001
(

0 commit comments

Comments
 (0)