Skip to content

Conversation

@SpenserCai
Copy link

@SpenserCai SpenserCai commented Dec 16, 2025

Summary

This PR adds support for the Mistral3 (Mistral-Small-3.x) vision-language model to candle-transformers. Mistral3 combines the Pixtral vision encoder with the Mistral language model, enabling multimodal image-text understanding.

Note: This PR is a preparatory step for the upcoming Flux2 model migration, as Flux2 shares similar multimodal architecture patterns with Mistral3.

Changes

New files in candle-transformers/src/models/mistral3/:

  • mod.rs - Module exports and documentation
  • config.rs - Mistral3Config with vision, text, and projector settings
  • model.rs - Mistral3Model and Mistral3ForConditionalGeneration
  • patch_merger.rs - PatchMerger for reducing image tokens
  • projector.rs - MultiModalProjector (RMSNorm + PatchMerger + MLP)

Modified files:

  • candle-transformers/src/models/mod.rs - Added mistral3 module export
  • candle-transformers/src/models/pixtral/vision_model.rs - Added forward_with_hidden_states() and VisionModelOutput struct
  • candle-transformers/src/models/mistral.rs - Added forward_embeds_hidden() for multimodal integration

Architecture

Mistral3ForConditionalGeneration
├── Mistral3Model
│   ├── vision_tower (Pixtral Vision Model, 24 layers)
│   ├── multi_modal_projector
│   │   ├── norm (RMSNorm)
│   │   ├── patch_merger (spatial_merge_size=2, reduces tokens by 4x)
│   │   ├── linear_1
│   │   ├── act (GELU)
│   │   └── linear_2
│   └── language_model (Mistral, 40 layers)
└── lm_head

Key Implementation Details

  1. PatchMerger: Uses reshape + permute to implement PyTorch's unfold operation (kernel_size == stride, no overlap), merging 2x2 patches into one.

  2. Image Token Replacement: Implements replace_image_tokens() as Candle equivalent of PyTorch's masked_scatter.

  3. Vision Tower Integration: Uses forward_with_hidden_states() to get batch-dimension-preserved output matching PyTorch Transformers behavior.

Supported Models

Differences from Pixtral LLaVA

Feature Pixtral LLaVA Mistral3
PatchMerger ✅ (spatial_merge_size=2)
Projector RMSNorm
Projector bias
Image token reduction 1x 4x

Usage

use candle_transformers::models::mistral3::{Mistral3Config, Mistral3ForConditionalGeneration};

let config: Mistral3Config = serde_json::from_str(&config_str)?;
let model = Mistral3ForConditionalGeneration::new(&config, vb)?;
let logits = model.forward(&input_ids, Some(&pixel_values), Some(&image_sizes), 0)?;

Verification

The implementation has been verified against PyTorch Transformers reference:

  • Vision Tower: avg_diff = 2.29e-4
  • MultiModal Projector: avg_diff = 3.61e-8
  • Full Forward Pass: Predicted token matches (token ID: 1784 "The")

Checklist

  • New model implementation follows existing patterns in candle-transformers
  • Configuration uses serde for JSON deserialization
  • Reuses existing components (Pixtral vision, Mistral language model)
  • Documentation comments included
  • Verified against PyTorch reference implementation
b59bebff62be671fdca863c2323b917f

@SpenserCai
Copy link
Author

image

mistral3 examples added!

@SpenserCai
Copy link
Author

Fixed clippy and fmt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant