From 8cedf397eefc613800437727f67a4015e4494a9a Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Thu, 12 Mar 2026 20:38:46 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tests/engine/test_dense_train_engine.py | 4 +- tests/engine/test_moe_train_engine.py | 8 +- tests/engine/test_moe_train_engine_float8.py | 12 +- tests/loss/test_ce_loss.py | 8 +- tests/loss/test_grpo_loss.py | 2 +- tests/loss/test_oreal_loss.py | 3 +- tests/model/test_gpt_oss_moe.py | 8 +- tests/model/test_intern_s1.py | 16 +- tests/model/test_moe.py | 10 +- tests/model/test_qwen3_5.py | 2 +- tests/model/test_qwen3_dense.py | 12 +- tests/model/test_qwen3_moe.py | 12 +- tests/model/test_qwen3_tile_embedding.py | 8 +- tests/model/test_qwen3_vl.py | 2 +- xtuner/v1/engine/train_engine.py | 2 +- xtuner/v1/loss/__init__.py | 17 +- xtuner/v1/loss/base_loss_ctx.py | 21 +- xtuner/v1/loss/ce_loss.py | 20 +- xtuner/v1/loss/moe_loss.py | 237 +++++++++++++++++- xtuner/v1/model/base.py | 104 +++++++- xtuner/v1/model/compose/base.py | 12 +- .../compose/intern_s1/modeling_intern_s1.py | 2 +- .../compose/qwen3_vl/modeling_qwen3_vl.py | 2 +- xtuner/v1/model/dense/dense.py | 23 +- xtuner/v1/model/moe/moe.py | 175 +++++++------ xtuner/v1/rl/base/loss.py | 38 ++- xtuner/v1/rl/base/worker.py | 17 +- xtuner/v1/train/trainer.py | 46 ++-- 28 files changed, 648 insertions(+), 175 deletions(-) diff --git a/tests/engine/test_dense_train_engine.py b/tests/engine/test_dense_train_engine.py index f25282750..88bff9e0f 100644 --- a/tests/engine/test_dense_train_engine.py +++ b/tests/engine/test_dense_train_engine.py @@ -83,13 +83,13 @@ def warmup_fn(x): seq_ctx = seq_ctx.split(sequence_parallel_mesh=sp_mesh) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) seq_ctx = seq_ctx_list[0] loss_ctx = loss_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/engine/test_moe_train_engine.py b/tests/engine/test_moe_train_engine.py index 1dc868509..4302aa160 100644 --- a/tests/engine/test_moe_train_engine.py +++ b/tests/engine/test_moe_train_engine.py @@ -93,12 +93,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -184,12 +184,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/engine/test_moe_train_engine_float8.py b/tests/engine/test_moe_train_engine_float8.py index 15ea4b673..7246b2514 100644 --- a/tests/engine/test_moe_train_engine_float8.py +++ b/tests/engine/test_moe_train_engine_float8.py @@ -87,12 +87,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -165,12 +165,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -264,12 +264,12 @@ def warmup_fn(x): seq_ctx.to('cuda') seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] logs_info = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/loss/test_ce_loss.py b/tests/loss/test_ce_loss.py index 863eae4da..195bc08f4 100644 --- a/tests/loss/test_ce_loss.py +++ b/tests/loss/test_ce_loss.py @@ -72,7 +72,7 @@ def test_global_loss_reduction(self, loss_mode, grad_accumulation_steps, chunk_s for data in data_batch: seq_ctx = data["seq_ctx"] seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None) loss_ctx_list.append(loss_ctx) loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list]) @@ -172,7 +172,7 @@ def test_other_loss_reduction(self, loss_reduction, loss_mode, grad_accumulation for data in data_batch: seq_ctx = data["seq_ctx"] seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None) loss_ctx_list.append(loss_ctx) loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list]) @@ -310,7 +310,7 @@ def test_sp_global_loss_reduction(self, loss_mode, sp_size, grad_accumulation_st sp_mesh = data_mesh['sp'] seq_ctx.sequence_parallel_mesh = sp_mesh seq_ctx_list = [seq_ctx] - loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] if sp_size > 1: seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh) @@ -397,7 +397,7 @@ def test_sp_others_loss_reduction(self, loss_reduction, loss_mode, sp_size, grad sp_mesh = data_mesh['sp'] seq_ctx.sequence_parallel_mesh = sp_mesh seq_ctx_list = [seq_ctx] - loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] if sp_size > 1: seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh) diff --git a/tests/loss/test_grpo_loss.py b/tests/loss/test_grpo_loss.py index 0ee237894..7a1747a12 100644 --- a/tests/loss/test_grpo_loss.py +++ b/tests/loss/test_grpo_loss.py @@ -147,7 +147,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size, if sp_size > 1: seq_ctx = seq_ctx.split(sp_mesh) seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels_list_rank[iter_idx], advantages=advantages_list_rank[iter_idx], sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh) loss_ctx_list.append(loss_ctx) with torch.no_grad(): diff --git a/tests/loss/test_oreal_loss.py b/tests/loss/test_oreal_loss.py index 1e50b2e48..2ceae417d 100644 --- a/tests/loss/test_oreal_loss.py +++ b/tests/loss/test_oreal_loss.py @@ -216,8 +216,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size, seq_ctx = seq_ctx.split(sp_mesh) seq_ctx_list.append(seq_ctx) loss_ctx = loss_cfg.build( - shifted_labels=shifted_labels_list_rank[iter_idx], - advantages=advantages_list_rank[iter_idx], + data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh, ) loss_ctx_list.append(loss_ctx) diff --git a/tests/model/test_gpt_oss_moe.py b/tests/model/test_gpt_oss_moe.py index d0f011bc4..edabb9ea0 100644 --- a/tests/model/test_gpt_oss_moe.py +++ b/tests/model/test_gpt_oss_moe.py @@ -78,7 +78,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -87,7 +87,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class with torch.no_grad(): output = gpt_oss_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -141,7 +141,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size): loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -152,7 +152,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size): with torch.no_grad(): output = gpt_oss_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2)) diff --git a/tests/model/test_intern_s1.py b/tests/model/test_intern_s1.py index 00f96c8b2..db76848f2 100644 --- a/tests/model/test_intern_s1.py +++ b/tests/model/test_intern_s1.py @@ -78,7 +78,7 @@ def test_interns1_text_run(self, device, tol): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -87,7 +87,7 @@ def test_interns1_text_run(self, device, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -186,7 +186,7 @@ def test_interns1_image_run(self, device, sp_size, tol): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -195,7 +195,7 @@ def test_interns1_image_run(self, device, sp_size, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -256,7 +256,7 @@ def test_fsdp_text_accuracy(self, device, tol): seq_ctx_list = [seq_ctx] loss_cfg = CELossConfig() LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -265,7 +265,7 @@ def test_fsdp_text_accuracy(self, device, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -370,7 +370,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol): seq_ctx_list = [seq_ctx] loss_cfg = CELossConfig() LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -379,7 +379,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) diff --git a/tests/model/test_moe.py b/tests/model/test_moe.py index e34b58bd6..c7c1c3137 100644 --- a/tests/model/test_moe.py +++ b/tests/model/test_moe.py @@ -62,12 +62,12 @@ def test_moe_config(self, dtype, device): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - model(seq_ctx=seq_ctx, loss_ctx=loss_ctx) + model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) class TestDistributedMoE(DeterministicDDPTestCase): @@ -135,15 +135,15 @@ def test_parallel_accuracy(self, dtype, device, dispatcher, n_shared_experts, fi seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"] + loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"] - loss_expected = model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"] + loss_expected = model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"] torch.allclose(loss_expected, loss_parallel, atol=1e-6, rtol=1e-4) diff --git a/tests/model/test_qwen3_5.py b/tests/model/test_qwen3_5.py index afe348aa9..58a0f484d 100644 --- a/tests/model/test_qwen3_5.py +++ b/tests/model/test_qwen3_5.py @@ -133,7 +133,7 @@ def _forward(self, model, type, device, sp_size): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] diff --git a/tests/model/test_qwen3_dense.py b/tests/model/test_qwen3_dense.py index efe327cb9..94aded680 100644 --- a/tests/model/test_qwen3_dense.py +++ b/tests/model/test_qwen3_dense.py @@ -64,7 +64,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -75,7 +75,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -120,7 +120,7 @@ def test_fsdp_accuracy(self, device, tp_size): seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -131,7 +131,7 @@ def test_fsdp_accuracy(self, device, tp_size): with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2)) @@ -196,7 +196,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -207,7 +207,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) assert "loss" in output diff --git a/tests/model/test_qwen3_moe.py b/tests/model/test_qwen3_moe.py index 4113be5a8..dbf5e7366 100644 --- a/tests/model/test_qwen3_moe.py +++ b/tests/model/test_qwen3_moe.py @@ -98,7 +98,7 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_mod loss_cfg = CELossConfig(mode=loss_mode) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -107,7 +107,7 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_mod with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] losses.append(loss) @@ -181,7 +181,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size, model_type): loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -190,7 +190,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size, model_type): with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] losses.append(loss) @@ -257,7 +257,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -268,7 +268,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) assert "loss" in output diff --git a/tests/model/test_qwen3_tile_embedding.py b/tests/model/test_qwen3_tile_embedding.py index 79869da87..b931395c5 100644 --- a/tests/model/test_qwen3_tile_embedding.py +++ b/tests/model/test_qwen3_tile_embedding.py @@ -78,12 +78,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] engine.train_step(engine_input) grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -153,12 +153,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] engine.train_step(engine_input) grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/model/test_qwen3_vl.py b/tests/model/test_qwen3_vl.py index f6027e879..4c6dc7ab4 100644 --- a/tests/model/test_qwen3_vl.py +++ b/tests/model/test_qwen3_vl.py @@ -139,7 +139,7 @@ def _test_all(self, hf_model, qwen3vl_model, type, device, sp_size, tol): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 7d12bc27d..7510a080b 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -216,7 +216,7 @@ def train_step(self, data_batches: list[ModelItem]) -> TrainStepInfo: # Here we assume that the model can handle a list of seq_ctx and loss_ctx. output = self.model( seq_ctx=seq_ctx_list, - loss_ctx=loss_ctx_list, + loss_ctx=loss_ctx_list, # type: ignore[arg-type] ) output.free_nongrad_feature() diff --git a/xtuner/v1/loss/__init__.py b/xtuner/v1/loss/__init__.py index 96973e0b9..1aa575601 100644 --- a/xtuner/v1/loss/__init__.py +++ b/xtuner/v1/loss/__init__.py @@ -1,12 +1,27 @@ from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs from .ce_loss import CELossConfig, CELossContext from .chunk_loss import ChunkLoss -from .moe_loss import BalancingLoss, ZLoss +from .moe_loss import ( + BalancingLoss, + BalancingLossConfig, + BalancingLossContext, + BalancingLossKwargs, + ZLoss, + ZLossConfig, + ZLossContext, + ZLossKwargs, +) __all__ = [ "BalancingLoss", + "BalancingLossConfig", + "BalancingLossContext", + "BalancingLossKwargs", "ZLoss", + "ZLossConfig", + "ZLossContext", + "ZLossKwargs", "CELossContext", "CELossConfig", "ChunkLoss", diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index 8ae297b6f..860873c41 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -121,8 +121,25 @@ def loss_ctx_cls(self) -> type["BaseLossContext"]: def _loss_kwargs_cls(self) -> type["BaseLossKwargs"]: raise NotImplementedError - def build(self, *args, **kwargs) -> "BaseLossContext": - raise NotImplementedError + @abstractmethod + def build( + self, + data: dict, + sp_mesh: "DeviceMesh | None" = None, + ) -> "BaseLossContext | None": + """Build loss context from data dict. + + Subclasses should extract required fields from data dict and construct loss_kwargs. + + Args: + data (dict): Data dict containing all possible loss-related fields. + Different loss configs extract different fields as needed. + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + BaseLossContext: Built loss context. + """ + ... # NOTE: Self type for BaseLossContext subclasses (F-bounded polymorphism) diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index 7d70a43e0..da305783e 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -43,9 +43,25 @@ def model_post_init(self, __context: Any) -> None: def build( self, - shifted_labels: torch.Tensor, + data: dict, sp_mesh: DeviceMesh | None = None, - ) -> "CELossContext": + ) -> "CELossContext | None": + """Build CELossContext from data dict. + + Args: + data (dict): Data dict containing loss-related fields. + Required: shifted_labels + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + CELossContext | None: Built loss context. Returns None if shifted_labels + is not present in data dict. + """ + if "shifted_labels" not in data: + return None + # Extract required fields from data + shifted_labels = data["shifted_labels"] + loss_kwargs = CELossKwargs(shifted_labels=shifted_labels).to(DEVICE) if sp_mesh is not None and sp_mesh.size() > 1: loss_kwargs = loss_kwargs.sp_split(sp_mesh) diff --git a/xtuner/v1/loss/moe_loss.py b/xtuner/v1/loss/moe_loss.py index d5aa64263..bd3d776ef 100644 --- a/xtuner/v1/loss/moe_loss.py +++ b/xtuner/v1/loss/moe_loss.py @@ -1,9 +1,17 @@ -from typing import Literal +from typing import Annotated, Any, Literal import torch import torch.nn as nn +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict from torch import distributed as dist from torch.distributed._functional_collectives import all_reduce +from torch.distributed.device_mesh import DeviceMesh + +from xtuner.v1.utils.device import get_device + + +DEVICE = get_device() class _AllReduce(torch.autograd.Function): @@ -106,3 +114,230 @@ def forward(self, router_logits): return torch.tensor(0.0, device=router_logits.device, dtype=torch.float32) loss = z_loss(router_logits, self.global_average) return loss * self.loss_weight + + +# ==================== New LossContext-based implementation ==================== + + +class BalancingLossConfig(BaseModel): + """Balancing loss configuration for MoE models. + + Args: + balancing_loss_alpha (float): Weight for the balancing loss. Defaults to 0.001. + balancing_loss_global_average (bool): Whether to perform global averaging across all ranks. + Defaults to True. + router_scoring_func (str): Router scoring function type. Options are "sigmoid" and "softmax". + Defaults to "softmax". + """ + + model_config = ConfigDict(extra="forbid") + balancing_loss_alpha: Annotated[float, Parameter(help="weight for balancing loss")] = 0.001 + balancing_loss_global_average: Annotated[bool, Parameter(help="global average for balancing loss")] = True + router_scoring_func: Annotated[Literal["sigmoid", "softmax"], Parameter(help="router scoring function")] = ( + "softmax" + ) + + def build(self) -> "BalancingLossContext": + """Build BalancingLossContext. + + Returns: + BalancingLossContext: Built loss context. + """ + loss_kwargs = BalancingLossKwargs() + return BalancingLossContext(self, loss_kwargs) + + +class BalancingLossKwargs(BaseModel): + """Keyword arguments for balancing loss computation. + + This class is empty as all parameters are passed to forward(). + """ + + model_config = ConfigDict(title="balancing loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) + + +class BalancingLossContext(nn.Module): + """Balancing loss context for MoE models. + + Args: + loss_cfg (BalancingLossConfig): The configuration for the balancing loss. + loss_kwargs (BalancingLossKwargs): The keyword arguments for the balancing loss. + """ + + def __init__(self, loss_cfg: BalancingLossConfig, loss_kwargs: BalancingLossKwargs): + super().__init__() + self.loss_cfg = loss_cfg + self.loss_kwargs = loss_kwargs + self._batch_size = 1 + + @staticmethod + def build_batches( + loss_ctx_list: list["BalancingLossContext"], + ) -> list["BalancingLossContext"]: + """Build batches for balancing loss contexts. + + For balancing loss, we set the batch size for proper gradient accumulation. + + Args: + loss_ctx_list (list[BalancingLossContext]): List of loss contexts. + + Returns: + list[BalancingLossContext]: The same list with batch_size set. + """ + for loss_ctx in loss_ctx_list: + loss_ctx._batch_size = len(loss_ctx_list) + return loss_ctx_list + + def forward( + self, + router_weights: torch.Tensor, + n_routed_experts: int, + num_experts_per_tok: int, + ) -> torch.Tensor: + """Compute balancing loss. + + Args: + router_weights (torch.Tensor): Router weights. Shape: (num_layers, seq_len, num_experts). + n_routed_experts (int): Number of routed experts. + num_experts_per_tok (int): Number of experts per token. + + Returns: + torch.Tensor: Balancing loss value. + """ + if self.loss_cfg.balancing_loss_alpha == 0: + return torch.tensor(0.0, device=router_weights.device, dtype=torch.float32) + + num_layers = router_weights.shape[0] + router_weights = router_weights.float() # (nlayers, seq, ne) + _, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1) + selected_experts_flat = selected_experts.view(num_layers, -1) + offset = torch.arange(num_layers, device=router_weights.device).unsqueeze(1) * n_routed_experts + selected_experts_offset = selected_experts_flat + offset + tokens_per_expert_flat = torch.histc( + selected_experts_offset.view(-1), + bins=num_layers * n_routed_experts, + min=0, + max=num_layers * n_routed_experts, + ) + tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) # (nlayers, ne) + + tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne) + if self.loss_cfg.balancing_loss_global_average and dist.is_initialized(): + tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) + tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, ) + seqlen_global = tokens_global // num_experts_per_tok + routing_weights_sum_global = all_reduce_autograd(router_weights.sum(dim=1), "sum", dist.group.WORLD) + routing_weights_mean_global = routing_weights_sum_global / seqlen_global.unsqueeze(-1) + scale_global = n_routed_experts / tokens_global + else: + scale_global = n_routed_experts / (router_weights.shape[1] * num_experts_per_tok) + routing_weights_mean_global = router_weights.mean(dim=1) + + loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1) + loss = loss.sum() * self.loss_cfg.balancing_loss_alpha + + # Normalize by batch size for proper gradient accumulation + loss = loss / self._batch_size + + return loss + + @property + def batch_size(self) -> int: + return self._batch_size + + +class ZLossConfig(BaseModel): + """Z-loss configuration for MoE models. + + Args: + z_loss_alpha (float): Weight for the z-loss. Defaults to 0.001. + z_loss_global_average (bool): Whether to perform global averaging across all ranks. + Defaults to True. + """ + + model_config = ConfigDict(extra="forbid") + z_loss_alpha: Annotated[float, Parameter(help="weight for z-loss")] = 0.001 + z_loss_global_average: Annotated[bool, Parameter(help="global average for z-loss")] = True + + def build(self) -> "ZLossContext": + """Build ZLossContext. + + Returns: + ZLossContext: Built loss context. + """ + loss_kwargs = ZLossKwargs() + return ZLossContext(self, loss_kwargs) + + +class ZLossKwargs(BaseModel): + """Keyword arguments for z-loss computation.""" + + model_config = ConfigDict(title="z-loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) + + +class ZLossContext(nn.Module): + """Z-loss context for MoE models. + + Args: + loss_cfg (ZLossConfig): The configuration for the z-loss. + loss_kwargs (ZLossKwargs): The keyword arguments for the z-loss. + """ + + def __init__(self, loss_cfg: ZLossConfig, loss_kwargs: ZLossKwargs): + super().__init__() + self.loss_cfg = loss_cfg + self.loss_kwargs = loss_kwargs + self._batch_size = 1 + + @staticmethod + def build_batches( + loss_ctx_list: list["ZLossContext"], + ) -> list["ZLossContext"]: + """Build batches for z-loss contexts. + + For z-loss, we set the batch size for proper gradient accumulation. + + Args: + loss_ctx_list (list[ZLossContext]): List of loss contexts. + + Returns: + list[ZLossContext]: The same list with batch_size set. + """ + for loss_ctx in loss_ctx_list: + loss_ctx._batch_size = len(loss_ctx_list) + return loss_ctx_list + + def forward(self, router_logits: torch.Tensor) -> torch.Tensor: + """Compute z-loss. + + Args: + router_logits (torch.Tensor): Router logits. Shape: (num_layers, seq_len, num_experts). + + Returns: + torch.Tensor: Z-loss value. + """ + if self.loss_cfg.z_loss_alpha == 0: + return torch.tensor(0.0, device=router_logits.device, dtype=torch.float32) + + router_logits = router_logits.float() # (nlayers, seq, ne) + num_seq = max(1, router_logits.shape[1]) + logsum_square = torch.logsumexp(router_logits, dim=-1).square() + loss = (logsum_square.sum(dim=-1) / num_seq).sum() + + if self.loss_cfg.z_loss_global_average and dist.is_initialized(): + unmasked_num = router_logits.shape[1] + unmasked_num_rank = torch.tensor(unmasked_num, device=router_logits.device, dtype=torch.int64) + unmasked_num_global = all_reduce(unmasked_num_rank, "sum", dist.group.WORLD) + world_size = dist.get_world_size() + loss = loss * unmasked_num * world_size / unmasked_num_global + + loss = loss * self.loss_cfg.z_loss_alpha + + # Normalize by batch size for proper gradient accumulation + loss = loss / self._batch_size + + return loss + + @property + def batch_size(self) -> int: + return self._batch_size diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index d097800c7..92d39caee 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -8,7 +8,7 @@ from itertools import chain from pathlib import Path from shutil import copy, copytree -from typing import Annotated, Generator, Iterable, Literal, Mapping, Sequence, cast +from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, TypedDict, cast import torch import torch.distributed as dist @@ -39,7 +39,7 @@ WeightWithDynamicTensorWiseFloat8CastTensor, WeightWithDynamicTilewiseFloat8CastTensor, ) -from xtuner.v1.loss import BaseLossContext +from xtuner.v1.loss import BaseLossConfig, BaseLossContext, CELossConfig from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig, MLAConfig from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather @@ -98,6 +98,7 @@ class XTunerBaseModelConfig(PydanticBaseModel): ] = None hf_key_mapping: Annotated[dict[str, str] | None, "Remapping hf key based on the `to_hf_key_list`"] = None dcp_ignore_frozen_params: bool = True + lm_loss_cfg: BaseLossConfig = CELossConfig() @property def hf_config(self) -> PretrainedConfig | None: @@ -242,7 +243,7 @@ def _is_float8_available(): class ModelItem(TypedDict): seq_ctx: SequenceContext - loss_ctx: BaseLossContext + loss_ctx: dict[str, BaseLossContext] | None def is_float8_weight(tensor): @@ -594,6 +595,81 @@ def _to_float8( name_list_new.extend([name, f"{name}_scale_inv"]) return gathered_tensor_list_new, name_list_new + def build_loss_ctx_batch( + self, + data_batch: list[dict], + sp_mesh: DeviceMesh | None = None, + ) -> list[dict[str, dict]]: + """Build and calibrate loss contexts for the entire batch. + + For Dense model, only LM loss is needed. + + Args: + data_batch (list[dict]): All microbatch data + sp_mesh (DeviceMesh | None): Sequence parallel mesh + cu_seq_lens_list (list[torch.IntTensor] | None): For calibration + + Returns: + list[dict[str, BaseLossContext]]: Loss context dict for each microbatch + """ + cu_seq_lens_list = [data["seq_ctx"].cu_seq_lens_k for data in data_batch] + res: list[dict] = [{} for _ in range(len(data_batch))] + + lm_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, data_batch, sp_mesh) + + if lm_loss_ctx_list is not None: + loss_ctx_cls = lm_loss_ctx_list[0].__class__ + lm_loss_ctx_list = loss_ctx_cls.build_batches( + lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh) + + if lm_loss_ctx_list is not None: + for i, lm_loss_ctx in enumerate(lm_loss_ctx_list): + res[i]["lm"] = lm_loss_ctx + + return res + + def _add_auxiliary_loss( + self, + loss_name: str, + loss_cfg: Any, + data_batch: list[dict], + res: list[dict], + ) -> None: + """Add auxiliary loss contexts to result. + + This helper builds loss contexts, calibrates them across the batch, + and adds them to the result dictionary. If loss_cfg is None, does nothing. + + Args: + loss_name (str): Name of the loss (e.g., "balancing", "z_loss"). + loss_cfg (Any): Loss configuration with a build() method. If None, skipped. + data_batch (list[dict]): Batch data. + res (list[dict]): Result dictionary to populate. Modified in-place. + + Example: + def build_loss_ctx_batch(self, data_batch, sp_mesh): + res = super().build_loss_ctx_batch(data_batch, sp_mesh) + + # One line per auxiliary loss + self._add_auxiliary_loss("balancing", self.config.balancing_loss_cfg, data_batch, res) + self._add_auxiliary_loss("z_loss", self.config.z_loss_cfg, data_batch, res) + + return res + """ + if loss_cfg is None: + return + + # Build loss contexts for all microbatches + ctx_list = [loss_cfg.build() for _ in data_batch] + + # Calibrate across batch + ctx_cls = ctx_list[0].__class__ + ctx_list = ctx_cls.build_batches(ctx_list) + + # Add to result + for i, ctx in enumerate(ctx_list): + res[i][loss_name] = ctx # type: ignore + def pre_micro_batch_forward(self, data_batches: Sequence[ModelItem]) -> DataBatchInfo: step_consumed_tokens = torch.tensor(0, device=DEVICE) efficient_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long) @@ -1634,19 +1710,37 @@ def _collect_full_state_dict(self, module: nn.Module): ret[name] = param return ret + def _build_loss_ctx( + self, + loss_ctx_cfg: BaseLossConfig | None, + data_batch: list[dict], + sp_mesh: DeviceMesh | None + ) -> list[BaseLossContext] | None: + if loss_ctx_cfg is None: + return None + + first_loss_ctx = loss_ctx_cfg.build(data=data_batch[0], sp_mesh=sp_mesh) + # If first build returns None, assume all data in the batch have the same schema + # and will also return None (e.g., missing required fields like shifted_labels) + if first_loss_ctx is None: + return None + else: + ret = [first_loss_ctx] + [ + loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]] + return ret # NOTE: Add this overload for inferring the return type for easier type checking and using @overload # type: ignore def __call__( # type: ignore self, seq_ctx: SequenceContext, - loss_ctx: BaseLossContext | None, + loss_ctx: dict[str, BaseLossContext] | None, ) -> ModelOutputs: ... @overload # type: ignore def __call__( # type: ignore self, seq_ctx: list[SequenceContext], - loss_ctx: list[BaseLossContext], + loss_ctx: list[dict[str, BaseLossContext]], ) -> ModelOutputs: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/model/compose/base.py b/xtuner/v1/model/compose/base.py index c9cb432a0..7181d521a 100644 --- a/xtuner/v1/model/compose/base.py +++ b/xtuner/v1/model/compose/base.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist from pydantic import ConfigDict -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import ( CPUOffloadPolicy, FSDPModule, @@ -15,6 +15,7 @@ from typing_extensions import override from xtuner.v1.config import FSDPConfig +from xtuner.v1.loss import BaseLossContext from xtuner.v1.model import BaseModel from xtuner.v1.model.base import DataBatchInfo, ModelItem, XTunerBaseModelConfig from xtuner.v1.utils import get_device, get_logger @@ -169,6 +170,15 @@ def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, def scale_and_reduce_grad(self): self.language_model.scale_and_reduce_grad() + @override + def build_loss_ctx_batch( # type: ignore[override] + self, + data_batch: list[dict], + sp_mesh: DeviceMesh | None = None, + ) -> list[dict[str, BaseLossContext]]: + """Delegate loss_ctx building to the language model.""" + return self.language_model.build_loss_ctx_batch(data_batch, sp_mesh=sp_mesh) + def pre_micro_batch_forward(self, data_batches: Sequence[ModelItem]) -> DataBatchInfo: data_batch_info = cast(ComposeDataBatchInfo, super().pre_micro_batch_forward(data_batches)) step_consumed_img_tokens = 0.0 diff --git a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py index c82028a21..d5fb87507 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py @@ -122,7 +122,7 @@ def extract_feature(self, pixel_values): def forward( self, seq_ctx: SequenceContext, - loss_ctx: CELossContext + loss_ctx: dict[str, CELossContext] | None = None ) -> MoEModelOutputs: input_ids = seq_ctx.input_ids pixel_values = seq_ctx.pixel_values diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index 453fa2ba7..bade4c6c6 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -136,7 +136,7 @@ def get_placeholder_mask( def forward( self, seq_ctx: SequenceContext, - loss_ctx: CELossContext + loss_ctx: dict[str, CELossContext] | None = None ) -> MoEModelOutputs: input_ids = seq_ctx.input_ids pixel_values = seq_ctx.pixel_values diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 46b123401..3bba65470 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path -from typing import Self, cast +from typing import Self, cast, Literal import torch import torch.distributed as dist @@ -20,7 +20,7 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler -from xtuner.v1.loss import CELossContext +from xtuner.v1.loss import CELossContext, BaseLossContext from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -79,7 +79,7 @@ def __init__(self, config: TransformerConfig): def forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext, + loss_ctx: dict[Literal["lm"], BaseLossContext] | None = None, ) -> ModelOutputs: input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -111,10 +111,17 @@ def forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) - output["loss"] = loss - output["logits"] = logits - output["extra_info"] = extra_info + if loss_ctx is None: + # Inference mode + logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias) + output["logits"] = logits + else: + # Training mode + loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) + output["loss"] = loss + output["logits"] = logits + output["extra_info"] = extra_info + return ModelOutputs(**output) def build_embeddings(self, config: TransformerConfig): @@ -167,7 +174,7 @@ def default_compile_cfg(self) -> dict[str, TorchCompileOption]: def __call__( # type: ignore self, seq_ctx: SequenceContext, - loss_ctx: CELossContext, + loss_ctx: dict[str, CELossContext] | None = None, ) -> ModelOutputs: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 6da9f0921..7a4edd67e 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -2,7 +2,7 @@ import os import types from pathlib import Path -from typing import Annotated, Literal, Self, Sequence, cast +from typing import Annotated, Literal, Self, Sequence, cast, TypedDict import torch import torch.distributed as dist @@ -27,7 +27,18 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler -from xtuner.v1.loss import BalancingLoss, CELossContext, ZLoss +from xtuner.v1.loss import ( + BalancingLoss, + BalancingLossConfig, + BalancingLossContext, + BalancingLossKwargs, + CELossContext, + ZLoss, + ZLossConfig, + ZLossContext, + ZLossKwargs, + BaseLossContext, +) from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -109,31 +120,10 @@ class MoEBatchForwardInfo(BatchForwardInfo): maxvio: float -class BalancingLossConfig(PydanticBaseModel): - model_config = ConfigDict(extra="forbid") - balancing_loss_alpha: float = 0.001 - balancing_loss_global_average: bool = True - - def build(self, router_scoring_func) -> BalancingLoss: - return BalancingLoss( - self.balancing_loss_alpha, - self.balancing_loss_global_average, - router_scoring_func=router_scoring_func, - ) - - -class ZLossConfig(PydanticBaseModel): - model_config = ConfigDict(extra="forbid") - z_loss_alpha: float = 0.001 - z_loss_global_average: bool = True - - def build(self) -> "ZLoss": - from xtuner.v1.loss import ZLoss - - return ZLoss( - self.z_loss_alpha, - self.z_loss_global_average, - ) +class MoELossContextDict(TypedDict): + lm: BaseLossContext + balancing: BalancingLossContext | None + z_loss: ZLossContext | None class MoEConfig(TransformerConfig): @@ -205,17 +195,6 @@ def __init__(self, config: MoEConfig): self._init_load_spec() self._maybe_enable_compile(self.compile_cfg) - self.balancing_loss: BalancingLoss | None - self.z_loss: ZLoss | None - if self.config.balancing_loss_cfg is not None: - self.balancing_loss = self.config.balancing_loss_cfg.build(self.config.router.scoring_func) - else: - self.balancing_loss = None - if self.config.z_loss_cfg is not None: - self.z_loss = self.config.z_loss_cfg.build() - else: - self.z_loss = None - self.offload_stream = torch.cuda.Stream() def _select_non_pad_router_logits( @@ -304,10 +283,44 @@ def update_bias(self, total_expert_counts_pre_iter, expected_loads): e_score_correction_bias.add_(updates) + def build_loss_ctx_batch( + self, + data_batch: list["ColateItem"], + sp_mesh: DeviceMesh | None = None, + ) -> list[MoELossContextDict]: + """Build and calibrate loss contexts for MoE model. + + Args: + data_batch (list[dict]): All microbatch data + sp_mesh (DeviceMesh | None): Sequence parallel mesh + cu_seq_lens_list (list[torch.IntTensor] | None): For calibration + + Returns: + list[dict]: Loss context dict for each microbatch. + Each dict contains: + - "lm": LM loss context + - "balancing": Balancing loss context (if configured) + - "z_loss": Z-loss context (if configured) + + Note: + Auxiliary loss contexts are built without parameters. + All data is passed to forward() at runtime: + - balancing_ctx(router_weights, n_routed_experts, num_experts_per_tok) + - z_loss_ctx(router_logits) + """ + # Build LM loss context + res = super().build_loss_ctx_batch(data_batch, sp_mesh) + + # Add auxiliary losses + self._add_auxiliary_loss("balancing", self.config.balancing_loss_cfg, data_batch, res) + self._add_auxiliary_loss("z_loss", self.config.z_loss_cfg, data_batch, res) + + return res + def forward( self, seq_ctx: list[SequenceContext] | SequenceContext, - loss_ctx: list[CELossContext] | CELossContext | None, + loss_ctx: list[MoELossContextDict] | MoELossContextDict | None, return_router_logits: bool = False, ): # TODO: caoweihan: Recover this assertion after the refactor of LossContext @@ -362,7 +375,7 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[MoEModelOutputs]) -> def _micro_batch_forward( self, seq_ctx_list: list[SequenceContext], - loss_ctx_list: list[CELossContext], + loss_ctx_list: list[MoELossContextDict], return_router_logits: bool = False, ) -> MoEModelOutputs: """Micro-batch forward pass for MoE model. @@ -463,7 +476,9 @@ def _micro_batch_forward( cat_hidden_states = self.norm(cat_hidden_states) # Process final outputs for each micro-batch - cat_loss_ctx = CELossContext.cat(loss_ctx_list) + # Extract LM loss context from dict + lm_loss_ctx_list = [loss_ctx_dict["lm"] for loss_ctx_dict in loss_ctx_list] + cat_loss_ctx = type(lm_loss_ctx_list[0]).cat(lm_loss_ctx_list) loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) # Aggregate losses (mean across micro-batches) @@ -495,23 +510,35 @@ def _micro_batch_forward( combined_router_logits = torch.cat(all_router_logits, dim=1) # [num_layers, total_seq, num_experts] combined_router_weights = torch.cat(all_router_weights, dim=1) - # Calculate balancing loss across all micro-batches - batch_size = loss_ctx_list[0].batch_size if loss_ctx_list else 1 - if self.balancing_loss: - balancing_loss = ( - self.balancing_loss( - router_weights=combined_router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, + # Build balancing loss contexts + balancing_loss_ctx_list: list[BalancingLossContext] = [] + for loss_ctx_dict in loss_ctx_list: + bal_ctx = loss_ctx_dict.get("balancing") + if bal_ctx is not None: + balancing_loss_ctx_list.append(bal_ctx) + + if balancing_loss_ctx_list: + # Compute balancing loss by passing all parameters to forward + balancing_loss = sum( + ctx( + combined_router_weights, + self.config.n_routed_experts, + self.config.num_experts_per_tok, ) - / batch_size - * len(seq_ctx_list) + for ctx in balancing_loss_ctx_list ) output["balancing_loss"] = balancing_loss - # Calculate z-loss across all micro-batches - if self.z_loss: - z_loss = self.z_loss(router_logits=combined_router_logits) / batch_size * len(seq_ctx_list) + # Calculate z-loss across all micro-batches using loss context + z_loss_ctx_list: list[ZLossContext] = [] + for loss_ctx_dict in loss_ctx_list: + z_ctx = loss_ctx_dict.get("z_loss") + if z_ctx is not None: + z_loss_ctx_list.append(z_ctx) + + if z_loss_ctx_list: + # Compute z-loss by passing router_logits to forward + z_loss = sum(ctx(combined_router_logits) for ctx in z_loss_ctx_list) output["z_loss"] = z_loss # Calculate tokens per expert for bias update (if applicable) @@ -542,7 +569,7 @@ def _micro_batch_forward( def _forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext | None, + loss_ctx: MoELossContextDict | None, return_router_logits: bool = False, ) -> MoEModelOutputs: input_ids = seq_ctx.input_ids @@ -603,7 +630,9 @@ def _forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) # type: ignore + # Get LM loss context from dict + lm_loss_ctx = loss_ctx["lm"] if loss_ctx is not None else None + loss, (logits, extra_info) = self.lm_head(hidden_states, lm_loss_ctx) # type: ignore output["loss"] = loss output["logits"] = logits output["extra_info"] = extra_info @@ -613,21 +642,25 @@ def _forward( router_logits = self._select_non_pad_router_logits(router_logits_list, seq_ctx.mask) router_weights = self._select_non_pad_router_logits(router_weights_list, seq_ctx.mask) - batch_size = loss_ctx.batch_size if loss_ctx is not None else 1 - if self.balancing_loss: - balancing_loss = ( - self.balancing_loss( - router_weights=router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, + # Calculate balancing loss using loss context + if loss_ctx is not None: + balancing_ctx = loss_ctx.get("balancing") + if balancing_ctx is not None: + # Compute balancing loss by passing all parameters to forward + balancing_loss = balancing_ctx( + router_weights, + self.config.n_routed_experts, + self.config.num_experts_per_tok, ) - / batch_size - ) - output["balancing_loss"] = balancing_loss + output["balancing_loss"] = balancing_loss - if self.z_loss: - z_loss = self.z_loss(router_logits=router_logits) / batch_size - output["z_loss"] = z_loss + # Calculate z-loss using loss context + if loss_ctx is not None: + z_loss_ctx = loss_ctx.get("z_loss") + if z_loss_ctx is not None: + # Compute z-loss by passing router_logits to forward + z_loss = z_loss_ctx(router_logits) + output["z_loss"] = z_loss tokens_per_expert_global = self._cal_tokens_per_expert(router_logits) output["tokens_per_expert_global"] = tokens_per_expert_global @@ -983,14 +1016,14 @@ def patched_emb_forward(self, input): def __call__( # type: ignore self, seq_ctx: SequenceContext, - loss_ctx: CELossContext | None, + loss_ctx: MoELossContextDict | None, ) -> MoEModelOutputs: ... @overload # type: ignore def __call__( # type: ignore self, seq_ctx: list[SequenceContext], - loss_ctx: list[CELossContext], + loss_ctx: list[MoELossContextDict], ) -> MoEModelOutputs: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/rl/base/loss.py b/xtuner/v1/rl/base/loss.py index 8164538b9..5767e7a76 100644 --- a/xtuner/v1/rl/base/loss.py +++ b/xtuner/v1/rl/base/loss.py @@ -96,14 +96,36 @@ def _loss_kwargs_cls(self) -> type["BaseRLLossKwargs"]: def build( self, - sp_mesh: DeviceMesh | None, - shifted_labels: torch.Tensor, - advantages: torch.Tensor, - rollout_logprobs: torch.Tensor | None = None, - old_logprobs: torch.Tensor | None = None, - rollout_is_weights: torch.Tensor | None = None, - ref_logprobs: torch.Tensor | None = None, - ) -> "BaseRLLossContext": + data: dict, + sp_mesh: DeviceMesh | None = None, + ) -> "BaseRLLossContext | None": + """Build RL loss context from data dict. + + Args: + data (dict): Data dictionary containing RL-specific fields: + - shifted_labels (torch.Tensor): The shifted labels + - advantages (torch.Tensor): Advantage estimates + - rollout_logprobs (torch.Tensor | None): Rollout log probabilities + - old_logprobs (torch.Tensor | None): Old policy log probabilities (optional, can be set later) + - rollout_is_weights (torch.Tensor | None): Importance sampling weights + - ref_logprobs (torch.Tensor | None): Reference model log probabilities + sp_mesh (DeviceMesh | None): Sequence parallel device mesh + + Returns: + BaseRLLossContext | None: The built loss context, or None if required fields are missing + """ + # Check for required fields + if "shifted_labels" not in data or "advantages" not in data: + return None + + # Extract RL-specific fields from data + shifted_labels = data["shifted_labels"] + advantages = data["advantages"] + rollout_logprobs = data.get("rollout_logprobs", None) + old_logprobs = data.get("old_logprobs", None) + rollout_is_weights = data.get("rollout_is_weights", None) + ref_logprobs = data.get("ref_logprobs", None) + LossKwargs = self._loss_kwargs_cls loss_kwargs = LossKwargs( shifted_labels=shifted_labels, diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index b5dc7bb2b..61eca0e64 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -497,10 +497,16 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo rollout_logprobs = data.get("rollout_logprobs", None) rollout_logprobs = rollout_logprobs.to(DEVICE) if rollout_logprobs is not None else None loss_ctx = loss_cfg.build( - self.sp_mesh, shifted_labels=shifted_labels, advantages=advantages, rollout_logprobs=rollout_logprobs + data={ + "shifted_labels": shifted_labels, + "advantages": advantages, + "rollout_logprobs": rollout_logprobs, + }, + sp_mesh=self.sp_mesh, ) seq_ctx_list.append(seq_ctx) + assert loss_ctx is not None loss_ctx_list.append(loss_ctx) del data_batches @@ -592,8 +598,9 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo LossContext = loss_cfg.loss_ctx_cls for i in range(0, len(loss_ctx_list), iters_per_step): batches_loss_ctx = loss_ctx_list[i : i + iters_per_step] - batches_loss_ctx = LossContext.build_batches(batches_loss_ctx) - batched_loss_ctx_list.extend(batches_loss_ctx) + batched_loss_ctx_list.extend( + LossContext.build_batches(batches_loss_ctx) # type: ignore[arg-type] + ) # train optimizer steps for i in range(0, len(seq_ctx_list), iters_per_step): @@ -601,7 +608,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo batches_loss_ctx = batched_loss_ctx_list[i : i + iters_per_step] engine_input = [ - ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx) + ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx) # type: ignore[typeddict-item] for seq_ctx, loss_ctx in zip(batches_seq_ctx, batches_loss_ctx) ] @@ -703,7 +710,7 @@ def _train_one_step_sft(self, data_batch): if self.sp_mesh.size() > 1: seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=self.sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=self.sp_mesh) loss_ctx_list.append(loss_ctx) del data_batch diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 38d46b3a0..b61d27dac 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -32,7 +32,7 @@ from xtuner.v1.datasets.config import BaseDataloaderConfig, DataloaderConfig, DatasetConfigList from xtuner.v1.engine import TrainEngine from xtuner.v1.engine.train_engine import TrainStepInfo -from xtuner.v1.loss import CELossConfig, CELossContext +from xtuner.v1.loss import CELossConfig from xtuner.v1.model.base import ModelItem, XTunerBaseModelConfig from xtuner.v1.model.moe.moe import MoEConfig from xtuner.v1.patch import patch_default_save_plan @@ -579,7 +579,11 @@ def __init__( global_batch_size = self.data_mesh["dp"].size() self._global_batch_size = global_batch_size + if loss_cfg is None: + loss_cfg = CELossConfig() + self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg, fsdp_cfg) + self._resolve_model_loss_cfg(model_cfg, loss_cfg) if dataset_cfg is not None: # TODO: Removed in version 1.1.0 logger.warning("`dataset_cfg` is deprecated, please use `dataloader_cfg.dataset_config_list` instead") @@ -618,8 +622,6 @@ def __init__( self._lr_cfg = lr_cfg self._lr_scheduler = self.build_lr_scheduler(lr_cfg, self.total_step) - if loss_cfg is None: - loss_cfg = CELossConfig() self.loss_cfg = loss_cfg if debug: @@ -784,28 +786,26 @@ def fit(self): self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds") def _prepare_model_input(self, data_batch) -> list[ModelItem]: - loss_cfg: CELossConfig = self.loss_cfg seq_ctx_list: list[SequenceContext] = [] - loss_ctx_list: list[CELossContext] = [] + # 1. Extract seq_ctx for data in data_batch: - seq_ctx = data.pop("seq_ctx").to(DEVICE) + seq_ctx = data["seq_ctx"].to(DEVICE) if self.sp_mesh.size() > 1: seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=self.sp_mesh) - loss_ctx_list.append(loss_ctx) + + # 2. Compute cu_seq_lens_list (for calibration) + # 3. Call model's interface to build and calibrate all loss_ctx (done in one shot) + loss_ctx_dict_list = self._engine.model.build_loss_ctx_batch(data_batch, sp_mesh=self.sp_mesh) # TODO: Consider moving data_batch deletion to the caller for better memory management. del data_batch - cu_seq_lens_list = [seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list] - loss_ctx_list = CELossContext.build_batches( - loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=self.sp_mesh - ) - + # 4. Return ModelItem engine_input = [ - ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx) for seq_ctx, loss_ctx in zip(seq_ctx_list, loss_ctx_list) + ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx_dict) + for seq_ctx, loss_ctx_dict in zip(seq_ctx_list, loss_ctx_dict_list) ] return engine_input @@ -1721,6 +1721,24 @@ def _resolve_deprecated_resume_cfg(self, resume_cfg: ResumeConfig, auto_resume: return True return auto_resume + def _resolve_model_loss_cfg(self, model_cfg: XTunerBaseModelConfig, loss_cfg: CELossConfig): + """Backward compatibility: set Trainer's loss_cfg to model's lm_loss_cfg if not already set. + + Args: + model_cfg (XTunerBaseModelConfig): Model configuration + loss_cfg (CELossConfig): Loss configuration from Trainer + """ + if loss_cfg is not None: + if hasattr(model_cfg, "text_config"): + model_cfg.text_config.lm_loss_cfg = loss_cfg + else: + model_cfg.lm_loss_cfg = loss_cfg + if self.rank == 0: + logger.warning( + "Setting model_cfg.lm_loss_cfg from Trainer's loss_cfg for backward compatibility. " + "In the future, please set lm_loss_cfg directly in model_cfg instead of Trainer." + ) + def _resolve_load_checkpoint_cfg( self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig ) -> LoadCheckpointConfig: