Skip to content

Commit cf36b38

Browse files
Feature/add t bool test (#628)
* add TBoolTest (mix of TStringTest and NumericTypesTestBase) * minor cleanup * add copyright comment, fix comment in test * make class and test methods public to be aligned with other tests * fix formatting with mvn spotless:apply --------- Co-authored-by: Winfried Gerlach <w.schoech@gmail.com>
1 parent 89703f7 commit cf36b38

File tree

2 files changed

+157
-3
lines changed

2 files changed

+157
-3
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121
import com.google.protobuf.InvalidProtocolBufferException;
2222
import java.util.ArrayDeque;
23-
import java.util.ArrayList;
2423
import java.util.Collection;
25-
import java.util.Collections;
2624
import java.util.LinkedHashSet;
2725
import java.util.List;
2826
import java.util.Map;
@@ -92,7 +90,7 @@ public synchronized List<String> getDependencies() {
9290
}
9391
});
9492
}
95-
dependencies = Collections.unmodifiableList(new ArrayList<>(deps));
93+
dependencies = List.copyOf(deps);
9694
}
9795

9896
return dependencies;
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =======================================================================
16+
*/
17+
18+
package org.tensorflow.types;
19+
20+
import static org.junit.jupiter.api.Assertions.assertEquals;
21+
import static org.junit.jupiter.api.Assertions.assertNotNull;
22+
23+
import org.junit.jupiter.api.Test;
24+
import org.tensorflow.EagerSession;
25+
import org.tensorflow.ndarray.NdArray;
26+
import org.tensorflow.ndarray.NdArrays;
27+
import org.tensorflow.ndarray.Shape;
28+
import org.tensorflow.ndarray.index.Indices;
29+
import org.tensorflow.op.Ops;
30+
import org.tensorflow.op.core.Constant;
31+
import org.tensorflow.op.math.LogicalAnd;
32+
import org.tensorflow.op.math.LogicalNot;
33+
import org.tensorflow.op.math.LogicalOr;
34+
35+
public class TBoolTest {
36+
37+
@Test
38+
public void createScalar() {
39+
TBool tensorT = TBool.scalarOf(true);
40+
assertNotNull(tensorT);
41+
assertEquals(Shape.scalar(), tensorT.shape());
42+
assertEquals(true, tensorT.getObject());
43+
44+
TBool tensorF = TBool.scalarOf(false);
45+
assertNotNull(tensorF);
46+
assertEquals(Shape.scalar(), tensorF.shape());
47+
assertEquals(false, tensorF.getObject());
48+
}
49+
50+
@Test
51+
public void createVector() {
52+
TBool tensor = TBool.vectorOf(true, false);
53+
assertNotNull(tensor);
54+
assertEquals(Shape.of(2), tensor.shape());
55+
assertEquals(true, tensor.getObject(0));
56+
assertEquals(false, tensor.getObject(1));
57+
}
58+
59+
@Test
60+
public void createCopy() {
61+
NdArray<Boolean> bools =
62+
NdArrays.ofObjects(Boolean.class, Shape.of(2, 2))
63+
.setObject(true, 0, 0)
64+
.setObject(false, 0, 1)
65+
.setObject(false, 1, 0)
66+
.setObject(true, 1, 1);
67+
68+
TBool tensor = TBool.tensorOf(bools);
69+
assertNotNull(tensor);
70+
bools.scalars().forEachIndexed((idx, s) -> assertEquals(s.getObject(), tensor.getObject(idx)));
71+
}
72+
73+
@Test
74+
public void initializeTensorsWithBools() {
75+
// Allocate a tensor of booleans of the shape (2, 3, 2)
76+
TBool tensor = TBool.tensorOf(Shape.of(2, 3, 2));
77+
78+
assertEquals(3, tensor.rank());
79+
assertEquals(12, tensor.size());
80+
NdArray<Boolean> data = (NdArray<Boolean>) tensor;
81+
82+
try (EagerSession session = EagerSession.create()) {
83+
Ops tf = Ops.create(session);
84+
85+
// Initialize tensor memory with falses and take a snapshot
86+
data.scalars().forEach(scalar -> ((NdArray<Boolean>) scalar).setObject(false));
87+
Constant<TBool> x = tf.constantOf(tensor);
88+
89+
// Initialize the same tensor memory with trues and take a snapshot
90+
data.scalars().forEach(scalar -> ((NdArray<Boolean>) scalar).setObject(true));
91+
Constant<TBool> y = tf.constantOf(tensor);
92+
93+
// Calculate x AND y and validate the result
94+
LogicalAnd xAndY = tf.math.logicalAnd(x, y);
95+
((NdArray<Boolean>) xAndY.asTensor())
96+
.scalars()
97+
.forEach(scalar -> assertEquals(false, scalar.getObject()));
98+
99+
// Calculate x OR y and validate the result
100+
LogicalOr xOrY = tf.math.logicalOr(x, y);
101+
((NdArray<Boolean>) xOrY.asTensor())
102+
.scalars()
103+
.forEach(scalar -> assertEquals(true, scalar.getObject()));
104+
105+
// Calculate !x and validate the result against y
106+
LogicalNot notX = tf.math.logicalNot(x);
107+
assertEquals(y.asTensor(), notX.asTensor());
108+
}
109+
}
110+
111+
@Test
112+
public void setAndCompute() {
113+
NdArray<Boolean> heapData =
114+
NdArrays.ofBooleans(Shape.of(4))
115+
.setObject(true, 0)
116+
.setObject(false, 1)
117+
.setObject(true, 2)
118+
.setObject(false, 3);
119+
120+
// Creates a 2x2 matrix
121+
try (TBool tensor = TBool.tensorOf(Shape.of(2, 2))) {
122+
NdArray<Boolean> data = (NdArray<Boolean>) tensor;
123+
124+
// Copy first 2 values of the vector to the first row of the matrix
125+
data.set(heapData.slice(Indices.range(0, 2)), 0);
126+
127+
// Copy values at an odd position in the vector as the second row of the matrix
128+
data.set(heapData.slice(Indices.odd()), 1);
129+
130+
assertEquals(true, data.getObject(0, 0));
131+
assertEquals(false, data.getObject(0, 1));
132+
assertEquals(false, data.getObject(1, 0));
133+
assertEquals(false, data.getObject(1, 1));
134+
135+
// Read rows of the tensor in reverse order
136+
NdArray<Boolean> flippedData = data.slice(Indices.flip(), Indices.flip());
137+
138+
assertEquals(false, flippedData.getObject(0, 0));
139+
assertEquals(false, flippedData.getObject(0, 1));
140+
assertEquals(false, flippedData.getObject(1, 0));
141+
assertEquals(true, flippedData.getObject(1, 1));
142+
143+
try (EagerSession session = EagerSession.create()) {
144+
Ops tf = Ops.create(session);
145+
146+
LogicalNot sub = tf.math.logicalNot(tf.constantOf(tensor));
147+
NdArray<Boolean> result = (NdArray<Boolean>) sub.asTensor();
148+
149+
assertEquals(false, result.getObject(0, 0));
150+
assertEquals(true, result.getObject(0, 1));
151+
assertEquals(true, result.getObject(1, 0));
152+
assertEquals(true, result.getObject(1, 1));
153+
}
154+
}
155+
}
156+
}

0 commit comments

Comments
 (0)