|
30 | 30 | from mindnlp.core import nn, ops
|
31 | 31 | from mindnlp.core.nn import functional as F
|
32 | 32 | from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 33 | +from mindnlp.configs import ON_ORANGE_PI |
33 | 34 | from ....common.activations import ACT2FN
|
34 | 35 | from ...cache_utils import Cache, DynamicCache
|
35 | 36 | from ...modeling_attn_mask_utils import (
|
@@ -265,7 +266,10 @@ def __init__(self, config):
|
265 | 266 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
266 | 267 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
267 | 268 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
268 |
| - self.act_fn = ACT2FN[config.hidden_act] |
| 269 | + if ON_ORANGE_PI: |
| 270 | + self.act_fn = mindspore.ops.silu |
| 271 | + else: |
| 272 | + self.act_fn = ACT2FN[config.hidden_act] |
269 | 273 |
|
270 | 274 | def forward(self, x):
|
271 | 275 | if self.config.pretraining_tp > 1:
|
@@ -450,12 +454,16 @@ def forward(
|
450 | 454 | q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
451 | 455 |
|
452 | 456 | query_states = ops.zeros((bsz, self.num_heads, q_len, self.q_head_dim), dtype=q_pe.dtype)
|
453 |
| - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope |
454 |
| - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe |
| 457 | + # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope |
| 458 | + # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe |
| 459 | + query_states = ops.cat([q_nope, q_pe], dim=-1) |
455 | 460 |
|
456 | 461 | key_states = ops.zeros((bsz, self.num_heads, q_len, self.q_head_dim), dtype=k_pe.dtype)
|
457 |
| - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope |
458 |
| - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe |
| 462 | + # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope |
| 463 | + # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe |
| 464 | + k_pe = ops.broadcast_to(k_pe, (bsz, self.num_heads, q_len, self.qk_rope_head_dim)) |
| 465 | + key_states = ops.cat([k_nope, k_pe], dim=-1) |
| 466 | + |
459 | 467 | if past_key_value is not None:
|
460 | 468 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
461 | 469 | key_states, value_states = past_key_value.update(
|
@@ -932,7 +940,8 @@ def prepare_inputs_for_generation(
|
932 | 940 | position_ids = kwargs.get("position_ids", None)
|
933 | 941 | if attention_mask is not None and position_ids is None:
|
934 | 942 | # create position_ids on the fly for batch generation
|
935 |
| - position_ids = attention_mask.int().cumsum(-1) - 1 |
| 943 | + # position_ids = attention_mask.int().cumsum(-1) - 1 |
| 944 | + position_ids = ops.cumsum(attention_mask.int(), -1) - 1 |
936 | 945 | position_ids.masked_fill(attention_mask == 0, 1)
|
937 | 946 | if past_key_values:
|
938 | 947 | position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
0 commit comments