Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xtuner/v1/datasets/sft_tokenize_fn/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __call__(self, item: dict | list, **kwargs) -> DataItem | CacheItem:
tools = item["tools"]
if isinstance(item, dict) and "messages" in item:
item = item["messages"]
if isinstance(item, dict) and "dialogs" in item:
item = item["dialogs"]
Comment on lines +45 to +46
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — This change (adding "dialogs" key support) appears unrelated to the full graph mode feature for Qwen3.5. Per CLAUDE.md: "One logical change per PR. Do not mix bug fixes with features or refactors." Consider splitting this into a separate PR.

messages = ChatMessages(messages=item, tools=tools)
tokenized = messages.tokenize(self.tokenizer, self.chat_template)

Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/moe/qwen3_5_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
),
"xtuner.v1.module.attention.mha.MultiHeadAttention.forward": TorchCompileOption(fullgraph=True),
# TODO: GatedDeltaNet does not currently support torch.compile(full_graph=True); support will be added in the future.
"xtuner.v1.module.attention.gated_deltanet.GatedDeltaNet.forward": TorchCompileOption(fullgraph=False),
"xtuner.v1.module.attention.gated_deltanet.GatedDeltaNet.forward": TorchCompileOption(fullgraph=True),
"xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._shared_experts_forward": TorchCompileOption(
fullgraph=True
),
Expand Down
33 changes: 9 additions & 24 deletions xtuner/v1/module/attention/gated_deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from xtuner.v1.float8.config import Float8Config
from xtuner.v1.utils import get_logger

from ...ops.gated_deltanet.gen_seq_idx import gen_seq_idx
from ...ops.gated_deltanet.chunk_gated_delta_rule import chunk_gated_delta_rule
from ...ops.gated_deltanet.causal_conv1d import causal_conv1d_fn
from ...ops.gated_deltanet.rms_norm_gated import rms_norm_gated
from ..linear import build_linear
from .attn_outputs import AttnOutputs


try:
from fla.modules import FusedRMSNormGated as FLA_FusedRMSNormGated
from fla.modules.fused_norm_gate import rms_norm_gated
from fla.ops.gated_delta_rule import chunk_gated_delta_rule

