Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/engine/test_dense_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ def warmup_fn(x):
seq_ctx = seq_ctx.split(sequence_parallel_mesh=sp_mesh)
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)

seq_ctx = seq_ctx_list[0]
loss_ctx = loss_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down
8 changes: 4 additions & 4 deletions tests/engine/test_moe_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down Expand Up @@ -184,12 +184,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down
12 changes: 6 additions & 6 deletions tests/engine/test_moe_train_engine_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down Expand Up @@ -165,12 +165,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down Expand Up @@ -264,12 +264,12 @@ def warmup_fn(x):
seq_ctx.to('cuda')
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
logs_info = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down
8 changes: 4 additions & 4 deletions tests/loss/test_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_global_loss_reduction(self, loss_mode, grad_accumulation_steps, chunk_s
for data in data_batch:
seq_ctx = data["seq_ctx"]
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None)
loss_ctx_list.append(loss_ctx)
loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list])

Expand Down Expand Up @@ -172,7 +172,7 @@ def test_other_loss_reduction(self, loss_reduction, loss_mode, grad_accumulation
for data in data_batch:
seq_ctx = data["seq_ctx"]
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None)
loss_ctx_list.append(loss_ctx)
loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list])

Expand Down Expand Up @@ -310,7 +310,7 @@ def test_sp_global_loss_reduction(self, loss_mode, sp_size, grad_accumulation_st
sp_mesh = data_mesh['sp']
seq_ctx.sequence_parallel_mesh = sp_mesh
seq_ctx_list = [seq_ctx]
loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
if sp_size > 1:
seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh)
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_sp_others_loss_reduction(self, loss_reduction, loss_mode, sp_size, grad
sp_mesh = data_mesh['sp']
seq_ctx.sequence_parallel_mesh = sp_mesh
seq_ctx_list = [seq_ctx]
loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
if sp_size > 1:
seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh)
Expand Down
2 changes: 1 addition & 1 deletion tests/loss/test_grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size,
if sp_size > 1:
seq_ctx = seq_ctx.split(sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels_list_rank[iter_idx], advantages=advantages_list_rank[iter_idx], sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh)
loss_ctx_list.append(loss_ctx)

with torch.no_grad():
Expand Down
3 changes: 1 addition & 2 deletions tests/loss/test_oreal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size,
seq_ctx = seq_ctx.split(sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(
shifted_labels=shifted_labels_list_rank[iter_idx],
advantages=advantages_list_rank[iter_idx],
data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]},
sp_mesh=sp_mesh,
)
loss_ctx_list.append(loss_ctx)
Expand Down
8 changes: 4 additions & 4 deletions tests/model/test_gpt_oss_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
loss_cfg = CELossConfig()
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -87,7 +87,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
with torch.no_grad():
output = gpt_oss_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
loss_cfg = CELossConfig()
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -152,7 +152,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
with torch.no_grad():
output = gpt_oss_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2))
Expand Down
16 changes: 8 additions & 8 deletions tests/model/test_intern_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_interns1_text_run(self, device, tol):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -87,7 +87,7 @@ def test_interns1_text_run(self, device, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_interns1_image_run(self, device, sp_size, tol):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -195,7 +195,7 @@ def test_interns1_image_run(self, device, sp_size, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_fsdp_text_accuracy(self, device, tol):
seq_ctx_list = [seq_ctx]
loss_cfg = CELossConfig()
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -265,7 +265,7 @@ def test_fsdp_text_accuracy(self, device, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -370,7 +370,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
seq_ctx_list = [seq_ctx]
loss_cfg = CELossConfig()
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -379,7 +379,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down
10 changes: 5 additions & 5 deletions tests/model/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def test_moe_config(self, dtype, device):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})


class TestDistributedMoE(DeterministicDDPTestCase):
Expand Down Expand Up @@ -135,15 +135,15 @@ def test_parallel_accuracy(self, dtype, device, dispatcher, n_shared_experts, fi

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]

loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"]
loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"]

loss_expected = model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"]
loss_expected = model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"]

torch.allclose(loss_expected, loss_parallel, atol=1e-6, rtol=1e-4)

Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _forward(self, model, type, device, sp_size):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand Down
12 changes: 6 additions & 6 deletions tests/model/test_qwen3_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class):
loss_cfg = CELossConfig()
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -75,7 +75,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class):
with torch.no_grad():
output = qwen_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_fsdp_accuracy(self, device, tp_size):
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -131,7 +131,7 @@ def test_fsdp_accuracy(self, device, tp_size):
with torch.no_grad():
output = qwen_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2))
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi
seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),))
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -207,7 +207,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi
with torch.no_grad():
output = qwen_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
assert "loss" in output

Expand Down
Loading
Loading