-
Notifications
You must be signed in to change notification settings - Fork 410
[Feat] support full graph mode for qwen3.5 #1550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,14 +14,16 @@ | |||||
| from xtuner.v1.float8.config import Float8Config | ||||||
| from xtuner.v1.utils import get_logger | ||||||
|
|
||||||
| from ...ops.gated_deltanet.gen_seq_idx import gen_seq_idx | ||||||
| from ...ops.gated_deltanet.chunk_gated_delta_rule import chunk_gated_delta_rule | ||||||
| from ...ops.gated_deltanet.causal_conv1d import causal_conv1d_fn | ||||||
| from ...ops.gated_deltanet.rms_norm_gated import rms_norm_gated | ||||||
| from ..linear import build_linear | ||||||
| from .attn_outputs import AttnOutputs | ||||||
|
|
||||||
|
|
||||||
| try: | ||||||
| from fla.modules import FusedRMSNormGated as FLA_FusedRMSNormGated | ||||||
| from fla.modules.fused_norm_gate import rms_norm_gated | ||||||
| from fla.ops.gated_delta_rule import chunk_gated_delta_rule | ||||||
|
|
||||||
| class FusedRMSNormGated(FLA_FusedRMSNormGated): | ||||||
| def forward( | ||||||
|
|
@@ -50,12 +52,6 @@ def forward( | |||||
|
|
||||||
| except ImportError: | ||||||
| FusedRMSNormGated = None # type: ignore | ||||||
| chunk_gated_delta_rule = None | ||||||
|
|
||||||
| try: | ||||||
| from causal_conv1d import causal_conv1d_fn | ||||||
| except ImportError: | ||||||
| causal_conv1d_fn = None | ||||||
|
|
||||||
| logger = get_logger() | ||||||
|
|
||||||
|
|
@@ -132,13 +128,7 @@ def __init__( | |||||
| A = torch.empty(self.num_v_heads).uniform_(0, 16) | ||||||
| self.A_log = nn.Parameter(torch.log(A)) | ||||||
|
|
||||||
| assert causal_conv1d_fn is not None, ( | ||||||
| "causal_conv1d_fn is not available. Please install causal-conv1d to use GatedDeltaNet by `https://github.com/Dao-AILab/causal-conv1d`." | ||||||
| ) | ||||||
| self.causal_conv1d_fn = causal_conv1d_fn | ||||||
| assert chunk_gated_delta_rule is not None, ( | ||||||
| "chunk_gated_delta_rule is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." | ||||||
| ) | ||||||
| self.chunk_gated_delta_rule = chunk_gated_delta_rule | ||||||
| assert FusedRMSNormGated is not None, ( | ||||||
| "FusedRMSNormGated is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." | ||||||
|
|
@@ -176,7 +166,6 @@ def forward( | |||||
| batch_size, seq_len, _ = hidden_states.shape | ||||||
| assert batch_size == 1, "Only batch size of 1 is supported for now in GateDeltaNet" | ||||||
| mixed_qkv = self.in_proj_qkv(hidden_states) | ||||||
| mixed_qkv = mixed_qkv.transpose(1, 2) | ||||||
|
|
||||||
| z = self.in_proj_z(hidden_states) | ||||||
| z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) | ||||||
|
|
@@ -191,27 +180,23 @@ def forward( | |||||
| if bias and isinstance(bias, DTensor): | ||||||
| bias = bias.to_local() | ||||||
|
|
||||||
| # TODO: If full_graph mode is supported in the future, it needs to be modified to custom_op | ||||||
| if seq_ctx.seq_idx is None: | ||||||
| seq_idx = torch.cat( | ||||||
| [ | ||||||
| torch.full((s,), i, dtype=torch.int32, device=mixed_qkv.device) | ||||||
| for i, s in enumerate(seq_ctx.seq_lens_q) | ||||||
| ], | ||||||
| dim=0, | ||||||
| )[None] | ||||||
| # Use Triton kernel to fill seq_idx based on cu_seq_lens_q | ||||||
| # Pre-allocate empty tensor with shape (1, seq_len) | ||||||
| seq_idx = gen_seq_idx(seq_len, seq_ctx.cu_seq_lens_q) | ||||||
| seq_ctx.seq_idx = cast(torch.IntTensor, seq_idx) | ||||||
| else: | ||||||
| seq_idx = seq_ctx.seq_idx | ||||||
|
|
||||||
| # mixed_qkv = mixed_qkv.transpose(1, 2) | ||||||
| mixed_qkv = self.causal_conv1d_fn( | ||||||
| x=mixed_qkv, | ||||||
| weight=weight, | ||||||
| bias=bias, | ||||||
| activation=self.activation, | ||||||
| seq_idx=seq_idx, | ||||||
| ) | ||||||
| mixed_qkv = mixed_qkv.transpose(1, 2) | ||||||
| # mixed_qkv = mixed_qkv.transpose(1, 2) | ||||||
| query, key, value = torch.split( | ||||||
|
Comment on lines
+199
to
200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning — Commented-out code should be removed rather than left in. Per CLAUDE.md: "Avoid backwards-compatibility hacks like ... adding // removed comments for removed code." If the transposes are no longer needed (because the custom op handles them internally), just delete these lines.
Suggested change
|
||||||
| mixed_qkv, | ||||||
| [ | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,270 @@ | ||||||
| # Copyright (c) 2024, Tri Dao. | ||||||
|
|
||||||
| import torch | ||||||
| import torch.nn.functional as F | ||||||
|
|
||||||
| import causal_conv1d_cuda | ||||||
|
|
||||||
|
|
||||||
| LIBRARY_NAME = "DaoAILab" | ||||||
|
|
||||||
|
|
||||||
| @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"}) | ||||||
| def _causal_conv1d_fwd_cpp( | ||||||
| x: torch.Tensor, | ||||||
| weight: torch.Tensor, | ||||||
| bias: torch.Tensor | None, | ||||||
| seq_idx: torch.Tensor | None, | ||||||
| initial_states: torch.Tensor | None, | ||||||
| out: torch.Tensor, | ||||||
| final_states_out: torch.Tensor | None, | ||||||
| silu_activation: bool, | ||||||
| ) -> None: | ||||||
| if seq_idx is not None: | ||||||
| # If seq_idx is provided, we must use channel last layout | ||||||
| x = x.transpose(1, 2) | ||||||
| out = out.transpose(1, 2) | ||||||
| causal_conv1d_cuda.causal_conv1d_fwd( | ||||||
| x, | ||||||
| weight, | ||||||
| bias, | ||||||
| seq_idx, | ||||||
| initial_states, | ||||||
| out, | ||||||
| final_states_out, | ||||||
| silu_activation, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={ | ||||||
| "dfinal_states", | ||||||
| "dx", | ||||||
| "dweight", | ||||||
| "dbias", | ||||||
| "dinitial_states", | ||||||
| }) | ||||||
| def _causal_conv1d_bwd_cpp( | ||||||
| x: torch.Tensor, | ||||||
| weight: torch.Tensor, | ||||||
| bias: torch.Tensor | None, | ||||||
| dout: torch.Tensor, | ||||||
| seq_idx: torch.Tensor | None, | ||||||
| initial_states: torch.Tensor | None, | ||||||
| dfinal_states: torch.Tensor | None, | ||||||
| dx: torch.Tensor, | ||||||
| dweight: torch.Tensor, | ||||||
| dbias: torch.Tensor | None, | ||||||
| dinitial_states: torch.Tensor, | ||||||
| silu_activation: bool, | ||||||
| ) -> None: | ||||||
| if seq_idx is not None: | ||||||
| # If seq_idx is provided, we must use channel last layout | ||||||
| x = x.transpose(1, 2) | ||||||
| dout = dout.transpose(1, 2) | ||||||
| dx = dx.transpose(1, 2) | ||||||
| causal_conv1d_cuda.causal_conv1d_bwd( | ||||||
| x, | ||||||
| weight, | ||||||
| bias, | ||||||
| dout, | ||||||
| seq_idx, | ||||||
| initial_states, | ||||||
| dfinal_states, | ||||||
| dx, | ||||||
| dweight, | ||||||
| dbias, | ||||||
| dinitial_states, | ||||||
| silu_activation, | ||||||
| ) | ||||||
|
|
||||||
| def causal_conv1d_fwd_function( | ||||||
| x: torch.Tensor, | ||||||
| weight: torch.Tensor, | ||||||
| bias: torch.Tensor | None, | ||||||
| seq_idx: torch.Tensor | None, | ||||||
| initial_states: torch.Tensor | None, | ||||||
| final_states_out: torch.Tensor | None, | ||||||
| silu_activation: bool, | ||||||
| ) -> torch.Tensor: | ||||||
| out = torch.empty_like(x) | ||||||
| _causal_conv1d_fwd_cpp( | ||||||
| x=x, | ||||||
| weight=weight, | ||||||
| bias=bias, | ||||||
| seq_idx=seq_idx, | ||||||
| initial_states=initial_states, | ||||||
| out=out, | ||||||
| final_states_out=final_states_out, | ||||||
| silu_activation=silu_activation, | ||||||
| ) | ||||||
| return out | ||||||
|
|
||||||
| def causal_conv1d_bwd_function( | ||||||
| x: torch.Tensor, | ||||||
| weight: torch.Tensor, | ||||||
| bias: torch.Tensor | None, | ||||||
| dout: torch.Tensor, | ||||||
| seq_idx: torch.Tensor | None, | ||||||
| initial_states: torch.Tensor | None, | ||||||
| dfinal_states: torch.Tensor | None, | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Critical —
Suggested change
|
||||||
| dx: torch.Tensor | None, | ||||||
| return_dinitial_states: torch.Tensor, | ||||||
| silu_activation: bool, | ||||||
| ) -> tuple[torch.Tensor | None]: | ||||||
| if seq_idx is None: | ||||||
| batch_size, dim = x.size()[:2] | ||||||
| else: | ||||||
| batch_size, _, dim = x.size()[:3] | ||||||
| width = weight.size(-1) | ||||||
|
|
||||||
| if dx is None: | ||||||
| dx = torch.empty_like(x) | ||||||
| dweight = torch.zeros_like(weight, dtype=torch.float32) | ||||||
| dbias = None | ||||||
| if bias is not None: | ||||||
| dbias = torch.zeros_like(bias, dtype=torch.float32) | ||||||
| dinitial_states = None | ||||||
| if return_dinitial_states: | ||||||
| dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) | ||||||
|
|
||||||
| _causal_conv1d_bwd_cpp( | ||||||
| x=x, | ||||||
| weight=weight, | ||||||
| bias=bias, | ||||||
| dout=dout, | ||||||
| seq_idx=seq_idx, | ||||||
| initial_states=initial_states, | ||||||
| dfinal_states=dfinal_states, | ||||||
| dx=dx, | ||||||
| dweight=dweight, | ||||||
| dbias=dbias, | ||||||
| dinitial_states=dinitial_states, | ||||||
| silu_activation=silu_activation, | ||||||
| ) | ||||||
|
|
||||||
| dweight = dweight.type_as(weight) | ||||||
| if dbias is not None: | ||||||
| dbias = dbias.type_as(bias) | ||||||
| return dx, dweight, dbias, dinitial_states | ||||||
|
|
||||||
| class CausalConv1dFn(torch.autograd.Function): | ||||||
| @staticmethod | ||||||
| def forward( | ||||||
| ctx, | ||||||
| x, | ||||||
| weight, | ||||||
| bias=None, | ||||||
| seq_idx=None, | ||||||
| initial_states=None, | ||||||
| return_final_states=False, | ||||||
| final_states_out=None, | ||||||
| activation=None, | ||||||
| ): | ||||||
| if activation not in [None, "silu", "swish"]: | ||||||
| raise NotImplementedError("activation must be None, silu, or swish") | ||||||
| if x.stride(2) != 1 and x.stride(1) != 1: | ||||||
| x = x.contiguous() | ||||||
| bias = bias.contiguous() if bias is not None else None | ||||||
| if seq_idx is not None: | ||||||
| assert ( | ||||||
| initial_states is None | ||||||
| ), "initial_states must be None if seq_idx is not None" | ||||||
| assert ( | ||||||
| not return_final_states | ||||||
| ), "If seq_idx is not None, we don't return final_states_out" | ||||||
| seq_idx = seq_idx.contiguous() if seq_idx is not None else None | ||||||
| if initial_states is not None and ( | ||||||
| initial_states.stride(2) != 1 and initial_states.stride(1) != 1 | ||||||
| ): | ||||||
| initial_states = initial_states.contiguous() | ||||||
| if return_final_states: | ||||||
| assert ( | ||||||
| x.stride(1) == 1 | ||||||
| ), "Only channel-last layout support returning final_states_out" | ||||||
| if final_states_out is not None: | ||||||
| assert ( | ||||||
| final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 | ||||||
| ) | ||||||
| else: | ||||||
| batch, dim, seqlen = x.shape | ||||||
| width = weight.shape[1] | ||||||
| final_states_out = torch.empty( | ||||||
| batch, width - 1, dim, device=x.device, dtype=x.dtype | ||||||
| ).transpose(1, 2) | ||||||
| else: | ||||||
| final_states_out = None | ||||||
| ctx.activation = activation in ["silu", "swish"] | ||||||
| out = causal_conv1d_fwd_function( | ||||||
| x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation | ||||||
| ) | ||||||
| ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) | ||||||
| ctx.return_final_states = return_final_states | ||||||
| ctx.return_dinitial_states = ( | ||||||
| initial_states is not None and initial_states.requires_grad | ||||||
| ) | ||||||
| return out if not return_final_states else (out, final_states_out) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def backward(ctx, dout, *args): | ||||||
| x, weight, bias, seq_idx, initial_states = ctx.saved_tensors | ||||||
| dfinal_states = args[0] if ctx.return_final_states else None | ||||||
| if dout.stride(2) != 1 and dout.stride(1) != 1: | ||||||
| dout = dout.contiguous() | ||||||
| # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the | ||||||
| # backward of conv1d with the backward of chunk). | ||||||
| # Here we just pass in None and dx will be allocated in the C++ code. | ||||||
| dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( | ||||||
| x, | ||||||
| weight, | ||||||
| bias, | ||||||
| dout, | ||||||
| seq_idx, | ||||||
| initial_states, | ||||||
| dfinal_states, | ||||||
| None, | ||||||
| ctx.return_dinitial_states, | ||||||
| ctx.activation, | ||||||
| ) | ||||||
| return ( | ||||||
| dx, | ||||||
| dweight, | ||||||
| dbias if bias is not None else None, | ||||||
| None, | ||||||
| dinitial_states if initial_states is not None else None, | ||||||
| None, | ||||||
| None, | ||||||
| None, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def causal_conv1d_fn( | ||||||
| x, | ||||||
| weight, | ||||||
| bias=None, | ||||||
| seq_idx=None, | ||||||
| initial_states=None, | ||||||
| return_final_states=False, | ||||||
| final_states_out=None, | ||||||
| activation=None, | ||||||
| ): | ||||||
| """ | ||||||
| x: (batch, dim, seqlen) | ||||||
| weight: (dim, width) | ||||||
| bias: (dim,) | ||||||
| seq_idx: (batch, seqlen) | ||||||
| initial_states: (batch, dim, width - 1) | ||||||
| final_states_out: (batch, dim, width - 1), to be written to | ||||||
| activation: either None or "silu" or "swish" | ||||||
|
|
||||||
| out: (batch, dim, seqlen) | ||||||
| """ | ||||||
| return CausalConv1dFn.apply( | ||||||
| x, | ||||||
| weight, | ||||||
| bias, | ||||||
| seq_idx, | ||||||
| initial_states, | ||||||
| return_final_states, | ||||||
| final_states_out, | ||||||
| activation, | ||||||
| ) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit — Missing newline at end of file. |
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Warning — This change (adding
"dialogs"key support) appears unrelated to the full graph mode feature for Qwen3.5. Per CLAUDE.md: "One logical change per PR. Do not mix bug fixes with features or refactors." Consider splitting this into a separate PR.