Skip to content

Commit 08f3a14

Browse files
authored
Fix bugs for minicpm3 on orange-pi (#2020)
1 parent 22221f4 commit 08f3a14

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

mindnlp/core/nn/modules/module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ def train(self, mode=True):
12251225
Module: self
12261226
"""
12271227
if ON_ORANGE_PI:
1228-
set_pyboost(not mode)
1228+
set_pyboost(False)
12291229
self.training = mode
12301230
for module in self.children():
12311231
module.train(mode)

mindnlp/transformers/generation/logits_process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def tf_like_call(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) ->
512512
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
513513
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
514514
scatter_rows = ops.tile(ops.unsqueeze(ops.range(topk_indices.shape[0]), dim=-1), (1, topk_indices.shape[-1]))
515-
scatter_indices = ops.stack((scatter_rows, topk_indices), dim=-1)
515+
scatter_indices = ops.stack((scatter_rows.to(topk_indices.dtype), topk_indices), dim=-1)
516516
next_scores = ops.tf_scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
517517

518518
return next_scores

mindnlp/transformers/models/minicpm3/modeling_minicpm3.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mindnlp.core import nn, ops
3131
from mindnlp.core.nn import functional as F
3232
from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33+
from mindnlp.configs import ON_ORANGE_PI
3334
from ....common.activations import ACT2FN
3435
from ...cache_utils import Cache, DynamicCache
3536
from ...modeling_attn_mask_utils import (
@@ -265,7 +266,10 @@ def __init__(self, config):
265266
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
266267
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
267268
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
268-
self.act_fn = ACT2FN[config.hidden_act]
269+
if ON_ORANGE_PI:
270+
self.act_fn = mindspore.ops.silu
271+
else:
272+
self.act_fn = ACT2FN[config.hidden_act]
269273

270274
def forward(self, x):
271275
if self.config.pretraining_tp > 1:
@@ -450,12 +454,16 @@ def forward(
450454
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
451455

452456
query_states = ops.zeros((bsz, self.num_heads, q_len, self.q_head_dim), dtype=q_pe.dtype)
453-
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
454-
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
457+
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
458+
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
459+
query_states = ops.cat([q_nope, q_pe], dim=-1)
455460

456461
key_states = ops.zeros((bsz, self.num_heads, q_len, self.q_head_dim), dtype=k_pe.dtype)
457-
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
458-
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
462+
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
463+
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
464+
k_pe = ops.broadcast_to(k_pe, (bsz, self.num_heads, q_len, self.qk_rope_head_dim))
465+
key_states = ops.cat([k_nope, k_pe], dim=-1)
466+
459467
if past_key_value is not None:
460468
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
461469
key_states, value_states = past_key_value.update(
@@ -932,7 +940,8 @@ def prepare_inputs_for_generation(
932940
position_ids = kwargs.get("position_ids", None)
933941
if attention_mask is not None and position_ids is None:
934942
# create position_ids on the fly for batch generation
935-
position_ids = attention_mask.int().cumsum(-1) - 1
943+
# position_ids = attention_mask.int().cumsum(-1) - 1
944+
position_ids = ops.cumsum(attention_mask.int(), -1) - 1
936945
position_ids.masked_fill(attention_mask == 0, 1)
937946
if past_key_values:
938947
position_ids = position_ids[:, -input_ids.shape[1] :]

0 commit comments

Comments
 (0)