|
24 | 24 |
|
25 | 25 | from PIL import Image
|
26 | 26 |
|
27 |
| -# torchtune model definition dependencies |
28 |
| -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
29 |
| - |
30 |
| -from torchtune.generation import sample as tune_sample |
31 |
| -from torchtune.models.llama3 import llama3_tokenizer |
32 |
| - |
33 |
| -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
34 |
| -from torchtune.training import set_default_dtype |
35 |
| - |
36 | 27 | from torchchat.cli.builder import (
|
37 | 28 | _initialize_model,
|
38 | 29 | _initialize_tokenizer,
|
|
43 | 34 | from torchchat.utils.build_utils import device_sync, set_precision
|
44 | 35 | from torchchat.utils.device_info import get_device_info
|
45 | 36 |
|
| 37 | +# torchtune model definition dependencies |
| 38 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 39 | + |
| 40 | +from torchtune.generation import sample as tune_sample |
| 41 | +from torchtune.models.llama3 import llama3_tokenizer |
| 42 | + |
| 43 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 44 | +from torchtune.training import set_default_dtype |
| 45 | + |
46 | 46 |
|
47 | 47 | class _ChatFormatter(ABC):
|
48 | 48 | def __init__(self, tokenizer):
|
@@ -179,8 +179,15 @@ def from_args(cls, args):
|
179 | 179 |
|
180 | 180 | # Validate that all image prompts exist before expensive model load
|
181 | 181 | if image_prompts := getattr(args, "image_prompts", None):
|
182 |
| - if not all(os.path.exists(image_prompt) for image_prompt in image_prompts): |
183 |
| - raise RuntimeError(f"Image prompt {image_prompt} does not exist") |
| 182 | + non_existent_image_prompts = [ |
| 183 | + image_prompt |
| 184 | + for image_prompt in image_prompts |
| 185 | + if (not os.path.exists(image_prompt)) |
| 186 | + ] |
| 187 | + if len(non_existent_image_prompts): |
| 188 | + raise RuntimeError( |
| 189 | + f"Image prompt {non_existent_image_prompts} does not exist" |
| 190 | + ) |
184 | 191 |
|
185 | 192 | return cls(
|
186 | 193 | prompt=getattr(args, "prompt", ""),
|
@@ -938,7 +945,8 @@ def chat(
|
938 | 945 | TransformerCrossAttentionLayer,
|
939 | 946 | TransformerSelfAttentionLayer,
|
940 | 947 | )
|
941 |
| - decoder = self.model.model.decoder |
| 948 | + |
| 949 | + decoder = self.model.model.decoder |
942 | 950 | for m in reversed(list(decoder.modules())):
|
943 | 951 | if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
|
944 | 952 | m, TransformerCrossAttentionLayer
|
@@ -984,7 +992,10 @@ def chat(
|
984 | 992 | # `is_torchtune_model` is a misnomer since it doesn't capture all
|
985 | 993 | # torchtune models (i.e. Flamingo)
|
986 | 994 | # See Issue: https://github.com/pytorch/torchchat/issues/1273
|
987 |
| - elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo: |
| 995 | + elif ( |
| 996 | + not generator_args.is_torchtune_model |
| 997 | + and self.model.config.model_type != ModelType.Flamingo |
| 998 | + ): |
988 | 999 | max_seq_length = min(
|
989 | 1000 | encoded.size(0) + generator_args.max_new_tokens,
|
990 | 1001 | (
|
|
0 commit comments