Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchtune/models/llama4/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ def llama4_scout_17b_16e(
)
return EarlyFusionModel(
decoder,
encoders={"vision": vision_encoder},
encoders={"vision": vision_encoder, "video": vision_encoder},
encoder_tokens={
"vision": LLAMA4_SPECIAL_TOKENS["<|patch|>"],
"video": LLAMA4_SPECIAL_TOKENS["<|video|>"],
},
encoders_trainable={
"vision": encoder_trainable,
"video": encoder_trainable,
},
decoder_trainable=decoder_trainable,
fusion_trainable=fusion_trainable,
Expand Down Expand Up @@ -154,12 +156,14 @@ def llama4_maverick_17b_128e(
)
return EarlyFusionModel(
decoder,
encoders={"vision": vision_encoder},
encoders={"vision": vision_encoder, "video": vision_encoder},
encoder_tokens={
"vision": LLAMA4_SPECIAL_TOKENS["<|patch|>"],
"video": LLAMA4_SPECIAL_TOKENS["<|video|>"],
},
encoders_trainable={
"vision": encoder_trainable,
"video": encoder_trainable,
},
decoder_trainable=decoder_trainable,
fusion_trainable=fusion_trainable,
Expand Down Expand Up @@ -314,12 +318,14 @@ def lora_llama4_scout_17b_16e(
)
return EarlyFusionModel(
decoder,
encoders={"vision": vision_encoder},
encoders={"vision": vision_encoder, "video": vision_encoder},
encoder_tokens={
"vision": LLAMA4_SPECIAL_TOKENS["<|patch|>"],
"video": LLAMA4_SPECIAL_TOKENS["<|video|>"],
},
encoders_trainable={
"vision": encoder_trainable != TrainableParams.FROZEN,
"video": encoder_trainable != TrainableParams.FROZEN,
},
decoder_trainable=decoder_trainable != TrainableParams.FROZEN,
fusion_trainable=fusion_trainable != TrainableParams.FROZEN,
Expand Down
51 changes: 44 additions & 7 deletions torchtune/models/llama4/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def get_reserved_special_tokens(start_id, end_id, name=None, start_reserved=0):
VISION_SPECIAL_TOKENS = {
"<|image_start|>": 200080,
"<|image_end|>": 200081,
"<|vision_reserved_special_token_0|>": 200082,
"<|vision_reserved_special_token_1|>": 200083,
"<|vid_start|>": 200082,
"<|vid_end|>": 200083,
"<|tile_x_separator|>": 200084,
"<|tile_y_separator|>": 200085,
"<|vision_reserved_special_token_2|>": 200086,
"<|vision_reserved_special_token_3|>": 200087,
"<|vision_reserved_special_token_4|>": 200088,
"<|vision_reserved_special_token_5|>": 200089,
"<|vision_reserved_special_token_0|>": 200086,
"<|vid_frame_separator|>": 200087,
"<|vision_reserved_special_token_1|>": 200088,
"<|vision_reserved_special_token_2|>": 200089,
"<|image|>": 200090,
"<|vision_reserved_special_token_6|>": 200091,
"<|video|>": 200091,
"<|patch|>": 200092,
} | get_reserved_special_tokens(200093, 201134, "vision", 7)

Expand Down Expand Up @@ -166,6 +166,12 @@ def __init__(
self.tile_x_separator = self.special_tokens["<|tile_x_separator|>"]
self.tile_y_separator = self.special_tokens["<|tile_y_separator|>"]

# Video tokens
self.video_id = self.special_tokens["<|patch|>"]
self.video_start = self.special_tokens["<|vid_start|>"]
self.video_end = self.special_tokens["<|vid_end|>"]
self.frame_separator = self.special_tokens["<|vid_frame_separator|>"]

# Reasoning tokens
self.reasoning_start = self.special_tokens["<|reasoning_thinking_start|>"]
self.reasoning_end = self.special_tokens["<|reasoning_thinking_end|>"]
Expand Down Expand Up @@ -302,6 +308,31 @@ def _get_tile_grid_tokens(
tokens.extend(single_tile_tokens)
return tokens

def _get_video_tokens(self, num_frames: int, patch_tokens_per_frame: int) -> list[int]:
"""
Tokenize video content with frame structure similar to Huggingface implementation.

Args:
num_frames (int): Number of frames in the video
patch_tokens_per_frame (int): Number of patch tokens per frame

Returns:
list[int]: Video tokens with frame separators
"""
tokens = []
tokens.append(self.video_start)

for frame_idx in range(num_frames):
# Add video patch tokens for this frame
tokens.extend([self.video_id] * patch_tokens_per_frame)

# Add frame separator (except for the last frame)
if frame_idx < num_frames - 1:
tokens.append(self.frame_separator)

tokens.append(self.video_end)
return tokens

def _tokenize_header(self, message: Message) -> list[int]:
"""
Tokenize header start, message role, and header end as list of ids
Expand Down Expand Up @@ -335,6 +366,12 @@ def _tokenize_body(self, message: Message) -> list[int]:
tokenized_body += self._get_tile_grid_tokens(
patch_tokens_per_tile, aspect_ratio
)
elif item["type"] == "video":
num_frames = item.get("num_frames", 1)
patch_tokens_per_frame = item.get("patch_tokens_per_frame", 1)
tokenized_body += self._get_video_tokens(
num_frames, patch_tokens_per_frame
)

return tokenized_body

Expand Down
120 changes: 115 additions & 5 deletions torchtune/models/llama4/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Mapping, Optional
from typing import Any, List, Mapping, Optional

import torch
from PIL import Image
Expand Down Expand Up @@ -154,6 +154,88 @@ def transform_image(
tiles = torch.cat((tiles, thumbnail.unsqueeze(0)), dim=0)
return tiles, ar

def transform_video(
self, video: List[torch.Tensor], inference: bool = False
) -> tuple[torch.Tensor, int]:
"""
Transform video content for video processing.

Args:
video: List of 4D torch.Tensors, each with shape [frames_per_clip, height, width, channels].
Length of list should equal clips_per_video.
inference: Whether running in inference mode

Returns:
tuple: (processed_frames, total_num_frames)
"""
if not isinstance(video, list):
raise ValueError(f"Expected list of 4D tensors, got {type(video)}")

if not video:
raise ValueError("Empty video clips list")

# Validate input format
for i, clip in enumerate(video):
if not isinstance(clip, torch.Tensor):
raise ValueError(f"Clip {i} is not a torch.Tensor, got {type(clip)}")
if clip.dim() != 4:
raise ValueError(f"Clip {i} has {clip.dim()} dimensions, expected 4D tensor [frames_per_clip, height, width, channels]")

processed_clips = []
total_frames = 0

# Process each clip
for clip in video:
frames_per_clip, height, width, channels = clip.shape
total_frames += frames_per_clip

# Normalize tensor values to [0, 1] range if needed
if clip.dtype == torch.uint8:
clip = clip.float() / 255.0
elif clip.min() < 0 or clip.max() > 1:
# Assume values are in [-1, 1] or other range, normalize to [0, 1]
clip = (clip - clip.min()) / (clip.max() - clip.min())

processed_frames_in_clip = []

# Process each frame in the clip
for frame_idx in range(frames_per_clip):
frame = clip[frame_idx] # Shape: [height, width, channels]

# Convert HWC to CHW format for PIL conversion
if frame.shape[-1] == channels and channels in [1, 3]: # HWC format
frame = frame.permute(2, 0, 1) # Convert to CHW
else:
raise ValueError(f"Unexpected frame shape: {frame.shape}, expected [height, width, channels]")

# Convert to PIL Image for processing through existing transforms
frame_uint8 = (frame * 255).to(torch.uint8)

if frame.shape[0] == 3: # RGB
frame_pil = Image.fromarray(frame_uint8.permute(1, 2, 0).cpu().numpy(), mode='RGB')
elif frame.shape[0] == 1: # Grayscale
frame_pil = Image.fromarray(frame_uint8.squeeze(0).cpu().numpy(), mode='L')
else:
raise ValueError(f"Unsupported number of channels: {frame.shape[0]}")

# Process the PIL image through existing transform
tiles, _ = self.transform_image(frame_pil, inference=inference)
processed_frames_in_clip.append(tiles)

# Stack frames in this clip: (frames_per_clip, num_tiles, channels, height, width)
clip_processed = torch.stack(processed_frames_in_clip, dim=0)
processed_clips.append(clip_processed)

# Stack all clips: (clips_per_video, frames_per_clip, num_tiles, channels, height, width)
# Then reshape to: (total_frames, num_tiles, channels, height, width)
all_clips = torch.stack(processed_clips, dim=0)
clips_per_video, frames_per_clip, num_tiles, channels, height, width = all_clips.shape

# Reshape to flatten clips and frames into a single frames dimension
processed_frames = all_clips.view(clips_per_video * frames_per_clip, num_tiles, channels, height, width)

return processed_frames, total_frames

def encode(
self,
text: str,
Expand Down Expand Up @@ -235,7 +317,7 @@ def __call__(
self, sample: Mapping[str, Any], inference: bool = False
) -> Mapping[str, Any]:
"""
Apply image decoding, transformations and tokenization to messages in the sample.
Apply image/video decoding, transformations and tokenization to messages in the sample.

Args:
sample (Mapping[str, Any]): A sample with a "messages" field.
Expand All @@ -245,21 +327,49 @@ def __call__(
Mapping[str, Any]: The transformed sample with the following fields:
- tokens: list[int] of tokenized messages
- mask: list[bool] of masks for the tokenized messages
- encoder_input: dict[str, Any] of transformed images
- encoder_input: dict[str, Any] of transformed images and videos
"""
encoder_input = {"vision": {"images": []}}
images_list = []
videos_list = []

messages = sample["messages"]
for message in messages:
for content in message.content:
if content["type"] == "image":
image = content["content"]
tiles, ar = self.transform_image(image, inference=inference)
encoder_input["vision"]["images"].append(tiles)
images_list.append(tiles)

# Add number of patch tokens, tiles, and aspect ratio to metadata
# so tokenizer can add the corresponding special tokens
content["patch_tokens_per_tile"] = self.patch_tokens_per_tile
content["aspect_ratio"] = ar
elif content["type"] == "video":
video = content["content"]
processed_frames, num_frames = self.transform_video(video, inference=inference)

# Flatten video frames to individual images for the vision encoder
# processed_frames shape: (num_frames, num_tiles, channels, height, width)
# We need to flatten to: (num_frames * num_tiles, channels, height, width)
flattened_frames = processed_frames.view(-1, *processed_frames.shape[2:])
videos_list.append(flattened_frames)

# Add metadata for video tokenization
# Each frame is treated like an image for patch token calculation
content["num_frames"] = num_frames
content["patch_tokens_per_frame"] = self.patch_tokens_per_tile

# Create encoder_input in the format expected by EarlyFusionModel
# Both vision and video use the same vision encoder, so they share the same data structure.
# The differentiation happens at token level (<|patch|> vs <|video|>), not data level.
encoder_input = {}
if images_list:
encoder_input["vision"] = {"images": images_list}

if videos_list:
# Videos use "images" key because they're processed as sequences of images
# by the same vision encoder used for static images
encoder_input["vision"] = {"images": videos_list}

sample["encoder_input"] = encoder_input
sample = self.tokenizer(sample, inference=inference)
Expand Down