Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
efficient_forward_tokens += (num_tokens.long() ** 2).sum()
total_forward_tokens += (num_tokens.long().sum()) ** 2

# todo: support intra_layer_micro_batch
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])
if len(seq_ctx_list) == 1:
output = self.model(seq_ctx=seq_ctx_list[0],loss_ctx=loss_ctx_list[0])
else:
output = self.model(seq_ctx=seq_ctx_list,loss_ctx=loss_ctx_list)

# llm loss has been global averaged
llm_loss = output["loss"]
step_llm_loss += llm_loss.detach().clone()
Expand Down Expand Up @@ -208,25 +211,25 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
maxvio = maxvio_all_layers.mean()
if moe_need_update_bias:
self.model.language_model.update_bias(tokens_per_expert_global_for_bias, avg_count_load) # type: ignore
other_log["maxvio"] = maxvio.item()
other_log["maxvio"] = maxvio

reduced_llm_loss = step_llm_loss
dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size()))

loss_log["local_loss"] = step_loss.item()
loss_log["reduced_llm_loss"] = reduced_llm_loss.item()
loss_log["local_loss"] = step_loss
loss_log["reduced_llm_loss"] = reduced_llm_loss
if step_balancing_loss is not None:
reduced_balancing_loss = step_balancing_loss
dist.all_reduce(reduced_balancing_loss.div_(dist.get_world_size()))
loss_log["reduced_balancing_loss"] = reduced_balancing_loss.item()
loss_log["reduced_balancing_loss"] = reduced_balancing_loss
if step_z_loss is not None:
reduced_z_loss = step_z_loss
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
loss_log["reduced_z_loss"] = reduced_z_loss.item()
loss_log["reduced_z_loss"] = reduced_z_loss

other_log["step_consumed_tokens"] = int(step_consumed_tokens.item())
other_log["step_consumed_tokens"] = int(step_consumed_tokens)
other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment]
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens)
other_log["step_consumed_img_tokens"] = int(step_consumed_img_tokens)

extra_info = other_log.get("extra_info", {}) # type: ignore
Expand Down
10 changes: 4 additions & 6 deletions xtuner/v1/loss/ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,10 @@ def loss_fn(
loss_weight = loss_weight.flatten()

rank_grad_tokens = (shifted_labels != self.loss_cfg.ignore_idx).sum()
if rank_grad_tokens == 0:
loss = logits.sum() * 0
else:
loss = F.cross_entropy(logits, shifted_labels, reduction="none", ignore_index=self.loss_cfg.ignore_idx)
# Step 2.b in the loss calculation: sum the loss over all tokens
loss = (loss * loss_weight).sum()
loss_when_zero = logits.sum() * 0
loss_normal = F.cross_entropy(logits, shifted_labels, reduction="none", ignore_index=self.loss_cfg.ignore_idx)
loss_normal = (loss_normal * loss_weight.to(loss_normal.device)).sum()
loss = torch.where(rank_grad_tokens == 0, loss_when_zero, loss_normal)

return loss, (logits, {})

Expand Down
134 changes: 82 additions & 52 deletions xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,63 +137,93 @@ def forward(
seq_ctx: SequenceContext,
loss_ctx: CELossContext
) -> MoEModelOutputs:
input_ids = seq_ctx.input_ids
pixel_values = seq_ctx.pixel_values
image_grid_thw = seq_ctx.image_grid_thw
sequence_parallel_mesh = seq_ctx.sequence_parallel_mesh

inputs_embeds = self.language_model.embed_tokens(input_ids) # type: ignore

if pixel_values is not None:
assert image_grid_thw is not None
assert input_ids is not None
visual_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values,
image_grid_thw,
sequence_parallel_mesh)
try:
# To simplify and facilitate the processing of deepstack_visual_embeds inside language_model,
# we all-gather visual_embeds, and then split them based on the input_ids,
# then non-uniformly split them based on the input_ids
visual_pos_masks, visual_features, deepstack_visual_embeds = self.get_placeholder_mask(
input_ids,
visual_features=visual_embeds,
deepstack_visual_embeds=deepstack_visual_embeds,
sequence_parallel_mesh=sequence_parallel_mesh,
origin_pixel_len=pixel_values.size(0)
)
inputs_embeds[visual_pos_masks] = inputs_embeds[visual_pos_masks] * 0.0 + visual_features
except Exception as e:
logger.error(f"!!!Warning: {e}, but continue anyway!!!!")
inputs_embeds = inputs_embeds + visual_embeds.sum() * 0.0
def _prepare_llm_inputs(single_seq_ctx):
sp_mesh = single_seq_ctx.sequence_parallel_mesh
img_thw = single_seq_ctx.image_grid_thw
pixel_val = single_seq_ctx.pixel_values
input_ids = single_seq_ctx.input_ids
inputs_embeds = self.language_model.embed_tokens(input_ids) # type: ignore
if pixel_val is not None:
assert img_thw is not None
assert input_ids is not None
visual_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_val,
img_thw,
sp_mesh)
try:
# To simplify and facilitate the processing of deepstack_visual_embeds inside language_model,
# we all-gather visual_embeds, and then split them based on the input_ids,
# then non-uniformly split them based on the input_ids
visual_pos_masks, visual_features, deepstack_visual_embeds = self.get_placeholder_mask(
input_ids,
visual_features=visual_embeds,
deepstack_visual_embeds=deepstack_visual_embeds,
sequence_parallel_mesh=sp_mesh,
origin_pixel_len=pixel_val.size(0)
)
inputs_embeds[visual_pos_masks] = inputs_embeds[visual_pos_masks] * 0.0 + visual_features
except Exception as e:
logger.error(f"!!!Warning: {e}, but continue anyway!!!!")
inputs_embeds = inputs_embeds + visual_embeds.sum() * 0.0
for deepstack_visual_embed in deepstack_visual_embeds:
inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0
deepstack_visual_embeds = None
visual_pos_masks = None
else:
pixel_values_dump = torch.randn(4, 1536, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
img_thw = torch.tensor([[1, 2, 2]], device=inputs_embeds.device)
viusal_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values_dump, img_thw)
inputs_embeds = inputs_embeds + viusal_embeds.sum() * 0.0
for deepstack_visual_embed in deepstack_visual_embeds:
inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0

