Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
fa2c2d2
add num_kv_splits_indptr to mla for mtp<=4 case for now
valarLip Jun 26, 2025
15f6155
update
valarLip Jun 27, 2025
8dd5617
update new kernel
valarLip Jul 1, 2025
c871e8d
infrastructures
ruanjm Jul 14, 2025
3750b5f
1st version of split kernel
ruanjm Jul 16, 2025
7ca2598
Fix issues raised by Lingpeng and fix the issue on batch_size
ruanjm Jul 16, 2025
7c5891c
update mla
valarLip Jul 16, 2025
12def78
update mla_stage2
valarLip Jul 18, 2025
5dc5a6d
Merge branch 'main' into mla_splitkv_enhance
valarLip Jul 18, 2025
eae14ae
Merge branch 'main' into mla_splitkv_enhance
valarLip Jul 18, 2025
f244f11
Merge branch 'mla_splitkv_enhance' into jruan/mla_splitkv_enhance_spl…
ruanjm Jul 22, 2025
224f89f
1st draft of v1 split program
ruanjm Jul 22, 2025
ef442fd
add kv_offset
ruanjm Jul 28, 2025
f10235e
mla_splitkv_enhance_split_alg_inte
Zzz9990 Jul 29, 2025
600b5dd
splitkv debug
Zzz9990 Jul 29, 2025
5c58ae8
1st version of reduce kernel
ruanjm Jul 29, 2025
9700bc5
metadata & kernel finish
Zzz9990 Jul 30, 2025
4a86304
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
d49c0cd
add reduce
Zzz9990 Jul 30, 2025
e4bf891
final_lse is optional now.
ruanjm Jul 30, 2025
7bf6aa4
update kernel
Zzz9990 Jul 30, 2025
2411f1f
bug fix
ruanjm Jul 30, 2025
e21600d
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
ffcc113
bug fix 1
ruanjm Jul 30, 2025
07e4ed1
modify reduce api
Zzz9990 Jul 30, 2025
3f2bf25
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
7c877c4
update kernel
Zzz9990 Jul 30, 2025
d10cdab
fix max splits
Zzz9990 Jul 30, 2025
bac5750
bug fix 3
ruanjm Jul 30, 2025
f59a3e6
fix s80 early return
Zzz9990 Jul 30, 2025
1ae58d1
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
5680c26
udpate calculation of partial_indx
ruanjm Jul 30, 2025
fa87c91
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 31, 2025
0dad74c
add per split test
Zzz9990 Jul 31, 2025
a8fa0b1
make lse support by ref
ruanjm Jul 31, 2025
56e964f
test split
Zzz9990 Jul 31, 2025
a76610a
fix redundant calculation of head offset in reduce kernel
ruanjm Jul 31, 2025
4ffd393
add custom test
Zzz9990 Jul 31, 2025
b3747df
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 31, 2025
ba36541
Add support of 128 head size
ruanjm Jul 31, 2025
e5a1b17
update comments
ruanjm Aug 1, 2025
a68879c
1. Let large work be assigned first.
ruanjm Aug 1, 2025
7209c36
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
4494b36
Calculate kv_limit dynamically
ruanjm Aug 4, 2025
09c4ca8
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
1e5e71a
Fix bug about difference in split_kv(bool)
ruanjm Aug 4, 2025
f35cf04
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
f7cf2b9
add test
Zzz9990 Aug 5, 2025
5b91267
fix seed
Zzz9990 Aug 5, 2025
59af206
Add global tolerance 16 in kv seqlen because main kernel cannot handl…
ruanjm Aug 5, 2025
e1b9065
Fix warp=1 error
ruanjm Aug 8, 2025
2adf050
Add redundant mode to make the size of output of metadata be fixed ad…
ruanjm Aug 8, 2025
c0df46b
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 12, 2025
fbff664
fp8 setup
Zzz9990 Aug 12, 2025
1d36311
first version of device metadata
ruanjm Aug 12, 2025
4212a41
Add work_ptrs
ruanjm Aug 12, 2025
818229e
Compatibility to CUDA Graph
ruanjm Aug 13, 2025
704324a
Refactor code. Merge 2 iterations of generate work together.
ruanjm Aug 14, 2025
6be798a
Make sure that each batch of workload can never be splited to more th…
ruanjm Aug 14, 2025
1b0e26f
Adjust metadata. Get 1% perf gain.
ruanjm Aug 14, 2025
36e9b53
Paralize most of metadata kernel
ruanjm Aug 15, 2025
4403c82
add scale
Zzz9990 Aug 18, 2025
fcb36f0
1. Use warp-level bitonic sort to sort batch idx based on their cost …
ruanjm Aug 18, 2025
5dc1eb7
fp8 function pass
Zzz9990 Aug 19, 2025
b46a8e3
Fix issues:
ruanjm Aug 19, 2025
d8d92bc
fp8 ready
Zzz9990 Aug 19, 2025
ead163a
fix
Zzz9990 Aug 19, 2025
7fefc29
Merge remote-tracking branch 'origin/jruan/mla_splitkv_enhance_split_…
Zzz9990 Aug 19, 2025
cc7ffdc
persistent ready
Zzz9990 Aug 19, 2025
5e32d5d
add nv acc test
Zzz9990 Aug 21, 2025
a97fcf8
rename
Zzz9990 Sep 1, 2025
e0c72f8
updata metashape
Zzz9990 Sep 1, 2025
7220b04
update reduce cu num
Zzz9990 Sep 1, 2025
07bf6bb
update optest for mla
Zzz9990 Sep 1, 2025
3a7bd04
fix cu num
Zzz9990 Sep 1, 2025
88c8a0d
Update metadata and reduce kernels.
ruanjm Sep 1, 2025
7f86b0b
rename kernels
Zzz9990 Sep 1, 2025
018798d
Add new param kv_granularity to metadata kernel.
ruanjm Sep 2, 2025
3bf1623
Introduce cal_workload_limit_global_v2
ruanjm Sep 9, 2025
907dbed
Support qhead=128 cases.
ruanjm Sep 11, 2025
b2bed66
Change get_mla_metadata() api. Make some not important parameters be …
ruanjm Sep 12, 2025
a658ad8
fix potential problem on calculating tot_qo_tiles
ruanjm Sep 12, 2025
325e03f
refactor metadata files
ruanjm Sep 15, 2025
7072d90
update metadata v1_2
Zzz9990 Sep 18, 2025
851a888
update gqa_128 mla_ps & fix metadata v1_2
Zzz9990 Sep 18, 2025
b56eb25
Optimize mla metadata v1.2
ruanjm Sep 19, 2025
8ea8f73
Optimize mla metadata v1.2 Part.2
ruanjm Sep 19, 2025
9020ce8
Optimize mla metadata v1.2 Part.3
ruanjm Sep 19, 2025
59d8e33
update qlen <=4
Zzz9990 Sep 19, 2025
b401744
fix mla qlen1
Zzz9990 Sep 19, 2025
3bf8b2b
Optimize mla metadata v1.2 Part.4
ruanjm Sep 22, 2025
3f376b5
Make reduce_final_map be optional in mla_reduce_v1
ruanjm Sep 23, 2025
7c865a5
Slightly increase reduce perf
ruanjm Sep 23, 2025
8a17f56
Add persistent mode for mla reduce kernel
ruanjm Sep 24, 2025
75ebf74
add mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps.co
fangche123 Sep 28, 2025
3f67dbe
update deepseekv32 sparse mla metadata
Zzz9990 Oct 1, 2025
84e9616
update mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps.co
fangche123 Oct 9, 2025
ce9096f
Adjust code for sparse attn
ruanjm Oct 10, 2025
71abd03
Optimize the a16w8 kernel
fangche123 Oct 10, 2025
ebb2591
Improve metadata v1.1 perf
ruanjm Oct 10, 2025
9afba8f
Make metadata v1.1 support sparse attn
ruanjm Oct 10, 2025
2150d8f
Remove redundant code in mla_reduce
ruanjm Oct 11, 2025
363707f
futile struggle
ruanjm Oct 11, 2025
0b874cb
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 13, 2025
ce9abd8
Fix issue after merge. aiter main branch is using torch.library.infer…
ruanjm Oct 13, 2025
64c3e29
Adjust metadata v1.1 and make this branch be ready to be merged to ma…
ruanjm Oct 14, 2025
57b9d57
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 14, 2025
b70d8d4
remove invalid co kernel
Zzz9990 Oct 14, 2025
f668d60
Fix issue brought from f794ae4 which disabled hipify by default.
ruanjm Oct 14, 2025
33ea0e8
support qolen>1 for sparse mla
Zzz9990 Oct 14, 2025
6e2c4ff
make code become prettier
ruanjm Oct 14, 2025
c3813fb
Fix issue in metadata v1.1
ruanjm Oct 14, 2025
bcd219a
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 15, 2025
33b0499
Fix issue in test_mla.py
ruanjm Oct 16, 2025
53f5826
Fix lint fails
ruanjm Oct 16, 2025
41576e1
Fix sub-test fails in op_test/test_mla.py
ruanjm Oct 16, 2025
68ef089
Fix regression in test_mla.py where mtp>1
ruanjm Oct 16, 2025
f7efe97
Add head_dim=128 support to reduce
ruanjm Oct 16, 2025
8440195
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 17, 2025
1c5b77b
Add nhead=8 for pa and add assert to make sure the input tensors are in
ruanjm Oct 17, 2025
69d41a0
fix issue in vllm benchmark for deepseek: remove metadata v0 because …
ruanjm Oct 17, 2025
0cf3db2
fix lint
ruanjm Oct 17, 2025
ae96787
Revert all the change about mi350 gemm.
ruanjm Oct 17, 2025
be55ef5
add a8w8 and a16w8 kernel in mla mi350
fangche123 Oct 20, 2025
600d993
add A8W8 Non-persistent mode kernel
fangche123 Oct 21, 2025
6c7f795
Fix issue reported by Copilot
ruanjm Oct 22, 2025
573c3cd
add mla non-persistent test
fangche123 Oct 22, 2025
0cfc1a3
script: update a16w8 kernel
fangche123 Oct 23, 2025
0490f21
rm test_mla_persistent_mi350.py and support mi350 in test_mla_persist…
fangche123 Oct 24, 2025
8ca7679
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
valarLip Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiter/dist/custom_all_reduce_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
* Copyright © Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024, The vLLM team.
* Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2024-2025, The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
4 changes: 2 additions & 2 deletions aiter/dist/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
* Copyright © Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024, The vLLM team.
* Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2024-2025, The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
28 changes: 28 additions & 0 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -901,5 +901,33 @@
],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_mla_metadata": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/mla_metadata_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_comm.cuh'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_1_device.cuh'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_1_host.cuh'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_2_device.cuh'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_mla_reduce": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/mla_reduce_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/reduce.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
}
}
219 changes: 142 additions & 77 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,67 +19,80 @@ def _fwd_kernel_stage2_asm(
O,
qo_indptr,
kv_indptr,
num_kv_splits_indptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
bs,
nheads,
max_seqlen_q,
NUM_KV_SPLITS: tl.constexpr,
MAYBE_FINAL_OUT: tl.constexpr,
BATCH_NUM: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
mgc: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_qo_offs = tl.program_id(2)

cur_qo_start = tl.load(qo_indptr + cur_batch)
cur_qo_end = tl.load(qo_indptr + cur_batch + 1)
cur_qo = cur_qo_start + cur_qo_offs
if cur_qo > cur_qo_end:
return
cur_split_start = tl.load(num_kv_splits_indptr + cur_batch)
cur_split_end = tl.load(num_kv_splits_indptr + cur_batch + 1)
num_max_kv_splits = tl.load(num_kv_splits_indptr + BATCH_NUM)
cur_kv_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch)

offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv

e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)

offs_v = (cur_qo * stride_mid_ob + cur_head * stride_mid_oh) * Lv + offs_d
offs_logic = cur_qo * stride_mid_ob + cur_head * stride_mid_oh

for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.maximum(mgc, tl.cdiv(cur_kv_seq_len, NUM_KV_SPLITS))
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_kv_seq_len)

if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os * Lv,
offs_logic = cur_qo_start * stride_mid_ob + cur_head * stride_mid_oh
offs_v = offs_logic * Lv + offs_d
num_valid_kv_splits = tl.minimum(
cur_split_end - cur_split_start, tl.cdiv(cur_kv_seq_len, mgc)
)
FINAL_OUT = MAYBE_FINAL_OUT and num_max_kv_splits == BATCH_NUM

for cur_qo in range(cur_qo_start, cur_qo_end):
if FINAL_OUT:
input_ptr = Mid_O.to(tl.pointer_type(O.type.element_ty))
out = tl.load(
# input_ptr + offs_v + stride_mid_ob * Lv,
input_ptr
+ Lv * (cur_qo * stride_mid_os + cur_head * stride_mid_oh)
+ offs_d,
mask=mask_d,
other=0.0,
)
tlogic = tl.load(Mid_lse + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)

old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv

e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max

tl.store(
O + cur_qo * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
tl.store(
O + cur_qo * stride_obs + cur_head * stride_oh + offs_d,
out,
mask=mask_d,
)
else:
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
for split_kv_id in range(0, num_valid_kv_splits):
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os * Lv,
mask=mask_d,
other=0.0,
)
tlogic = tl.load(Mid_lse + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)

