Skip to content

Commit 917c455

Browse files
committed
fix and pass all the test cases in TestSoftmaxCrossEntropy
1 parent 373dac1 commit 917c455

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

test/singa/test_cross_entropy.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ class TestSoftmaxCrossEntropy : public ::testing::Test {
4242

4343
TEST_F(TestSoftmaxCrossEntropy, CppForward) {
4444
p.CopyDataFromHostPtr(pdat, 8);
45-
t.AsType(singa::kInt);
45+
EXPECT_TRUE(p.block()->initialized());
4646
t.CopyDataFromHostPtr(tdat, 2);
47+
t.AsType(singa::kInt);
48+
4749

4850
singa::SoftmaxCrossEntropy cross_entropy;
4951
const Tensor& loss = cross_entropy.Forward(singa::kEval, p, t);
@@ -56,8 +58,8 @@ TEST_F(TestSoftmaxCrossEntropy, CppForward) {
5658

5759
TEST_F(TestSoftmaxCrossEntropy, CppForwardAryTarget) {
5860
p.CopyDataFromHostPtr(pdat, 8);
59-
ta.AsType(singa::kInt);
6061
ta.CopyDataFromHostPtr(tary, 8);
62+
ta.AsType(singa::kInt);
6163

6264
singa::SoftmaxCrossEntropy cross_entropy;
6365
const Tensor& loss = cross_entropy.Forward(singa::kEval, p, ta);
@@ -70,8 +72,8 @@ TEST_F(TestSoftmaxCrossEntropy, CppForwardAryTarget) {
7072

7173
TEST_F(TestSoftmaxCrossEntropy, CppBackward) {
7274
p.CopyDataFromHostPtr(pdat, 8);
73-
t.AsType(singa::kInt);
7475
t.CopyDataFromHostPtr(tdat, 2);
76+
t.AsType(singa::kInt);
7577

7678
singa::SoftmaxCrossEntropy cross_entropy;
7779
cross_entropy.Forward(singa::kTrain, p, t);
@@ -90,8 +92,8 @@ TEST_F(TestSoftmaxCrossEntropy, CppBackward) {
9092

9193
TEST_F(TestSoftmaxCrossEntropy, CppBackwardAryTarget) {
9294
p.CopyDataFromHostPtr(pdat, 8);
93-
ta.AsType(singa::kInt);
9495
ta.CopyDataFromHostPtr(tary, 8);
96+
ta.AsType(singa::kInt);
9597

9698
singa::SoftmaxCrossEntropy cross_entropy;
9799
cross_entropy.Forward(singa::kTrain, p, ta);

0 commit comments

Comments
 (0)