Skip to content

Commit

Permalink
fix and pass all the test cases in TestSoftmaxCrossEntropy
Browse files Browse the repository at this point in the history
  • Loading branch information
XJDKC committed Apr 8, 2020
1 parent 373dac1 commit 917c455
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions test/singa/test_cross_entropy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class TestSoftmaxCrossEntropy : public ::testing::Test {

TEST_F(TestSoftmaxCrossEntropy, CppForward) {
p.CopyDataFromHostPtr(pdat, 8);
t.AsType(singa::kInt);
EXPECT_TRUE(p.block()->initialized());
t.CopyDataFromHostPtr(tdat, 2);
t.AsType(singa::kInt);


singa::SoftmaxCrossEntropy cross_entropy;
const Tensor& loss = cross_entropy.Forward(singa::kEval, p, t);
Expand All @@ -56,8 +58,8 @@ TEST_F(TestSoftmaxCrossEntropy, CppForward) {

TEST_F(TestSoftmaxCrossEntropy, CppForwardAryTarget) {
p.CopyDataFromHostPtr(pdat, 8);
ta.AsType(singa::kInt);
ta.CopyDataFromHostPtr(tary, 8);
ta.AsType(singa::kInt);

singa::SoftmaxCrossEntropy cross_entropy;
const Tensor& loss = cross_entropy.Forward(singa::kEval, p, ta);
Expand All @@ -70,8 +72,8 @@ TEST_F(TestSoftmaxCrossEntropy, CppForwardAryTarget) {

TEST_F(TestSoftmaxCrossEntropy, CppBackward) {
p.CopyDataFromHostPtr(pdat, 8);
t.AsType(singa::kInt);
t.CopyDataFromHostPtr(tdat, 2);
t.AsType(singa::kInt);

singa::SoftmaxCrossEntropy cross_entropy;
cross_entropy.Forward(singa::kTrain, p, t);
Expand All @@ -90,8 +92,8 @@ TEST_F(TestSoftmaxCrossEntropy, CppBackward) {

TEST_F(TestSoftmaxCrossEntropy, CppBackwardAryTarget) {
p.CopyDataFromHostPtr(pdat, 8);
ta.AsType(singa::kInt);
ta.CopyDataFromHostPtr(tary, 8);
ta.AsType(singa::kInt);

singa::SoftmaxCrossEntropy cross_entropy;
cross_entropy.Forward(singa::kTrain, p, ta);
Expand Down

0 comments on commit 917c455

Please sign in to comment.