Skip to content

Commit

Permalink
fix build_response_from_segments
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 6, 2024
1 parent bd2b056 commit def24fd
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 26 deletions.
6 changes: 3 additions & 3 deletions mmsg/integrations/chameleon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def build_response_from_segments(
token_ids for modality, token_ids in segments if modality == "text"
]
image_tokens_list = [
token_ids[:1024]
if len(token_ids) > 1024
else [1] * (1024 - len(token_ids)) + token_ids
token_ids[: processor.image_seq_length]
if len(token_ids) > processor.image_seq_length
else [1] * (processor.image_seq_length - len(token_ids)) + token_ids
for modality, token_ids in segments
if modality == "image"
]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dynamic = ["version"]

[project.optional-dependencies]
test = [
"bitsandbytes",
"modal",
"numpy",
"term-image",
Expand Down
10 changes: 6 additions & 4 deletions scripts/image_only_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_image_only_generation(
image_1_path: Optional[str] = None,
image_2_path: Optional[str] = None,
max_new_tokens: int = 2400,
fast: bool = False,
fast: bool = True,
model_cache_dir: Optional[str] = None,
outputs_dir: str = ".",
seed: Optional[int] = None,
Expand Down Expand Up @@ -70,9 +70,11 @@ def run_image_only_generation(
prompt = "Please draw an apple!"
logger.info(f"Prompt: {prompt}")

inputs = processor(prompt, return_tensors="pt").to(
model.device, dtype=model.dtype
)
inputs = processor(
prompt,
padding=True,
return_tensors="pt",
).to(model.device, dtype=model.dtype)
elif inference_mode == "text-image-to-image":
logger.info("TASK: Text-Image to Image generation")

Expand Down
11 changes: 7 additions & 4 deletions scripts/interleaved_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def run_interleaved_generation(
inference_mode: Literal["text-to-interleaved-text-image"] = "text-to-interleaved-text-image",
prompt: Optional[str] = None,
max_new_tokens: int = 2400,
fast: bool = False,
fast: bool = True,
model_cache_dir: Optional[str] = None,
outputs_dir: str = ".",
seed: Optional[int] = None,
Expand Down Expand Up @@ -65,9 +65,12 @@ def run_interleaved_generation(
prompt = "Please draw an apple!"
logger.info(f"Prompt: {prompt}")

inputs = processor(prompt, return_tensors="pt").to(
model.device, dtype=model.dtype
)
inputs = processor(
prompt,
padding=True,
return_tensors="pt",
return_for_text_completion=True,
).to(model.device, dtype=model.dtype)
else:
raise ValueError(f"Invalid inference_id: {inference_mode}")

Expand Down
4 changes: 2 additions & 2 deletions scripts/modal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run_inference(
regex_pattern: Optional[str] = None,
model_cache_dir: str = MODEL_DIR,
outputs_dir: str = GENERATED_IMAGES_DIR,
fast: bool = False,
fast: bool = True,
seed: Optional[int] = None,
) -> Optional[str]:
if inference_mode in ["text-to-text", "text-image-to-text", "multi-image-to-text"]:
Expand Down Expand Up @@ -130,7 +130,7 @@ def main(
regex_pattern: Optional[str] = None,
model_cache_dir: str = MODEL_DIR,
outputs_dir: str = GENERATED_IMAGES_DIR,
fast: bool = False,
fast: bool = True,
seed: Optional[int] = None,
local: bool = False,
):
Expand Down
1 change: 1 addition & 0 deletions scripts/modal_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
accelerate==0.31.0
bitsandbytes==0.43.1
deepspeed==0.14.4
jsonlines==4.0.0
mpi4py_mpich==3.1.5
Expand Down
9 changes: 3 additions & 6 deletions scripts/structured_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run_structured_generation(
json_schema_path: Optional[str] = None,
regex_pattern: Optional[str] = None,
max_new_tokens: int = 2400,
fast: bool = False,
fast: bool = True,
model_cache_dir: str = "/pretrained",
outputs_dir: str = ".",
seed: Optional[int] = None,
Expand Down Expand Up @@ -123,16 +123,13 @@ def run_structured_generation(

if json_schema is not None:
json_schema_str = json.dumps(json_schema)
logger.info(f"JSON schema: {json_schema_str}")
regex_pattern = build_regex_from_schema(json_schema_str)
logger.info(f"Built regex pattern from json schema: {regex_pattern}")

prompt = f"{prompt} Please follow this schema: {json.dumps(json_schema)}"
else:
prompt = f"{prompt} Please follow this regex pattern: {regex_pattern}"
logger.info(f"Prompt: {prompt}")

images = None

logger.info("Building regex guide...")
regex_guide = RegexWithMultimodalMarkersGuide(
regex_pattern,
Expand All @@ -149,7 +146,7 @@ def run_structured_generation(
[FSMLogitsProcessor(mmsg_tokenizer, regex_guide)]
)

inputs = processor(prompt, images=images, return_tensors="pt").to(model.device)
inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device)

logger.info("Starting generation...")
with torch.inference_mode():
Expand Down
22 changes: 15 additions & 7 deletions scripts/text_only_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run_text_only_generation(
image_1_path: Optional[str] = None,
image_2_path: Optional[str] = None,
max_new_tokens: int = 40,
fast: bool = False,
fast: bool = True,
model_cache_dir: str = "/pretrained",
seed: Optional[int] = None,
) -> str:
Expand Down Expand Up @@ -67,9 +67,12 @@ def run_text_only_generation(
prompt = "Is a banana a fruit or a vegetable? Please answer with yes or no."
logger.info(f"Prompt: {prompt}")

inputs = processor(prompt, return_tensors="pt").to(
model.device, dtype=model.dtype
)
inputs = processor(
prompt,
padding=True,
return_tensors="pt",
return_for_text_completion=True,
).to(model.device, dtype=model.dtype)
elif inference_mode == "text-image-to-text":
logger.info("TASK: Text-Image to Text generation")

Expand All @@ -83,9 +86,13 @@ def run_text_only_generation(
image = load_image(image_1_path)
logger.info("Image 1 loaded.", image_1_path)

inputs = processor(prompt, image, return_tensors="pt").to(
model.device, dtype=model.dtype
)
inputs = processor(
prompt,
image,
padding=True,
return_tensors="pt",
return_for_text_completion=True,
).to(model.device, dtype=model.dtype)
elif inference_mode == "multi-image-to-text":
logger.info("TASK: Multi-Image generation")

Expand All @@ -108,6 +115,7 @@ def run_text_only_generation(
images=images,
padding=True,
return_tensors="pt",
return_for_text_completion=True,
).to(model.device, dtype=model.dtype)
else:
raise ValueError(f"Invalid inference_mode: {inference_mode}")
Expand Down

0 comments on commit def24fd

Please sign in to comment.