deepstack_visual_embeds = None
visual_pos_masks = None

if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0:
assert single_seq_ctx.position_ids is not None
assert single_seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \
f" but got {single_seq_ctx.position_ids.ndim}"

deepstack_visual_embeds = None
visual_pos_masks = None
return inputs_embeds

if isinstance(seq_ctx, list):
lang_seq_ctx = []
for single_seq_ctx in seq_ctx:
inputs_embeds = _prepare_llm_inputs(single_seq_ctx)
lang_seq_ctx.append(
SequenceContext(
input_ids=None,
cu_seq_lens_q=single_seq_ctx.cu_seq_lens_q,
cu_seq_lens_k=single_seq_ctx.cu_seq_lens_k,
max_length_q=single_seq_ctx.max_length_q,
max_length_k=single_seq_ctx.max_length_k,
position_ids=single_seq_ctx.position_ids,
num_padding=single_seq_ctx.num_padding,
sequence_parallel_mesh=single_seq_ctx.sequence_parallel_mesh,
inputs_embeds=inputs_embeds,
rollout_routed_experts=single_seq_ctx.rollout_routed_experts,
deepstack_visual_embeds=None,
visual_pos_masks=None
)
)
elif isinstance(seq_ctx, SequenceContext):
inputs_embeds = _prepare_llm_inputs(seq_ctx)
lang_seq_ctx = SequenceContext(input_ids=None,
cu_seq_lens_q=seq_ctx.cu_seq_lens_q,
cu_seq_lens_k=seq_ctx.cu_seq_lens_k,
max_length_q=seq_ctx.max_length_q,
max_length_k=seq_ctx.max_length_k,
position_ids=seq_ctx.position_ids,
num_padding=seq_ctx.num_padding,
sequence_parallel_mesh=seq_ctx.sequence_parallel_mesh,
inputs_embeds=inputs_embeds,
rollout_routed_experts=seq_ctx.rollout_routed_experts,
deepstack_visual_embeds=None,
visual_pos_masks=None)
else:
pixel_values_dump = torch.randn(4, 1536, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
image_grid_thw = torch.tensor([[1, 2, 2]], device=inputs_embeds.device)
viusal_embeds, deepstack_visual_embeds = self.get_visual_features(pixel_values_dump, image_grid_thw)
inputs_embeds = inputs_embeds + viusal_embeds.sum() * 0.0
for deepstack_visual_embed in deepstack_visual_embeds:
inputs_embeds = inputs_embeds + deepstack_visual_embed.sum() * 0.0

deepstack_visual_embeds = None
visual_pos_masks = None

if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0:
assert seq_ctx.position_ids is not None
assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \
f" but got {seq_ctx.position_ids.ndim}"
deepstack_visual_embeds = None
visual_pos_masks = None

# NOTE: 一定不要原地覆盖,否则第二次 forward 会缺少数据
lang_seq_ctx = seq_ctx.copy(
input_ids=None,
inputs_embeds=inputs_embeds,
deepstack_visual_embeds=deepstack_visual_embeds,
visual_pos_masks=visual_pos_masks,
)
raise NotImplementedError
outputs = self.language_model(
lang_seq_ctx,
loss_ctx
Expand Down
Loading