18
18
19
19
import static org .junit .jupiter .api .Assertions .assertEquals ;
20
20
import static org .junit .jupiter .api .Assertions .assertTrue ;
21
+ import static org .junit .jupiter .api .Assertions .assertArrayEquals ;
21
22
22
23
import org .junit .jupiter .api .Test ;
23
24
import org .tensorflow .ndarray .index .Indices ;
@@ -43,6 +44,122 @@ public void testNullConversions(){
43
44
assertTrue (Indices .slice (null , null ).endMask (),
44
45
"Passed null for slice end but didn't set end mask" );
45
46
}
47
+
48
+ @ Test
49
+ public void testIndices (){
50
+
51
+ String [][] indexData = new String [5 ][4 ];
52
+ for (int i =0 ; i < 5 ; i ++)
53
+ for (int j =0 ; j < 4 ; j ++)
54
+ indexData [i ][j ] = "(" +j +", " +i +")" ;
55
+
56
+ NdArray <String > matrix2d = StdArrays .ndCopyOf (indexData );
57
+ assertEquals (2 , matrix2d .rank ());
58
+
59
+ /*
60
+ |(0, 0), (1, 0), (2, 0), (3, 0)|
61
+ |(0, 1), (1, 1), (2, 1), (3, 1)|
62
+ |(0, 2), (1, 2), (2, 2), (3, 2)|
63
+ |(0, 3), (1, 3), (2, 3), (3, 3)|
64
+ |(0, 4), (1, 4), (2, 4), (3, 4)|
65
+ */
66
+
67
+ NdArray <String > same1 = matrix2d .slice (Indices .all ());
68
+ String [][] same1j = StdArrays .array2dCopyOf (same1 , String .class );
69
+ assertEquals (2 , same1 .rank ());
70
+ assertEquals (same1 , matrix2d );
71
+
72
+ NdArray <String > same2 = matrix2d .slice (Indices .ellipsis ());
73
+ String [][] same2j = StdArrays .array2dCopyOf (same2 , String .class );
74
+ assertEquals (2 , same2 .rank ());
75
+ assertEquals (matrix2d , same2 );
76
+
77
+ // All rows, column 1
78
+ NdArray <String > same3 = matrix2d .slice (Indices .all (), Indices .at (1 ));
79
+ assertEquals (1 , same3 .rank ());
80
+ String [] same3j = StdArrays .array1dCopyOf (same3 , String .class );
81
+ assertArrayEquals (new String [] { "(1, 0)" , "(1, 1)" , "(1, 2)" , "(1, 3)" , "(1, 4)" }, same3j );
82
+
83
+ // row 2, all columns
84
+ NdArray <String > same4 = matrix2d .slice (Indices .at (2 ), Indices .all ());
85
+ assertEquals (1 , same4 .rank ());
86
+ String [] same4j = StdArrays .array1dCopyOf (same4 , String .class );
87
+ assertArrayEquals (new String [] {"(0, 2)" , "(1, 2)" , "(2, 2)" , "(3, 2)" }, same4j );
88
+ assertEquals (NdArrays .vectorOfObjects ("(0, 2)" , "(1, 2)" , "(2, 2)" , "(3, 2)" ), same4 );
89
+
90
+ // row 2, column 1
91
+ NdArray <String > same5 = matrix2d .slice (Indices .at (2 ), Indices .at (1 ));
92
+ assertEquals (0 , same5 .rank ());
93
+ assertTrue (same5 .shape ().isScalar ());
94
+ // Don't use an index
95
+ String same5j = same5 .getObject ();
96
+ assertEquals ("(1, 2)" , same5j );
97
+
98
+ // rows 1 to 2, all columns
99
+ NdArray <String > same6 = matrix2d .slice (Indices .slice (1 ,3 ));
100
+ assertEquals (2 , same6 .rank ());
101
+ String [][] same6j = StdArrays .array2dCopyOf (same6 , String .class );
102
+ assertArrayEquals (
103
+ new String [][]
104
+ {
105
+ {"(0, 1)" , "(1, 1)" , "(2, 1)" , "(3, 1)" },
106
+ {"(0, 2)" , "(1, 2)" , "(2, 2)" , "(3, 2)" }
107
+ },
108
+ same6j
109
+ );
110
+
111
+ // Exception in thread "main" java.nio.BufferOverflowException
112
+ // all rows, columns 1 to 2
113
+ NdArray <String > same7 = matrix2d .slice (Indices .all (), Indices .slice (1 ,3 ));
114
+ assertEquals (2 , same7 .rank ());
115
+ assertEquals (Shape .of (5 ,2 ), same7 .shape ());
116
+ assertEquals (10 , same7 .size ());
117
+ NdArray <String > r7_0 = same7 .get (0 );
118
+ NdArray <String > r7_1 = same7 .get (1 );
119
+ NdArray <String > r7_2 = same7 .get (2 );
120
+ NdArray <String > r7_3 = same7 .get (3 );
121
+ NdArray <String > r7_4 = same7 .get (4 );
122
+ assertEquals (1 , r7_0 .rank ());
123
+ assertEquals (Shape .of (2 ), r7_0 .shape ());
124
+ assertEquals (2 , r7_0 .size ());
125
+ // TODO: I get a (0,0) which is not what I expected
126
+ System .out .println (r7_0 .getObject ());
127
+ //assertEquals("(1,0)", r7_0.getObject());
128
+ assertEquals ( "(1, 0)" , r7_0 .getObject (0 ));
129
+ assertEquals ( "(2, 0)" , r7_0 .getObject (1 ));
130
+ assertEquals ( "(1, 1)" , r7_1 .getObject (0 ));
131
+ assertEquals ( "(2, 1)" , r7_1 .getObject (1 ));
132
+ assertEquals ( "(1, 2)" , r7_2 .getObject (0 ));
133
+ assertEquals ( "(2, 2)" , r7_2 .getObject (1 ));
134
+ assertEquals ( "(1, 3)" , r7_3 .getObject (0 ));
135
+ assertEquals ( "(2, 3)" , r7_3 .getObject (1 ));
136
+ assertEquals ( "(1, 4)" , r7_4 .getObject (0 ));
137
+ assertEquals ( "(2, 4)" , r7_4 .getObject (1 ));
138
+ String [][] expectedr7 = new String [][]
139
+ {
140
+ {"(1, 0)" , "(2, 0)" },
141
+ {"(1, 1)" , "(2, 1)" },
142
+ {"(1, 2)" , "(2, 2)" },
143
+ {"(1, 3)" , "(2, 3)" },
144
+ {"(1, 4)" , "(2, 4)" }
145
+ };
146
+ //String[][] lArray = new String[5][2];
147
+ String [][] lArray = new String [5 ][];
148
+ lArray [0 ] = new String [2 ];
149
+ lArray [1 ] = new String [2 ];
150
+ lArray [2 ] = new String [2 ];
151
+ lArray [3 ] = new String [2 ];
152
+ lArray [4 ] = new String [2 ];
153
+ StdArrays .copyFrom (same7 , lArray );
154
+ assertArrayEquals ( expectedr7 , lArray );
155
+ String [][] same7j = StdArrays .array2dCopyOf (same7 , String .class );
156
+ assertArrayEquals ( expectedr7 , same7j );
157
+
158
+ /*
159
+ */
160
+
161
+ assertEquals (0 , 0 );
162
+ }
46
163
47
164
@ Test
48
165
public void testNewaxis (){
0 commit comments