Skip to content

Commit 3f0c5f7

Browse files
authored
new awq kernels paths (#2572)
* new awq kernels paths
1 parent 39c984f commit 3f0c5f7

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

onmt/modules/moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def forward(self, x):
6060
y = torch.empty_like(x)
6161
for i, expert in enumerate(self.experts):
6262
if torch.any(flat_expert_indices == i):
63-
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
63+
y[flat_expert_indices == i] = expert(
64+
x[flat_expert_indices == i].unsqueeze(0)
65+
)
6466
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
6567
dim=1
6668
)

onmt/modules/rmsnorm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import torch.nn as nn
55

66
try:
7-
import awq_inference_engine
7+
import awq_ext
88

9-
AWQ_INFERENCE_ENGINE = True
9+
AWQ_EXT = True
1010
except ImportError:
11-
AWQ_INFERENCE_ENGINE = False
11+
AWQ_EXT = False
1212

1313

1414
class RMSNorm(torch.nn.Module):
@@ -24,12 +24,12 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
2424
self.weight = nn.Parameter(torch.ones(hidden_size))
2525

2626
def forward(self, hidden_states):
27-
if AWQ_INFERENCE_ENGINE and not self.training:
27+
if AWQ_EXT and not self.training:
2828
inp_type = hidden_states.dtype
2929
output = torch.empty_like(hidden_states).to(inp_type)
3030
if hidden_states.dim() == 2: # patch for multi experts
3131
hidden_states = hidden_states.unsqueeze(0)
32-
awq_inference_engine.layernorm_forward_cuda(
32+
awq_ext.layernorm_forward_cuda(
3333
hidden_states.half(), self.weight.half(), output.half(), self.eps
3434
)
3535
if hidden_states.dim() == 2: # patch for multi experts

0 commit comments

Comments
 (0)