old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv

e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
offs_logic += stride_mid_ob
offs_v += stride_mid_ob * Lv
tl.store(
O + cur_qo * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)


@functools.lru_cache()
Expand All @@ -100,7 +113,6 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q):
for i in range(1, 17)
]
num_kv_splits = sorted(tmp, key=lambda x: x[0], reverse=True)[0][1]
# num_kv_splits = min(16, max(1, cu_num // bs))

get_mgc = {16: 16, 128: 16}

Expand All @@ -123,6 +135,15 @@ def mla_decode_fwd(
sm_scale=None, # 1.0 / (qk_head_dim**0.5)
logit_cap=0.0,
num_kv_splits=None, # for experts only!!!
num_kv_splits_indptr=None, # for experts only!!!
work_meta_data=None,
work_indptr=None,
work_info_set=None,
reduce_indptr=None,
reduce_final_map=None,
reduce_partial_map=None,
q_scale=None,
kv_scale=None,
):
device = q.device
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
Expand All @@ -134,9 +155,14 @@ def mla_decode_fwd(
bs = qo_indptr.shape[0] - 1
total_kv = kv_indices.shape[0]

num_kv_splits, mgc = get_meta_param(
num_kv_splits, bs, total_kv, nhead, max_seqlen_q
)
if num_kv_splits_indptr is None and work_meta_data is None:
num_kv_splits, mgc = get_meta_param(None, bs, total_kv, nhead, max_seqlen_q)
num_kv_splits_indptr = torch.arange(
0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device=device
)

if num_kv_splits is None:
num_kv_splits = get_cu_num()

if nhead == 16 and max_seqlen_q == 1:
# special case for 16 heads and max_seqlen_q == 1
Expand All @@ -145,22 +171,72 @@ def mla_decode_fwd(
dtype=dtypes.fp32,
device=device,
)
MAYBE_FINAL_OUT = False
elif nhead in [16, 128]:
logits = (
o.view((total_s, num_kv_splits, nhead, v_head_dim))
if num_kv_splits == 1
else torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
MAYBE_FINAL_OUT = True
logits = torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
else:
assert False, f"{nhead=} not supported"

attn_lse = torch.empty(
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)

if num_kv_splits_indptr is not None:
aiter.mla_decode_stage1_asm_fwd(
q,
kv_buffer,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_kv_splits_indptr,
None,
None,
None,
max_seqlen_q,
sm_scale,
logits,
attn_lse,
o,
q_scale,
kv_scale,
)

# if num_kv_splits == 1 and not (max_seqlen_q == 1 and nhead == 16):
# return logits.view(total_s, nhead, v_head_dim), attn_lse
Lv = v_head_dim
BLOCK_DV = triton.next_power_of_2(Lv)
grid = (bs, nhead)
extra_kargs = {"waves_per_eu": 4}

_fwd_kernel_stage2_asm[grid](
logits,
attn_lse,
o,
qo_indptr,
kv_indptr,
num_kv_splits_indptr,
attn_lse.stride(0),
attn_lse.stride(2),
attn_lse.stride(1),
o.stride(0),
o.stride(1),
MAYBE_FINAL_OUT=MAYBE_FINAL_OUT,
BATCH_NUM=bs,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
mgc=mgc,
num_warps=4,
num_stages=2,
**extra_kargs,
)
return logits, final_lse

aiter.mla_decode_stage1_asm_fwd(
q,
Expand All @@ -169,41 +245,30 @@ def mla_decode_fwd(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_kv_splits_indptr,
work_meta_data,
work_indptr,
work_info_set,
max_seqlen_q,
sm_scale,
logits,
attn_lse,
o,
q_scale,
kv_scale,
)

if num_kv_splits == 1 and not (max_seqlen_q == 1 and nhead == 16):
return logits.view(total_s, nhead, v_head_dim), attn_lse
Lv = v_head_dim
BLOCK_DV = triton.next_power_of_2(Lv)
grid = (bs, nhead, max_seqlen_q)
extra_kargs = {"waves_per_eu": 4}
_fwd_kernel_stage2_asm[grid](
aiter.mla_reduce_v1(
logits,
attn_lse,
reduce_indptr,
reduce_final_map,
reduce_partial_map,
o,
qo_indptr,
kv_indptr,
attn_lse.stride(0),
attn_lse.stride(2),
attn_lse.stride(1),
o.stride(0),
o.stride(1),
bs,
nhead,
max_seqlen_q,
NUM_KV_SPLITS=num_kv_splits,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
mgc=mgc,
num_warps=4,
num_stages=2,
**extra_kargs,
final_lse,
)
return logits, attn_lse

return logits, final_lse


def mla_prefill_fwd(
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from torch import Tensor
from ..jit.core import compile_ops
Expand Down
Loading