diff --git a/mmsg/integrations/chameleon_utils.py b/mmsg/integrations/chameleon_utils.py index cb2e80e..c945193 100644 --- a/mmsg/integrations/chameleon_utils.py +++ b/mmsg/integrations/chameleon_utils.py @@ -85,9 +85,9 @@ def build_response_from_segments( token_ids for modality, token_ids in segments if modality == "text" ] image_tokens_list = [ - token_ids[:1024] - if len(token_ids) > 1024 - else [1] * (1024 - len(token_ids)) + token_ids + token_ids[: processor.image_seq_length] + if len(token_ids) > processor.image_seq_length + else [1] * (processor.image_seq_length - len(token_ids)) + token_ids for modality, token_ids in segments if modality == "image" ] diff --git a/pyproject.toml b/pyproject.toml index e83a51f..4385e9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dynamic = ["version"] [project.optional-dependencies] test = [ + "bitsandbytes", "modal", "numpy", "term-image", diff --git a/scripts/image_only_generation.py b/scripts/image_only_generation.py index 734f084..a10a7de 100644 --- a/scripts/image_only_generation.py +++ b/scripts/image_only_generation.py @@ -21,7 +21,7 @@ def run_image_only_generation( image_1_path: Optional[str] = None, image_2_path: Optional[str] = None, max_new_tokens: int = 2400, - fast: bool = False, + fast: bool = True, model_cache_dir: Optional[str] = None, outputs_dir: str = ".", seed: Optional[int] = None, @@ -70,9 +70,11 @@ def run_image_only_generation( prompt = "Please draw an apple!" logger.info(f"Prompt: {prompt}") - inputs = processor(prompt, return_tensors="pt").to( - model.device, dtype=model.dtype - ) + inputs = processor( + prompt, + padding=True, + return_tensors="pt", + ).to(model.device, dtype=model.dtype) elif inference_mode == "text-image-to-image": logger.info("TASK: Text-Image to Image generation") diff --git a/scripts/interleaved_generation.py b/scripts/interleaved_generation.py index 34de3c4..4dd5833 100644 --- a/scripts/interleaved_generation.py +++ b/scripts/interleaved_generation.py @@ -17,7 +17,7 @@ def run_interleaved_generation( inference_mode: Literal["text-to-interleaved-text-image"] = "text-to-interleaved-text-image", prompt: Optional[str] = None, max_new_tokens: int = 2400, - fast: bool = False, + fast: bool = True, model_cache_dir: Optional[str] = None, outputs_dir: str = ".", seed: Optional[int] = None, @@ -65,9 +65,12 @@ def run_interleaved_generation( prompt = "Please draw an apple!" logger.info(f"Prompt: {prompt}") - inputs = processor(prompt, return_tensors="pt").to( - model.device, dtype=model.dtype - ) + inputs = processor( + prompt, + padding=True, + return_tensors="pt", + return_for_text_completion=True, + ).to(model.device, dtype=model.dtype) else: raise ValueError(f"Invalid inference_id: {inference_mode}") diff --git a/scripts/modal_inference.py b/scripts/modal_inference.py index 5fb8679..26d17e8 100644 --- a/scripts/modal_inference.py +++ b/scripts/modal_inference.py @@ -31,7 +31,7 @@ def run_inference( regex_pattern: Optional[str] = None, model_cache_dir: str = MODEL_DIR, outputs_dir: str = GENERATED_IMAGES_DIR, - fast: bool = False, + fast: bool = True, seed: Optional[int] = None, ) -> Optional[str]: if inference_mode in ["text-to-text", "text-image-to-text", "multi-image-to-text"]: @@ -130,7 +130,7 @@ def main( regex_pattern: Optional[str] = None, model_cache_dir: str = MODEL_DIR, outputs_dir: str = GENERATED_IMAGES_DIR, - fast: bool = False, + fast: bool = True, seed: Optional[int] = None, local: bool = False, ): diff --git a/scripts/modal_requirements.txt b/scripts/modal_requirements.txt index ad4d9f6..2d6bbd4 100644 --- a/scripts/modal_requirements.txt +++ b/scripts/modal_requirements.txt @@ -1,4 +1,5 @@ accelerate==0.31.0 +bitsandbytes==0.43.1 deepspeed==0.14.4 jsonlines==4.0.0 mpi4py_mpich==3.1.5 diff --git a/scripts/structured_generation.py b/scripts/structured_generation.py index 6a21463..f9075ae 100644 --- a/scripts/structured_generation.py +++ b/scripts/structured_generation.py @@ -18,7 +18,7 @@ def run_structured_generation( json_schema_path: Optional[str] = None, regex_pattern: Optional[str] = None, max_new_tokens: int = 2400, - fast: bool = False, + fast: bool = True, model_cache_dir: str = "/pretrained", outputs_dir: str = ".", seed: Optional[int] = None, @@ -123,16 +123,13 @@ def run_structured_generation( if json_schema is not None: json_schema_str = json.dumps(json_schema) + logger.info(f"JSON schema: {json_schema_str}") regex_pattern = build_regex_from_schema(json_schema_str) logger.info(f"Built regex pattern from json schema: {regex_pattern}") prompt = f"{prompt} Please follow this schema: {json.dumps(json_schema)}" - else: - prompt = f"{prompt} Please follow this regex pattern: {regex_pattern}" logger.info(f"Prompt: {prompt}") - images = None - logger.info("Building regex guide...") regex_guide = RegexWithMultimodalMarkersGuide( regex_pattern, @@ -149,7 +146,7 @@ def run_structured_generation( [FSMLogitsProcessor(mmsg_tokenizer, regex_guide)] ) - inputs = processor(prompt, images=images, return_tensors="pt").to(model.device) + inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device) logger.info("Starting generation...") with torch.inference_mode(): diff --git a/scripts/text_only_generation.py b/scripts/text_only_generation.py index b599958..8b72728 100644 --- a/scripts/text_only_generation.py +++ b/scripts/text_only_generation.py @@ -20,7 +20,7 @@ def run_text_only_generation( image_1_path: Optional[str] = None, image_2_path: Optional[str] = None, max_new_tokens: int = 40, - fast: bool = False, + fast: bool = True, model_cache_dir: str = "/pretrained", seed: Optional[int] = None, ) -> str: @@ -67,9 +67,12 @@ def run_text_only_generation( prompt = "Is a banana a fruit or a vegetable? Please answer with yes or no." logger.info(f"Prompt: {prompt}") - inputs = processor(prompt, return_tensors="pt").to( - model.device, dtype=model.dtype - ) + inputs = processor( + prompt, + padding=True, + return_tensors="pt", + return_for_text_completion=True, + ).to(model.device, dtype=model.dtype) elif inference_mode == "text-image-to-text": logger.info("TASK: Text-Image to Text generation") @@ -83,9 +86,13 @@ def run_text_only_generation( image = load_image(image_1_path) logger.info("Image 1 loaded.", image_1_path) - inputs = processor(prompt, image, return_tensors="pt").to( - model.device, dtype=model.dtype - ) + inputs = processor( + prompt, + image, + padding=True, + return_tensors="pt", + return_for_text_completion=True, + ).to(model.device, dtype=model.dtype) elif inference_mode == "multi-image-to-text": logger.info("TASK: Multi-Image generation") @@ -108,6 +115,7 @@ def run_text_only_generation( images=images, padding=True, return_tensors="pt", + return_for_text_completion=True, ).to(model.device, dtype=model.dtype) else: raise ValueError(f"Invalid inference_mode: {inference_mode}")