From 4a8c624590c32c24bf1371018d62e3f149d6a255 Mon Sep 17 00:00:00 2001 From: wentiange Date: Tue, 3 Mar 2026 12:37:53 +0000 Subject: [PATCH 1/3] [Adapt] add DominoEP for InternS1 Pro VL --- .../v1/engine/vision_compose_train_engine.py | 7 +- .../compose/qwen3_vl/modeling_qwen3_vl.py | 134 +++++++++++------- xtuner/v1/model/moe/moe.py | 4 +- xtuner/v1/train/trainer.py | 2 +- 4 files changed, 90 insertions(+), 57 deletions(-) diff --git a/xtuner/v1/engine/vision_compose_train_engine.py b/xtuner/v1/engine/vision_compose_train_engine.py index 4aa98cd58..0bad6f091 100644 --- a/xtuner/v1/engine/vision_compose_train_engine.py +++ b/xtuner/v1/engine/vision_compose_train_engine.py @@ -169,8 +169,11 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: efficient_forward_tokens += (num_tokens.long() ** 2).sum() total_forward_tokens += (num_tokens.long().sum()) ** 2 - # todo: support intra_layer_micro_batch - output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0]) + if len(seq_ctx_list) == 1: + output = self.model(seq_ctx=seq_ctx_list[0],loss_ctx=loss_ctx_list[0]) + else: + output = self.model(seq_ctx=seq_ctx_list,loss_ctx=loss_ctx_list) + # llm loss has been global averaged llm_loss = output["loss"] step_llm_loss += llm_loss.detach().clone() diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index 865cbcfb1..1d1e5e604 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -137,63 +137,93 @@ def forward( seq_ctx: SequenceContext, loss_ctx: CELossContext ) -> MoEModelOutputs: - input_ids = seq_ctx.input_ids - pixel_values = seq_ctx.pixel_values - image_grid_thw = seq_ctx.image_grid_thw - sequence_parallel_mesh = seq_ctx.sequence_parallel_mesh - - inputs_embeds = self.language_model.embed_tokens(input_ids) # type: ignore - - if pixel_values is not None: - assert image_grid_thw is not None - assert input_ids is not None - visual_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values, - image_grid_thw, - sequence_parallel_mesh) - try: - # To simplify and facilitate the processing of deepstack_visual_embeds inside language_model, - # we all-gather visual_embeds, and then split them based on the input_ids, - # then non-uniformly split them based on the input_ids - visual_pos_masks, visual_features, deepstack_visual_embeds = self.get_placeholder_mask( - input_ids, - visual_features=visual_embeds, - deepstack_visual_embeds=deepstack_visual_embeds, - sequence_parallel_mesh=sequence_parallel_mesh, - origin_pixel_len=pixel_values.size(0) - ) - inputs_embeds[visual_pos_masks] = inputs_embeds[visual_pos_masks] * 0.0 + visual_features - except Exception as e: - logger.error(f"!!!Warning: {e}, but continue anyway!!!!") - inputs_embeds = inputs_embeds + visual_embeds.sum() * 0.0 + def _prepare_llm_inputs(single_seq_ctx): + sp_mesh = single_seq_ctx.sequence_parallel_mesh + img_thw = single_seq_ctx.image_grid_thw + pixel_val = single_seq_ctx.pixel_values + input_ids = single_seq_ctx.input_ids + inputs_embeds = self.language_model.embed_tokens(input_ids) # type: ignore + if pixel_val is not None: + assert img_thw is not None + assert input_ids is not None + visual_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_val, + img_thw, + sp_mesh) + try: + # To simplify and facilitate the processing of deepstack_visual_embeds inside language_model, + # we all-gather visual_embeds, and then split them based on the input_ids, + # then non-uniformly split them based on the input_ids + visual_pos_masks, visual_features, deepstack_visual_embeds = self.get_placeholder_mask( + input_ids, + visual_features=visual_embeds, + deepstack_visual_embeds=deepstack_visual_embeds, + sequence_parallel_mesh=sp_mesh, + origin_pixel_len=pixel_val.size(0) + ) + inputs_embeds[visual_pos_masks] = inputs_embeds[visual_pos_masks] * 0.0 + visual_features + except Exception as e: + logger.error(f"!!!Warning: {e}, but continue anyway!!!!") + inputs_embeds = inputs_embeds + visual_embeds.sum() * 0.0 + for deepstack_visual_embed in deepstack_visual_embeds: + inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0 + deepstack_visual_embeds = None + visual_pos_masks = None + else: + pixel_values_dump = torch.randn(4, 1536, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + img_thw = torch.tensor([[1, 2, 2]], device=inputs_embeds.device) + viusal_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values_dump, img_thw) + inputs_embeds = inputs_embeds + viusal_embeds.sum() * 0.0 for deepstack_visual_embed in deepstack_visual_embeds: inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0 + deepstack_visual_embeds = None visual_pos_masks = None + + if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0: + assert single_seq_ctx.position_ids is not None + assert single_seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \ + f" but got {single_seq_ctx.position_ids.ndim}" + + deepstack_visual_embeds = None + visual_pos_masks = None + return inputs_embeds + + if isinstance(seq_ctx, list): + lang_seq_ctx = [] + for single_seq_ctx in seq_ctx: + inputs_embeds = _prepare_llm_inputs(single_seq_ctx) + lang_seq_ctx.append( + SequenceContext( + input_ids=None, + cu_seq_lens_q=single_seq_ctx.cu_seq_lens_q, + cu_seq_lens_k=single_seq_ctx.cu_seq_lens_k, + max_length_q=single_seq_ctx.max_length_q, + max_length_k=single_seq_ctx.max_length_k, + position_ids=single_seq_ctx.position_ids, + num_padding=single_seq_ctx.num_padding, + sequence_parallel_mesh=single_seq_ctx.sequence_parallel_mesh, + inputs_embeds=inputs_embeds, + rollout_routed_experts=single_seq_ctx.rollout_routed_experts, + deepstack_visual_embeds=None, + visual_pos_masks=None + ) + ) + elif isinstance(seq_ctx, SequenceContext): + inputs_embeds = _prepare_llm_inputs(seq_ctx) + lang_seq_ctx = SequenceContext(input_ids=None, + cu_seq_lens_q=seq_ctx.cu_seq_lens_q, + cu_seq_lens_k=seq_ctx.cu_seq_lens_k, + max_length_q=seq_ctx.max_length_q, + max_length_k=seq_ctx.max_length_k, + position_ids=seq_ctx.position_ids, + num_padding=seq_ctx.num_padding, + sequence_parallel_mesh=seq_ctx.sequence_parallel_mesh, + inputs_embeds=inputs_embeds, + rollout_routed_experts=seq_ctx.rollout_routed_experts, + deepstack_visual_embeds=None, + visual_pos_masks=None) else: - pixel_values_dump = torch.randn(4, 1536, device=inputs_embeds.device, dtype=inputs_embeds.dtype) - image_grid_thw = torch.tensor([[1, 2, 2]], device=inputs_embeds.device) - viusal_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values_dump, image_grid_thw) - inputs_embeds = inputs_embeds + viusal_embeds.sum() * 0.0 - for deepstack_visual_embed in deepstack_visual_embeds: - inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0 - - deepstack_visual_embeds = None - visual_pos_masks = None - - if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0: - assert seq_ctx.position_ids is not None - assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \ - f" but got {seq_ctx.position_ids.ndim}" - deepstack_visual_embeds = None - visual_pos_masks = None - - # NOTE: 一定不要原地覆盖,否则第二次 forward 会缺少数据 - lang_seq_ctx = seq_ctx.copy( - input_ids=None, - inputs_embeds=inputs_embeds, - deepstack_visual_embeds=deepstack_visual_embeds, - visual_pos_masks=visual_pos_masks, - ) + raise NotImplementedError outputs = self.language_model( lang_seq_ctx, loss_ctx diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 962129e1d..a36909438 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -335,8 +335,8 @@ def _micro_batch_forward( cat_position_embeddings = self.rotary_emb(cat_hidden_states, cat_position_ids) # type: ignore position_embeddings_list = list( zip( - cat_position_embeddings[0].chunk(len(seq_ctx_list), dim=1), - cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=1), + cat_position_embeddings[0].chunk(len(seq_ctx_list), dim=2), + cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=2), ) ) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 45187b15d..1d35dd8b1 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -578,7 +578,7 @@ def __init__( self.sp_mesh = self.data_mesh["sp"] if global_batch_size is None: - global_batch_size = self.data_mesh["dp"].size() + global_batch_size = self.data_mesh["dp"].size() * intra_layer_micro_batch self._global_batch_size = global_batch_size self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg, fsdp_cfg) From c796e83525c6e658120029f028f8cf1a42790952 Mon Sep 17 00:00:00 2001 From: wentiange Date: Tue, 3 Mar 2026 12:46:47 +0000 Subject: [PATCH 2/3] [Optimization] remove some D2H transfers --- xtuner/v1/engine/vision_compose_train_engine.py | 14 +++++++------- xtuner/v1/loss/ce_loss.py | 10 ++++------ xtuner/v1/train/trainer.py | 2 +- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/xtuner/v1/engine/vision_compose_train_engine.py b/xtuner/v1/engine/vision_compose_train_engine.py index 0bad6f091..23dce6c43 100644 --- a/xtuner/v1/engine/vision_compose_train_engine.py +++ b/xtuner/v1/engine/vision_compose_train_engine.py @@ -211,25 +211,25 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: maxvio = maxvio_all_layers.mean() if moe_need_update_bias: self.model.language_model.update_bias(tokens_per_expert_global_for_bias, avg_count_load) # type: ignore - other_log["maxvio"] = maxvio.item() + other_log["maxvio"] = maxvio reduced_llm_loss = step_llm_loss dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size())) - loss_log["local_loss"] = step_loss.item() - loss_log["reduced_llm_loss"] = reduced_llm_loss.item() + loss_log["local_loss"] = step_loss + loss_log["reduced_llm_loss"] = reduced_llm_loss if step_balancing_loss is not None: reduced_balancing_loss = step_balancing_loss dist.all_reduce(reduced_balancing_loss.div_(dist.get_world_size())) - loss_log["reduced_balancing_loss"] = reduced_balancing_loss.item() + loss_log["reduced_balancing_loss"] = reduced_balancing_loss if step_z_loss is not None: reduced_z_loss = step_z_loss dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) - loss_log["reduced_z_loss"] = reduced_z_loss.item() + loss_log["reduced_z_loss"] = reduced_z_loss - other_log["step_consumed_tokens"] = int(step_consumed_tokens.item()) + other_log["step_consumed_tokens"] = int(step_consumed_tokens) other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment] - other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() + other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens) other_log["step_consumed_img_tokens"] = int(step_consumed_img_tokens) extra_info = other_log.get("extra_info", {}) # type: ignore diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index ea128ad3c..552993f5b 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -169,12 +169,10 @@ def loss_fn( loss_weight = loss_weight.flatten() rank_grad_tokens = (shifted_labels != self.loss_cfg.ignore_idx).sum() - if rank_grad_tokens == 0: - loss = logits.sum() * 0 - else: - loss = F.cross_entropy(logits, shifted_labels, reduction="none", ignore_index=self.loss_cfg.ignore_idx) - # Step 2.b in the loss calculation: sum the loss over all tokens - loss = (loss * loss_weight).sum() + loss_when_zero = logits.sum() * 0 + loss_normal = F.cross_entropy(logits, shifted_labels, reduction="none", ignore_index=self.loss_cfg.ignore_idx) + loss_normal = (loss_normal * loss_weight.to(loss_normal.device)).sum() + loss = torch.where(rank_grad_tokens == 0, loss_when_zero, loss_normal) return loss, (logits, {}) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 1d35dd8b1..a0c96a3b3 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -764,7 +764,7 @@ def fit(self): self._log_step( loss_log=loss_log, training_metrics=training_metrics, - grad_norm=grad_norm.item(), + grad_norm=grad_norm, data_time=data_time, step_time=step_time, internal_metrics=internal_metrics, From dc0337858eb40c6c0c6f42ac721d433d52718110 Mon Sep 17 00:00:00 2001 From: wentiange Date: Thu, 5 Mar 2026 04:18:17 +0000 Subject: [PATCH 3/3] [Feature] Layer-wise MoE balance loss computation --- xtuner/v1/model/moe/moe.py | 194 +++++++++++++++++++++++-------------- 1 file changed, 119 insertions(+), 75 deletions(-) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index a36909438..72152acb5 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -50,13 +50,14 @@ from xtuner.v1.utils import ( get_device, get_logger, + get_torch_device_module, ) from xtuner.v1.utils.activation_offload import async_save_on_cpu DEVICE = get_device() logger = get_logger() - +DEVICE_MODULE = get_torch_device_module() MOE_NON_EP_COMPILE_CFG: dict[str, TorchCompileOption] = { "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEBlock.forward": TorchCompileOption(fullgraph=True), @@ -339,13 +340,12 @@ def _micro_batch_forward( cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=2), ) ) + mask_list = torch.cat([ctx.mask for ctx in seq_ctx_list], dim=1) + indices = torch.nonzero(mask_list, as_tuple=True)[1] # Initialize output containers output: dict = {} - router_logits_list: list[dict[str, torch.Tensor]] = [{} for _ in range(len(seq_ctx_list))] - router_weights_list: list[dict[str, torch.Tensor]] = [{} for _ in range(len(seq_ctx_list))] - # Process through layers cat_seq_ctx: SequenceContext | None = None @@ -354,6 +354,12 @@ def _micro_batch_forward( for seq_ctx in seq_ctx_list: self._mark_dynamic(seq_ctx) + num_layers = len(self.layers) + local_load = torch.zeros(num_layers, self.config.n_routed_experts).to(DEVICE_MODULE.current_device()) # (nlayers, ne) + local_gating_sum = torch.zeros(num_layers, self.config.n_routed_experts).to(DEVICE_MODULE.current_device()) + local_load_logits = torch.zeros(num_layers, self.config.n_routed_experts, + dtype = torch.int64).to(DEVICE_MODULE.current_device()) + for idx, decoder_layer in self.layers.items(): layer_idx = int(idx) @@ -397,17 +403,39 @@ def _micro_batch_forward( *hidden_states_list, position_embeddings=position_embeddings_list, seq_ctx=seq_ctx_list, - ) - hidden_states = layer_results[: len(hidden_states_list)] - router_logits = layer_results[len(hidden_states_list) : len(hidden_states_list) * 2] - router_weights = layer_results[len(hidden_states_list) * 2 :] - - # Update hidden states and collect router results - for i, hidden_states in enumerate(hidden_states): - hidden_states_list[i] = hidden_states - router_logits_list[i][f"layer{idx}"] = router_logits[i] - router_weights_list[i][f"layer{idx}"] = router_weights[i] - + ) + hidden_states_list = layer_results[:len(hidden_states_list)] + if self.balancing_loss: + router_weights_list = torch.cat(layer_results[len(hidden_states_list)*2:], dim=0).unsqueeze(0) + router_weights = ( + torch.index_select(router_weights_list, 1, indices).contiguous().float() + ) + _, selected_experts = torch.topk(router_weights, self.config.num_experts_per_tok, dim=-1) + tokens_per_expert = torch.histc( + selected_experts.view(-1), + bins=self.config.n_routed_experts, + min=0, + max=self.config.n_routed_experts, + ).float() + local_load[layer_idx] = tokens_per_expert + local_gating_sum[layer_idx] = router_weights.sum(dim = 1) + + router_logits_list = torch.cat(layer_results[len(hidden_states_list):len(hidden_states_list)*2], dim=0).unsqueeze(0) + router_logits = ( + torch.index_select(router_logits_list, 1, indices).contiguous().float() + ) + _, selected_experts = torch.topk(router_logits, self.config.num_experts_per_tok, dim=-1) + tokens_per_expert = torch.histc( + selected_experts.view(-1), + bins=self.config.n_routed_experts, + min=0, + max=self.config.n_routed_experts, + ).to(torch.long) + local_load_logits[layer_idx] = tokens_per_expert + + if self.z_loss: + z_loss = self.z_loss(router_logits=router_logits) + output["z_loss"] = output.get("z_loss", 0) + z_loss # Apply final norm to all micro-batches cat_hidden_states = torch.cat(hidden_states_list, dim=1) cat_hidden_states = self.norm(cat_hidden_states) @@ -423,48 +451,27 @@ def _micro_batch_forward( moe_extra_info.append(extra_info) output["extra_info"] = moe_extra_info - # Handle router results for all micro-batches - all_router_logits = [] - all_router_weights = [] - - for micro_batch_idx, (micro_batch_router_logits, micro_batch_router_weights) in enumerate( - zip(router_logits_list, router_weights_list) - ): - if micro_batch_router_logits: - _router_logits_list = list(micro_batch_router_logits.values()) - _router_weights_list = list(micro_batch_router_weights.values()) - - attn_mask = seq_ctx_list[micro_batch_idx].mask - router_logits = self._select_non_pad_router_logits(_router_logits_list, attn_mask) - router_weights = self._select_non_pad_router_logits(_router_weights_list, attn_mask) - all_router_logits.append(router_logits) - all_router_weights.append(router_weights) - - if all_router_logits: - # Concatenate router logits from all micro-batches - combined_router_logits = torch.cat(all_router_logits, dim=1) # [num_layers, total_seq, num_experts] - combined_router_weights = torch.cat(all_router_weights, dim=1) - - # Calculate balancing loss across all micro-batches - if self.balancing_loss: - balancing_loss = self.balancing_loss( - router_weights=combined_router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, - ) - output["balancing_loss"] = balancing_loss - - # Calculate z-loss across all micro-batches - if self.z_loss: - z_loss = self.z_loss(router_logits=combined_router_logits) - output["z_loss"] = z_loss - - # Calculate tokens per expert for bias update (if applicable) - tokens_per_expert_global = self._cal_tokens_per_expert(combined_router_logits) - output["tokens_per_expert_global"] = tokens_per_expert_global - - del combined_router_logits + if self.balancing_loss: + if self.balancing_loss.global_average and dist.is_initialized(): + tokens_per_expert_global = all_reduce(local_load, "sum", dist.group.WORLD) + tokens_global = tokens_per_expert_global.sum(-1) + seqlen_global = tokens_global // self.config.num_experts_per_tok + dist.all_reduce(local_gating_sum,) + routing_weights_mean_global = local_gating_sum / seqlen_global.unsqueeze(-1) + scale_global = self.config.n_routed_experts / tokens_global + else: + scale_global = self.config.n_routed_experts / (router_weights.shape[1] * self.config.num_experts_per_tok) + routing_weights_mean_global = local_gating_sum / router_weights.shape[1] + loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1) + loss = loss.sum() + output["balancing_loss"] = loss * self.balancing_loss.loss_weight + # Calculate tokens per expert for bias update (if applicable) + if dist.is_initialized(): + tokens_per_expert_global = all_reduce(local_load_logits, "sum", dist.group.WORLD) + else: + tokens_per_expert_global = local_load_logits + output["tokens_per_expert_global"] = tokens_per_expert_global if self.config.return_router_results or return_router_logits: # raise NotImplementedError @@ -511,6 +518,13 @@ def _forward( self._mark_dynamic(seq_ctx) + indices = torch.nonzero(seq_ctx.mask, as_tuple=True)[1] + + num_layers = len(self.layers) + local_load = torch.zeros(num_layers, self.config.n_routed_experts).to(DEVICE_MODULE.current_device()) + local_gating_sum = torch.zeros(num_layers, self.config.n_routed_experts).to(DEVICE_MODULE.current_device()) + local_load_logits = torch.zeros(num_layers, self.config.n_routed_experts, + dtype = torch.int64).to(DEVICE_MODULE.current_device()) for idx, decoder_layer in self.layers.items(): if int(idx) < self.config.first_k_dense_replace: hidden_states = decoder_layer( @@ -540,8 +554,36 @@ def _forward( seq_ctx=seq_ctx, ) hidden_states, router_results, router_weights = layer_results - output["router_logits"][f"layer{idx}"] = router_results - output["router_weights"][f"layer{idx}"] = router_weights + if self.balancing_loss: + router_weights = ( + torch.index_select(router_weights, 0, indices).contiguous().float() + ) + _, selected_experts = torch.topk(router_weights, self.config.num_experts_per_tok, dim=-1) + tokens_per_expert = torch.histc( + selected_experts.view(-1), + bins=self.config.n_routed_experts, + min=0, + max=self.config.n_routed_experts, + ).float() + layer_idx = int(idx) + local_load[layer_idx] = tokens_per_expert + local_gating_sum[layer_idx] = router_weights.sum(dim = 0) + + router_logits = ( + torch.index_select(router_results, 0, indices).contiguous().float() + ) + _, selected_experts = torch.topk(router_logits, self.config.num_experts_per_tok, dim=-1) + tokens_per_expert = torch.histc( + selected_experts.view(-1), + bins=self.config.n_routed_experts, + min=0, + max=self.config.n_routed_experts, + ).to(torch.long) + local_load_logits[layer_idx] = tokens_per_expert + + if self.z_loss: + z_loss = self.z_loss(router_logits=router_logits) + output["z_loss"] = output.get("z_loss", 0) + z_loss if self.config.return_hidden_states: output["hidden_states"].append(hidden_states) @@ -553,27 +595,29 @@ def _forward( output["logits"] = logits output["extra_info"] = extra_info - router_logits_list = list(output["router_logits"].values()) # type: ignore - router_weights_list = list(output["router_weights"].values()) # type: ignore - router_logits = self._select_non_pad_router_logits(router_logits_list, seq_ctx.mask) - router_weights = self._select_non_pad_router_logits(router_weights_list, seq_ctx.mask) - if self.balancing_loss: - balancing_loss = self.balancing_loss( - router_weights=router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, - ) - output["balancing_loss"] = balancing_loss - - if self.z_loss: - z_loss = self.z_loss(router_logits=router_logits) - output["z_loss"] = z_loss + if self.balancing_loss.global_average and dist.is_initialized(): + tokens_per_expert_global = all_reduce(local_load, "sum", dist.group.WORLD) + tokens_global = tokens_per_expert_global.sum(-1) + seqlen_global = tokens_global // self.config.num_experts_per_tok + + dist.all_reduce(local_gating_sum,) + routing_weights_mean_global = local_gating_sum / seqlen_global.unsqueeze(-1) + scale_global = self.config.n_routed_experts / tokens_global + else: + scale_global = self.config.n_routed_experts / (router_weights.shape[0] * self.config.num_experts_per_tok) + routing_weights_mean_global = local_gating_sum / router_weights.shape[0] + loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1) + loss = loss.sum() + output["balancing_loss"] = loss * self.balancing_loss.loss_weight - tokens_per_expert_global = self._cal_tokens_per_expert(router_logits) + # Calculate tokens per expert for bias update (if applicable) + if dist.is_initialized(): + tokens_per_expert_global = all_reduce(local_load_logits, "sum", dist.group.WORLD) + else: + tokens_per_expert_global = local_load_logits output["tokens_per_expert_global"] = tokens_per_expert_global - del router_logits if self.config.return_router_results or return_router_logits: # raise NotImplementedError