Skip to content

Commit b10c340

Browse files
Apply suggestions from code review
1 parent 5593686 commit b10c340

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

optimum/exporters/onnx/model_patcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,7 @@ def __init__(
14271427
super().__init__(config, model, model_kwargs)
14281428

14291429

1430+
# https://github.com/huggingface/transformers/blob/v4.53.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L228
14301431
def qwen3_moe_forward_patched(self, hidden_states: torch.Tensor) -> torch.Tensor:
14311432
batch_size, sequence_length, hidden_dim = hidden_states.shape
14321433
hidden_states = hidden_states.view(-1, hidden_dim)
@@ -1448,7 +1449,8 @@ def qwen3_moe_forward_patched(self, hidden_states: torch.Tensor) -> torch.Tensor
14481449
# this will be used to easily index which expert is going to be sollicitated
14491450
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
14501451

1451-
# TODO: we loop over all possible experts to avoid issues in graph execution.
1452+
# TODO: we loop over all possible experts instead of hitted ones to avoid issues in graph execution.
1453+
# expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
14521454
# Loop over all available experts in the model and perform the computation on each expert
14531455
for expert_idx in range(self.num_experts):
14541456
expert_layer = self.experts[expert_idx]

0 commit comments

Comments
 (0)