class FusedRMSNormGated(FLA_FusedRMSNormGated):
def forward(
Expand Down Expand Up @@ -50,12 +52,6 @@ def forward(

except ImportError:
FusedRMSNormGated = None # type: ignore
chunk_gated_delta_rule = None

try:
from causal_conv1d import causal_conv1d_fn
except ImportError:
causal_conv1d_fn = None

logger = get_logger()

Expand Down Expand Up @@ -132,13 +128,7 @@ def __init__(
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))

assert causal_conv1d_fn is not None, (
"causal_conv1d_fn is not available. Please install causal-conv1d to use GatedDeltaNet by `https://github.com/Dao-AILab/causal-conv1d`."
)
self.causal_conv1d_fn = causal_conv1d_fn
assert chunk_gated_delta_rule is not None, (
"chunk_gated_delta_rule is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`."
)
self.chunk_gated_delta_rule = chunk_gated_delta_rule
assert FusedRMSNormGated is not None, (
"FusedRMSNormGated is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`."
Expand Down Expand Up @@ -176,7 +166,6 @@ def forward(
batch_size, seq_len, _ = hidden_states.shape
assert batch_size == 1, "Only batch size of 1 is supported for now in GateDeltaNet"
mixed_qkv = self.in_proj_qkv(hidden_states)
mixed_qkv = mixed_qkv.transpose(1, 2)

z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
Expand All @@ -191,27 +180,23 @@ def forward(
if bias and isinstance(bias, DTensor):
bias = bias.to_local()

# TODO: If full_graph mode is supported in the future, it needs to be modified to custom_op
if seq_ctx.seq_idx is None:
seq_idx = torch.cat(
[
torch.full((s,), i, dtype=torch.int32, device=mixed_qkv.device)
for i, s in enumerate(seq_ctx.seq_lens_q)
],
dim=0,
)[None]
# Use Triton kernel to fill seq_idx based on cu_seq_lens_q
# Pre-allocate empty tensor with shape (1, seq_len)
seq_idx = gen_seq_idx(seq_len, seq_ctx.cu_seq_lens_q)
seq_ctx.seq_idx = cast(torch.IntTensor, seq_idx)
else:
seq_idx = seq_ctx.seq_idx

# mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=weight,
bias=bias,
activation=self.activation,
seq_idx=seq_idx,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
# mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
Comment on lines +199 to 200
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Commented-out code should be removed rather than left in. Per CLAUDE.md: "Avoid backwards-compatibility hacks like ... adding // removed comments for removed code." If the transposes are no longer needed (because the custom op handles them internally), just delete these lines.

Suggested change
# mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(

mixed_qkv,
[
Expand Down
270 changes: 270 additions & 0 deletions xtuner/v1/ops/gated_deltanet/causal_conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright (c) 2024, Tri Dao.

import torch
import torch.nn.functional as F

import causal_conv1d_cuda


LIBRARY_NAME = "DaoAILab"


@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"})
def _causal_conv1d_fwd_cpp(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
out: torch.Tensor,
final_states_out: torch.Tensor | None,
silu_activation: bool,
) -> None:
if seq_idx is not None:
# If seq_idx is provided, we must use channel last layout
x = x.transpose(1, 2)
out = out.transpose(1, 2)
causal_conv1d_cuda.causal_conv1d_fwd(
x,
weight,
bias,
seq_idx,
initial_states,
out,
final_states_out,
silu_activation,
)


@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={
"dfinal_states",
"dx",
"dweight",
"dbias",
"dinitial_states",
})
def _causal_conv1d_bwd_cpp(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
dout: torch.Tensor,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
dfinal_states: torch.Tensor | None,
dx: torch.Tensor,
dweight: torch.Tensor,
dbias: torch.Tensor | None,
dinitial_states: torch.Tensor,
silu_activation: bool,
) -> None:
if seq_idx is not None:
# If seq_idx is provided, we must use channel last layout
x = x.transpose(1, 2)
dout = dout.transpose(1, 2)
dx = dx.transpose(1, 2)
causal_conv1d_cuda.causal_conv1d_bwd(
x,
weight,
bias,
dout,
seq_idx,
initial_states,
dfinal_states,
dx,
dweight,
dbias,
dinitial_states,
silu_activation,
)

def causal_conv1d_fwd_function(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
final_states_out: torch.Tensor | None,
silu_activation: bool,
) -> torch.Tensor:
out = torch.empty_like(x)
_causal_conv1d_fwd_cpp(
x=x,
weight=weight,
bias=bias,
seq_idx=seq_idx,
initial_states=initial_states,
out=out,
final_states_out=final_states_out,
silu_activation=silu_activation,
)
return out

def causal_conv1d_bwd_function(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
dout: torch.Tensor,
seq_idx: torch.Tensor | None,
initial_states: torch.Tensor | None,
dfinal_states: torch.Tensor | None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Criticalreturn_dinitial_states is typed as torch.Tensor but is used as a bool (line 124: if return_dinitial_states:). This incorrect type hint will cause issues with torch.compile schema inference.

Suggested change
dfinal_states: torch.Tensor | None,
return_dinitial_states: bool,

dx: torch.Tensor | None,
return_dinitial_states: torch.Tensor,
silu_activation: bool,
) -> tuple[torch.Tensor | None]:
if seq_idx is None:
batch_size, dim = x.size()[:2]
else:
batch_size, _, dim = x.size()[:3]
width = weight.size(-1)

if dx is None:
dx = torch.empty_like(x)
dweight = torch.zeros_like(weight, dtype=torch.float32)
dbias = None
if bias is not None:
dbias = torch.zeros_like(bias, dtype=torch.float32)
dinitial_states = None
if return_dinitial_states:
dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)

_causal_conv1d_bwd_cpp(
x=x,
weight=weight,
bias=bias,
dout=dout,
seq_idx=seq_idx,
initial_states=initial_states,
dfinal_states=dfinal_states,
dx=dx,
dweight=dweight,
dbias=dbias,
dinitial_states=dinitial_states,
silu_activation=silu_activation,
)

dweight = dweight.type_as(weight)
if dbias is not None:
dbias = dbias.type_as(bias)
return dx, dweight, dbias, dinitial_states

class CausalConv1dFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
weight,
bias=None,
seq_idx=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
):
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (
initial_states is None
), "initial_states must be None if seq_idx is not None"
assert (
not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(
batch, width - 1, dim, device=x.device, dtype=x.dtype
).transpose(1, 2)
else:
final_states_out = None
ctx.activation = activation in ["silu", "swish"]
out = causal_conv1d_fwd_function(
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
)
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
ctx.return_final_states = return_final_states
ctx.return_dinitial_states = (
initial_states is not None and initial_states.requires_grad
)
return out if not return_final_states else (out, final_states_out)

@staticmethod
def backward(ctx, dout, *args):
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
dfinal_states = args[0] if ctx.return_final_states else None
if dout.stride(2) != 1 and dout.stride(1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
# Here we just pass in None and dx will be allocated in the C++ code.
dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
x,
weight,
bias,
dout,
seq_idx,
initial_states,
dfinal_states,
None,
ctx.return_dinitial_states,
ctx.activation,
)
return (
dx,
dweight,
dbias if bias is not None else None,
None,
dinitial_states if initial_states is not None else None,
None,
None,
None,
)


def causal_conv1d_fn(
x,
weight,
bias=None,
seq_idx=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"

out: (batch, dim, seqlen)
"""
return CausalConv1dFn.apply(
x,
weight,
bias,
seq_idx,
initial_states,
return_final_states,
final_states_out,
activation,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — Missing newline at end of file.

Loading
Loading