Skip to content

Commit bd6ba69

Browse files
authored
modify qwen2 (#2023)
1 parent 13fc6e8 commit bd6ba69

File tree

5 files changed

+18
-14
lines changed

5 files changed

+18
-14
lines changed

mindnlp/core/ops/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def range_(start, length):
533533
updates = moveaxis(
534534
updates, range_(batch_start, batch_size), range(batch_size)
535535
)
536-
tensor = ops.tensor_scatter_update(tensor, stacked_indices, updates)
536+
tensor = ops.tensor_scatter_update(tensor, stacked_indices, updates.to(tensor.dtype))
537537
if range(len(dims)) != dims:
538538
tensor = moveaxis(tensor, range(len(dims)), dims)
539539
return strided_slice_update(

mindnlp/transformers/generation/beam_search.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def finalize(
329329
decoded[i, : sent_lengths[i]] = hypo
330330

331331
if indices is not None:
332-
indices[i, : len(best_idx)] = best_idx
332+
indices[i, : len(best_idx)] = mindspore.Tensor(best_idx, dtype=indices.dtype)
333333

334334
if sent_lengths[i] < sent_max_len:
335335
# inserting only the first eos_token_id
@@ -832,7 +832,7 @@ def finalize(
832832
decoded[i, : sent_lengths[i]] = hypo
833833

834834
if indices is not None:
835-
indices[i, : len(best_idx)] = best_idx
835+
indices[i, : len(best_idx)] = mindspore.Tensor(best_idx, dtype=indices.dtype)
836836

837837
if sent_lengths[i] < sent_max_len:
838838
# inserting only the first eos_token_id

mindnlp/transformers/models/qwen2/modeling_qwen2.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -758,11 +758,12 @@ def forward(
758758
shift_logits = logits[..., :-1, :]
759759
shift_labels = labels[..., 1:]
760760
# Flatten the tokens
761-
loss_fct = CrossEntropyLoss()
761+
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
762762
shift_logits = shift_logits.view(-1, self.config.vocab_size)
763-
shift_labels = shift_labels.view(-1)
763+
shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
764764
# Enable model parallelism
765-
loss = loss_fct(shift_logits, shift_labels)
765+
loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
766+
loss = loss.mean()
766767

767768
if not return_dict:
768769
output = (logits,) + outputs[1:]
@@ -934,8 +935,10 @@ def forward(
934935
else:
935936
loss = loss_fct(pooled_logits, labels)
936937
elif self.config.problem_type == "single_label_classification":
937-
loss_fct = CrossEntropyLoss()
938-
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
938+
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
939+
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
940+
loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
941+
loss = loss.mean()
939942
elif self.config.problem_type == "multi_label_classification":
940943
loss_fct = BCEWithLogitsLoss()
941944
loss = loss_fct(pooled_logits, labels)
@@ -1014,8 +1017,10 @@ def forward(
10141017

10151018
loss = None
10161019
if labels is not None:
1017-
loss_fct = CrossEntropyLoss()
1018-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1020+
loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
1021+
labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
1022+
loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
1023+
loss = loss.mean()
10191024

10201025
if not return_dict:
10211026
output = (logits,) + outputs[2:]

tests/transformers/generation/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,7 @@ def test_left_padding_compatibility(self):
14961496
def _prepare_model_kwargs(input_ids, attention_mask, signature):
14971497
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
14981498
if "position_ids" in signature:
1499-
position_ids = ops.cumsum(attention_mask, dim=-1) - 1
1499+
position_ids = ops.cumsum(attention_mask.int(), dim=-1) - 1
15001500
position_ids = position_ids.masked_fill(attention_mask == 0, 1)
15011501
model_kwargs["position_ids"] = position_ids
15021502
if "cache_position" in signature:
@@ -3286,4 +3286,4 @@ def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
32863286

32873287
# bos_token_id is required when no input ids nor inputs_embeds is passed
32883288
with self.assertRaises(ValueError):
3289-
model.generate(max_length=20, bos_token_id=None)
3289+
model.generate(max_length=20, bos_token_id=None)

tests/transformers/models/qwen2/test_modeling_qwen2.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
if is_mindspore_available():
3838
import mindspore
3939
from mindnlp.core import ops, nn, no_grad
40-
4140
from mindnlp.transformers import (
4241
Qwen2ForCausalLM,
4342
Qwen2ForSequenceClassification,
@@ -482,4 +481,4 @@ def test_speculative_generation(self):
482481
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
483482

484483
del model
485-
gc.collect()
484+
gc.collect()

0 commit comments

Comments
 (0)