From 3f0c5f7c24c0a26e7165e054cd8000e18b931799 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 18 Mar 2024 14:59:44 +0100 Subject: [PATCH] new awq kernels paths (#2572) * new awq kernels paths --- onmt/modules/moe.py | 4 +++- onmt/modules/rmsnorm.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/onmt/modules/moe.py b/onmt/modules/moe.py index f356130d97..a8ceba6cde 100644 --- a/onmt/modules/moe.py +++ b/onmt/modules/moe.py @@ -60,7 +60,9 @@ def forward(self, x): y = torch.empty_like(x) for i, expert in enumerate(self.experts): if torch.any(flat_expert_indices == i): - y[flat_expert_indices == i] = expert(x[flat_expert_indices == i]) + y[flat_expert_indices == i] = expert( + x[flat_expert_indices == i].unsqueeze(0) + ) y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum( dim=1 ) diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index a25d08b27e..fb8cc29065 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -4,11 +4,11 @@ import torch.nn as nn try: - import awq_inference_engine + import awq_ext - AWQ_INFERENCE_ENGINE = True + AWQ_EXT = True except ImportError: - AWQ_INFERENCE_ENGINE = False + AWQ_EXT = False class RMSNorm(torch.nn.Module): @@ -24,12 +24,12 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, hidden_states): - if AWQ_INFERENCE_ENGINE and not self.training: + if AWQ_EXT and not self.training: inp_type = hidden_states.dtype output = torch.empty_like(hidden_states).to(inp_type) if hidden_states.dim() == 2: # patch for multi experts hidden_states = hidden_states.unsqueeze(0) - awq_inference_engine.layernorm_forward_cuda( + awq_ext.layernorm_forward_cuda( hidden_states.half(), self.weight.half(), output.half(), self.eps ) if hidden_states.dim() == 2: # patch for multi experts