Skip to content

Commit

Permalink
update scripts to support latest version of PR
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Jul 21, 2024
1 parent 1b55363 commit bd2b056
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.vscode
**/__pycache__/
dist/
mmsg/_version.py
mmsg.egg-info
10 changes: 8 additions & 2 deletions mmsg/integrations/chameleon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import uuid
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from typing_extensions import NotRequired

from ..utils import pil_to_base64
Expand Down Expand Up @@ -84,12 +85,17 @@ def build_response_from_segments(
token_ids for modality, token_ids in segments if modality == "text"
]
image_tokens_list = [
token_ids for modality, token_ids in segments if modality == "image"
token_ids[:1024]
if len(token_ids) > 1024
else [1] * (1024 - len(token_ids)) + token_ids
for modality, token_ids in segments
if modality == "image"
]

text_str_list = processor.batch_decode(text_tokens_list, skip_special_tokens=True)

pixel_values = model.decode_image_tokens(image_tokens_list)
image_tokens_tensor = torch.tensor(image_tokens_list, device=model.device)
pixel_values = model.decode_image_tokens(image_tokens_tensor)
images = processor.postprocess_pixel_values(
pixel_values.float().detach().cpu().numpy()
)
Expand Down
4 changes: 2 additions & 2 deletions scripts/image_only_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def run_image_only_generation(
from mmsg.integrations.chameleon_utils import postprocess_token_sequence
from mmsg.utils import load_image

if seed:
set_seed(42)
if seed is not None:
set_seed(seed)
torch.set_printoptions(threshold=10_000)

if fast:
Expand Down
29 changes: 3 additions & 26 deletions scripts/interleaved_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,11 @@ def run_interleaved_generation(
ChameleonProcessor,
set_seed,
)
from transformers.generation.logits_process import LogitsProcessorList

from mmsg.integrations.chameleon_logits_processor import (
ChameleonFSMLogitsProcessor,
ChameleonModalityFSMGuide,
)
from mmsg.integrations.chameleon_utils import postprocess_token_sequence

if seed:
set_seed(42)
if seed is not None:
set_seed(seed)
torch.set_printoptions(threshold=10_000)

if fast:
Expand Down Expand Up @@ -76,30 +71,12 @@ def run_interleaved_generation(
else:
raise ValueError(f"Invalid inference_id: {inference_mode}")

max_length = max_new_tokens + inputs["input_ids"].shape[-1]

logits_processor = LogitsProcessorList([
ChameleonFSMLogitsProcessor(
fsm=ChameleonModalityFSMGuide(
all_token_ids=model.vocabulary_mapping.vocab_map.values(),
image_token_ids=model.vocabulary_mapping.image_token_ids,
eos_token_id=model.config.eos_token_id,
boi_token_id=model.vocabulary_mapping.boi_token_id,
eoi_token_id=model.vocabulary_mapping.eoi_token_id,
device=model.device,
multimodal_generation_mode="interleaved-text-image",
),
max_length=max_length,
)
])

logger.info("Generating response...")
with torch.inference_mode():
output_token_ids_batch = model.generate(
**inputs,
multimodal_generation_mode="free",
multimodal_generation_mode="interleaved-text-image",
max_new_tokens=max_new_tokens,
logits_processor=logits_processor,
do_sample=True,
)
logger.info("Finished generation.")
Expand Down
22 changes: 18 additions & 4 deletions scripts/modal_convert_chameleon_weights.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import logging
import os
from typing import Optional

import modal
from modal_commons import GPU_CONFIG, MODEL_DIR, VOLUME_CONFIG, app

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger()


@app.function(
cpu=4.0,
Expand Down Expand Up @@ -38,7 +46,7 @@ def run_convert_chameleon_weights(
continue_download = True
if os.path.exists(local_input_model_path):
if force_download:
print("Removing original model...")
logger.info("Removing original model...")
subprocess.run(
[
"rm",
Expand All @@ -49,10 +57,11 @@ def run_convert_chameleon_weights(
)
else:
continue_download = False
print(
logger.info(
f"Original model already downloaded at {local_input_model_path}. Skipping redownload."
)
if continue_download:
logger.info("Downloading original model...")
subprocess.run(
[
"huggingface-cli",
Expand All @@ -66,11 +75,12 @@ def run_convert_chameleon_weights(
],
check=True,
)
logger.info("Downloaded original model.")

continue_conversion = True
if os.path.exists(local_output_model_path):
if force_conversion:
print("Removing converted model...")
logger.info("Removing converted model...")
subprocess.run(
[
"rm",
Expand All @@ -81,19 +91,22 @@ def run_convert_chameleon_weights(
)
else:
continue_conversion = False
print(f"Converted model already exists at {local_output_model_path}. Skipping.")
logger.info(f"Converted model already exists at {local_output_model_path}. Skipping.")
if continue_conversion:
if model_size not in NUM_SHARDS:
raise ValueError(
f"Model size {model_size} not supported. Choose from {NUM_SHARDS.keys()}"
)
logger.info("Converting model to Transformers-compatible version...")
write_model(
model_path=local_output_model_path,
input_base_path=local_input_model_path,
model_size=model_size,
)
logger.info("Finished converting model.")

if upload_to_hf:
logger.info("Uploading converted model to HF...")
subprocess.run(
[
"huggingface-cli",
Expand All @@ -106,6 +119,7 @@ def run_convert_chameleon_weights(
],
check=True,
)
logger.info("Finished uploading converted model to HF.")


@app.local_entrypoint()
Expand Down
6 changes: 3 additions & 3 deletions scripts/structured_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def run_structured_generation(
from mmsg.integrations.chameleon_utils import postprocess_token_sequence
from mmsg.integrations.multimodal_tokenizer import MultimodalTokenizer

if seed:
set_seed(42)
if seed is not None:
set_seed(seed)
torch.set_printoptions(threshold=10_000)

if fast:
Expand Down Expand Up @@ -155,7 +155,7 @@ def run_structured_generation(
with torch.inference_mode():
output_token_ids_batch = model.generate(
**inputs,
multimodal_generation_mode="free",
multimodal_generation_mode="unrestricted",
logits_processor=logits_processor,
max_new_tokens=max_new_tokens,
do_sample=True,
Expand Down
4 changes: 2 additions & 2 deletions scripts/text_only_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def run_text_only_generation(

from mmsg.utils import load_image

if seed:
set_seed(0)
if seed is not None:
set_seed(seed)
torch.set_printoptions(threshold=10_000)

if fast:
Expand Down

0 comments on commit bd2b056

Please sign in to comment.