Skip to content

Commit 8e33cb0

Browse files
authored
LLM: support speecht5_tts (#10077)
* support speecht5_tts * fix
1 parent 7382757 commit 8e33cb0

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

python/llm/src/bigdl/llm/transformers/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
238238
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
239239
.to(device)
240240
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
241+
if in_features % 64 != 0:
242+
# now our kernel requires in_features is a multiple of 64
243+
continue
241244
new_linear = LowBitLinear(
242245
in_features,
243246
out_features,

python/llm/src/bigdl/llm/transformers/low_bit_linear.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,11 @@ def forward(self, x: torch.Tensor):
478478
if x_2d.is_contiguous() is False:
479479
x_2d = x_2d.contiguous()
480480

481-
input_seq_size = x_shape[1]
481+
if len(x_shape) == 3:
482+
input_seq_size = x_shape[1]
483+
elif len(x_shape) < 3:
484+
input_seq_size = 1
485+
482486
if is_training:
483487
# training path
484488
if x_2d.requires_grad:

0 commit comments

Comments
 (0)