@@ -758,11 +758,12 @@ def forward(
758
758
shift_logits = logits [..., :- 1 , :]
759
759
shift_labels = labels [..., 1 :]
760
760
# Flatten the tokens
761
- loss_fct = CrossEntropyLoss ()
761
+ loss_fct = mindspore . ops . SoftmaxCrossEntropyWithLogits ()
762
762
shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
763
- shift_labels = shift_labels .view (- 1 )
763
+ shift_labels = nn . functional . one_hot ( shift_labels .view (- 1 ), self . config . vocab_size )
764
764
# Enable model parallelism
765
- loss = loss_fct (shift_logits , shift_labels )
765
+ loss , _ = loss_fct (shift_logits , shift_labels .to (shift_logits .dtype ))
766
+ loss = loss .mean ()
766
767
767
768
if not return_dict :
768
769
output = (logits ,) + outputs [1 :]
@@ -934,8 +935,10 @@ def forward(
934
935
else :
935
936
loss = loss_fct (pooled_logits , labels )
936
937
elif self .config .problem_type == "single_label_classification" :
937
- loss_fct = CrossEntropyLoss ()
938
- loss = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
938
+ loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
939
+ labels = nn .functional .one_hot (labels .view (- 1 ), self .num_labels )
940
+ loss , _ = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .to (pooled_logits .dtype ))
941
+ loss = loss .mean ()
939
942
elif self .config .problem_type == "multi_label_classification" :
940
943
loss_fct = BCEWithLogitsLoss ()
941
944
loss = loss_fct (pooled_logits , labels )
@@ -1014,8 +1017,10 @@ def forward(
1014
1017
1015
1018
loss = None
1016
1019
if labels is not None :
1017
- loss_fct = CrossEntropyLoss ()
1018
- loss = loss_fct (logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
1020
+ loss_fct = mindspore .ops .SoftmaxCrossEntropyWithLogits ()
1021
+ labels = nn .functional .one_hot (labels .view (- 1 ), self .num_labels )
1022
+ loss , _ = loss_fct (logits .view (- 1 , self .num_labels ), labels .to (logits .dtype ))
1023
+ loss = loss .mean ()
1019
1024
1020
1025
if not return_dict :
1021
1026
output = (logits ,) + outputs [2 :]
0 commit comments