Skip to content

Commit ff04d86

Browse files
committed
Added basic index tests (rank 2)
1 parent 0c450cc commit ff04d86

File tree

2 files changed

+259
-37
lines changed

2 files changed

+259
-37
lines changed

.devcontainer/devcontainer.json

Lines changed: 0 additions & 23 deletions
This file was deleted.

ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java

Lines changed: 259 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ public void testNullConversions(){
4848
@Test
4949
public void testIndices(){
5050

51-
// String[][] indexData = new String[5][4];
52-
String[][] indexData = new String[5][];
51+
String[][] indexData = new String[5][4];
5352
for (int i=0 ; i < 5; i++){
54-
indexData[i] = new String[4];
5553
for (int j=0 ; j < 4; j++)
5654
indexData[i][j] = "("+j+", "+i+")";
5755
}
@@ -72,11 +70,13 @@ public void testIndices(){
7270
String[][] same1j = StdArrays.array2dCopyOf(same1, String.class);
7371
assertEquals(2, same1.rank());
7472
assertEquals(same1, matrix2d);
73+
assertEquals(matrix2d, StdArrays.ndCopyOf(same1j));
7574

7675
NdArray<String> same2 = matrix2d.slice(Indices.ellipsis());
7776
String[][] same2j = StdArrays.array2dCopyOf(same2, String.class);
7877
assertEquals(2, same2.rank());
7978
assertEquals(matrix2d, same2);
79+
assertEquals(matrix2d, StdArrays.ndCopyOf(same2j));
8080

8181
// All rows, column 1
8282
NdArray<String> same3 = matrix2d.slice(Indices.all(), Indices.at(1));
@@ -127,8 +127,8 @@ public void testIndices(){
127127
assertEquals(Shape.of(2), r7_0.shape());
128128
assertEquals(2, r7_0.size());
129129
// TODO: I get a (0,0) which is not what I expected
130-
System.out.println(r7_0.getObject());
131-
//assertEquals("(1,0)", r7_0.getObject());
130+
// System.out.println(r7_0.getObject());
131+
// assertEquals("(1,0)", r7_0.getObject());
132132
assertEquals( "(1, 0)", r7_0.getObject(0));
133133
assertEquals( "(2, 0)", r7_0.getObject(1));
134134
assertEquals( "(1, 1)", r7_1.getObject(0));
@@ -148,21 +148,266 @@ public void testIndices(){
148148
{"(1, 4)", "(2, 4)"}
149149
};
150150
String[][] lArray = new String[5][2];
151-
// String[][] lArray = new String[5][];
152-
// lArray[0] = new String[2];
153-
// lArray[1] = new String[2];
154-
// lArray[2] = new String[2];
155-
// lArray[3] = new String[2];
156-
// lArray[4] = new String[2];
157151
StdArrays.copyFrom(same7, lArray);
158152
assertArrayEquals( expectedr7, lArray);
159153
String[][] same7j = StdArrays.array2dCopyOf(same7, String.class);
160154
assertArrayEquals( expectedr7, same7j);
161155

162-
/*
163-
*/
156+
157+
// rows 1 to 2, columns 1 to 2
158+
NdArray<String> same8 = matrix2d.slice(Indices.slice(1,3), Indices.slice(1,3));
159+
assertEquals(2, same8.rank());
160+
assertEquals(Shape.of(2,2), same8.shape());
161+
assertEquals(4, same8.size());
162+
String[][] same8j = StdArrays.array2dCopyOf(same8, String.class);
163+
// print2D(same8j)
164+
String[][] expected_r8 = new String[][]
165+
{
166+
{"(1, 1)", "(2, 1)"},
167+
{"(1, 2)", "(2, 2)"}
168+
};
169+
assertArrayEquals(expected_r8, same8j);
170+
NdArray<String> r8_0 = same8.get(0);
171+
NdArray<String> r8_1 = same8.get(1);
172+
assertEquals(1, r8_0.rank());
173+
assertEquals(Shape.of(2), r8_0.shape());
174+
assertEquals(2, r8_0.size());
175+
assertEquals("(1, 1)", r8_0.getObject(0));
176+
assertEquals("(2, 1)", r8_0.getObject(1));
177+
assertEquals("(1, 2)", r8_1.getObject(0));
178+
assertEquals("(2, 2)", r8_1.getObject(1));
179+
180+
// rows 1 to 2, columns 1 to 2
181+
NdArray<String> same9 = matrix2d.slice(Indices.range(1,3), Indices.range(1,3));
182+
assertEquals(2, same9.rank());
183+
assertEquals(Shape.of(2,2), same9.shape());
184+
assertEquals(4, same9.size());
185+
String[][] same9j = StdArrays.array2dCopyOf(same9, String.class);
186+
String[][] expected_r9 = new String[][]
187+
{
188+
{"(1, 1)", "(2, 1)"},
189+
{"(1, 2)", "(2, 2)"}
190+
};
191+
assertArrayEquals(expected_r9, same9j);
192+
NdArray<String> r9_0 = same9.get(0);
193+
NdArray<String> r9_1 = same9.get(1);
194+
assertEquals(1, r9_0.rank());
195+
assertEquals(Shape.of(2), r9_0.shape());
196+
assertEquals(2, r9_0.size());
197+
assertEquals("(1, 1)", r9_0.getObject(0));
198+
assertEquals("(2, 1)", r9_0.getObject(1));
199+
assertEquals("(1, 2)", r9_1.getObject(0));
200+
assertEquals("(2, 2)", r9_1.getObject(1));
201+
202+
// rows 1, 3 and 4, columns 0 to 2
203+
NdArray<String> same10 = matrix2d.slice(Indices.odd(), Indices.even());
204+
String[][] same10j = StdArrays.array2dCopyOf(same10, String.class);
205+
assertEquals(2, same10.rank());
206+
assertEquals(Shape.of(2,2), same10.shape());
207+
assertEquals(4, same10.size());
208+
String[][] expected_r10 = new String[][]
209+
{
210+
{"(0, 1)", "(2, 1)"},
211+
{"(0, 3)", "(2, 3)"}
212+
};
213+
assertArrayEquals(expected_r10, same10j);
214+
NdArray<String> r10_0 = same10.get(0);
215+
NdArray<String> r10_1 = same10.get(1);
216+
assertEquals(1, r10_0.rank());
217+
assertEquals(Shape.of(2), r10_0.shape());
218+
assertEquals(2, r10_0.size());
219+
assertEquals("(0, 1)", r10_0.getObject(0));
220+
assertEquals("(2, 1)", r10_0.getObject(1));
221+
assertEquals("(0, 3)", r10_1.getObject(0));
222+
assertEquals("(2, 3)", r10_1.getObject(1));
223+
224+
// rows 3 and 4, columns 0 and 1. Second value is stride
225+
NdArray<String> same11 = matrix2d.slice(Indices.sliceFrom(3,1), Indices.sliceFrom(2,1));
226+
String[][] same11j = StdArrays.array2dCopyOf(same11, String.class);
227+
assertEquals(2, same11.rank());
228+
assertEquals(Shape.of(2,2), same11.shape());
229+
assertEquals(4, same11.size());
230+
String[][] expected_r11 = new String[][]
231+
{
232+
{"(2, 3)", "(3, 3)"},
233+
{"(2, 4)", "(3, 4)"}
234+
};
235+
assertArrayEquals(expected_r11, same11j);
236+
NdArray<String> r11_0 = same11.get(0);
237+
NdArray<String> r11_1 = same11.get(1);
238+
assertEquals(1, r11_0.rank());
239+
assertEquals(Shape.of(2), r11_0.shape());
240+
assertEquals(2, r11_0.size());
241+
assertEquals("(2, 3)", r11_0.getObject(0));
242+
assertEquals("(3, 3)", r11_0.getObject(1));
243+
assertEquals("(2, 4)", r11_1.getObject(0));
244+
assertEquals("(3, 4)", r11_1.getObject(1));
245+
246+
247+
// rows 0 and 2, columns 0 and 1. Second value is stride. Index non inclusive
248+
NdArray<String> same12 = matrix2d.slice(Indices.sliceTo(3,2), Indices.sliceTo(2,1));
249+
String[][] same12j = StdArrays.array2dCopyOf(same12, String.class);
250+
assertEquals(2, same12.rank());
251+
assertEquals(Shape.of(2,2), same12.shape());
252+
assertEquals(4, same12.size());
253+
String[][] expected_r12 = new String[][]
254+
{
255+
{"(0, 0)", "(1, 0)"},
256+
{"(0, 2)", "(1, 2)"}
257+
};
258+
assertArrayEquals(expected_r12, same12j);
259+
NdArray<String> r12_0 = same12.get(0);
260+
NdArray<String> r12_1 = same12.get(1);
261+
assertEquals(1, r12_0.rank());
262+
assertEquals(Shape.of(2), r12_0.shape());
263+
assertEquals(2, r12_0.size());
264+
assertEquals("(0, 0)", r12_0.getObject(0));
265+
assertEquals("(1, 0)", r12_0.getObject(1));
266+
assertEquals("(0, 2)", r12_1.getObject(0));
267+
assertEquals("(1, 2)", r12_1.getObject(1));
268+
269+
// rows 0 and 2, columns 0 and 1. Second value is stride. Index non inclusive
270+
NdArray<String> same13 = matrix2d.slice(Indices.step(2), Indices.step(2));
271+
String[][] same13j = StdArrays.array2dCopyOf(same13, String.class);
272+
assertEquals(2, same13.rank());
273+
assertEquals(Shape.of(3,2), same13.shape());
274+
assertEquals(6, same13.size());
275+
String[][] expected_r13 = new String[][]
276+
{
277+
{"(0, 0)", "(2, 0)"},
278+
{"(0, 2)", "(2, 2)"},
279+
{"(0, 4)", "(2, 4)"}
280+
};
281+
assertArrayEquals(expected_r13, same13j);
282+
NdArray<String> r13_0 = same13.get(0);
283+
NdArray<String> r13_1 = same13.get(1);
284+
NdArray<String> r13_2 = same13.get(2);
285+
assertEquals(1, r13_0.rank());
286+
assertEquals(Shape.of(2), r13_0.shape());
287+
assertEquals(2, r13_0.size());
288+
assertEquals("(0, 0)", r13_0.getObject(0));
289+
assertEquals("(2, 0)", r13_0.getObject(1));
290+
assertEquals("(0, 2)", r13_1.getObject(0));
291+
assertEquals("(2, 2)", r13_1.getObject(1));
292+
assertEquals("(0, 4)", r13_2.getObject(0));
293+
assertEquals("(2, 4)", r13_2.getObject(1));
294+
295+
296+
NdArray<String> same14 = same13.slice(Indices.flip(), Indices.flip());
297+
String[][] same14j = StdArrays.array2dCopyOf(same14, String.class);
298+
assertEquals(2, same14.rank());
299+
assertEquals(Shape.of(3,2), same14.shape());
300+
assertEquals(6, same14.size());
301+
String[][] expected_r14 = new String[][]
302+
{
303+
{"(2, 4)", "(0, 4)"},
304+
{"(2, 2)", "(0, 2)"},
305+
{"(2, 0)", "(0, 0)"}
306+
};
307+
assertArrayEquals(same14j, expected_r14);
308+
NdArray<String> r14_0 = same14.get(0);
309+
NdArray<String> r14_1 = same14.get(1);
310+
NdArray<String> r14_2 = same14.get(2);
311+
assertEquals(1, r14_0.rank());
312+
assertEquals(Shape.of(2), r14_0.shape());
313+
assertEquals(2, r14_0.size());
314+
assertEquals("(0, 0)", r14_2.getObject(1));
315+
assertEquals("(2, 0)", r14_2.getObject(0));
316+
assertEquals("(0, 2)", r14_1.getObject(1));
317+
assertEquals("(2, 2)", r14_1.getObject(0));
318+
assertEquals("(0, 4)", r14_0.getObject(1));
319+
assertEquals("(2, 4)", r14_0.getObject(0));
320+
321+
322+
NdArray<String> same15 = matrix2d.slice(Indices.slice(4,0,-2), Indices.slice(3L,null,-2));
323+
String[][] same15j = StdArrays.array2dCopyOf(same15, String.class);
324+
assertEquals(2, same15.rank());
325+
assertEquals(Shape.of(2,2), same15.shape());
326+
assertEquals(4,same15.size());
327+
String[][] expected_r15 = new String[][]
328+
{
329+
{"(3, 4)", "(1, 4)"},
330+
{"(3, 2)", "(1, 2)"},
331+
};
332+
assertArrayEquals(expected_r15, same15j);
333+
NdArray<String> r15_0 = same15.get(0);
334+
NdArray<String> r15_1 = same15.get(1);
335+
assertEquals(1, r15_0.rank());
336+
assertEquals(Shape.of(2), r15_0.shape());
337+
assertEquals(2, r15_0.size());
338+
assertEquals("(3, 4)", r15_0.getObject(0));
339+
assertEquals("(1, 4)", r15_0.getObject(1));
340+
assertEquals("(3, 2)", r15_1.getObject(0));
341+
assertEquals("(1, 2)", r15_1.getObject(1));
342+
343+
NdArray<String> same16 = matrix2d.slice(Indices.seq(4,2), Indices.seq(3,1));
344+
String[][] same16j = StdArrays.array2dCopyOf(same16, String.class);
345+
assertEquals(2, same16.rank());
346+
assertEquals(Shape.of(2,2), same16.shape());
347+
assertEquals(4, same16.size());
348+
String[][] expected_r16 = new String[][]
349+
{
350+
{"(3, 4)", "(1, 4)"},
351+
{"(3, 2)", "(1, 2)"}
352+
};
353+
assertArrayEquals(expected_r16, same16j);
354+
NdArray<String> r16_0 = same16.get(0);
355+
NdArray<String> r16_1 = same16.get(1);
356+
assertEquals(1, r16_0.rank());
357+
assertEquals(Shape.of(2), r16_0.shape());
358+
assertEquals(2, r16_0.size());
359+
assertEquals("(3, 4)", r16_0.getObject(0));
360+
assertEquals("(1, 4)", r16_0.getObject(1));
361+
assertEquals("(3, 2)", r16_1.getObject(0));
362+
assertEquals("(1, 2)", r16_1.getObject(1));
164363

165-
assertEquals(0, 0);
364+
365+
// New axis always has size 1
366+
NdArray<String> same17 = matrix2d.slice(Indices.all(), Indices.all(), Indices.newAxis());
367+
String[][][] same17j = StdArrays.array3dCopyOf(same17, String.class);
368+
assertEquals(3, same17.rank());
369+
assertEquals(Shape.of(5,4,1), same17.shape());
370+
assertEquals(20, same17.size());
371+
String[][][] expected_r17 = new String[][][]
372+
{
373+
{{"(0, 0)"}, {"(1, 0)"}, {"(2, 0)"}, {"(3, 0)"}},
374+
{{"(0, 1)"}, {"(1, 1)"}, {"(2, 1)"}, {"(3, 1)"}},
375+
{{"(0, 2)"}, {"(1, 2)"}, {"(2, 2)"}, {"(3, 2)"}},
376+
{{"(0, 3)"}, {"(1, 3)"}, {"(2, 3)"}, {"(3, 3)"}},
377+
{{"(0, 4)"}, {"(1, 4)"}, {"(2, 4)"}, {"(3, 4)"}}
378+
};
379+
assertArrayEquals(expected_r17, same17j);
380+
NdArray<String> r17_0 = same17.get(0);
381+
NdArray<String> r17_1 = same17.get(1);
382+
NdArray<String> r17_2 = same17.get(2);
383+
NdArray<String> r17_3 = same17.get(3);
384+
NdArray<String> r17_4 = same17.get(4);
385+
assertEquals(2, r17_0.rank());
386+
assertEquals(Shape.of(4,1), r17_0.shape());
387+
assertEquals(4, r17_0.size());
388+
// row 0
389+
// What use case can we have for a new index of size 1?
390+
// row 1
391+
assertEquals("(0, 1)", r17_1.getObject(0,0));
392+
assertEquals("(1, 1)", r17_1.getObject(1,0));
393+
assertEquals("(2, 1)", r17_1.getObject(2,0));
394+
assertEquals("(3, 1)", r17_1.getObject(3,0));
395+
// row 2
396+
assertEquals("(0, 2)", r17_2.getObject(0,0));
397+
assertEquals("(1, 2)", r17_2.getObject(1,0));
398+
assertEquals("(2, 2)", r17_2.getObject(2,0));
399+
assertEquals("(3, 2)", r17_2.getObject(3,0));
400+
// row 3
401+
assertEquals("(0, 3)", r17_3.getObject(0,0));
402+
assertEquals("(1, 3)", r17_3.getObject(1,0));
403+
assertEquals("(2, 3)", r17_3.getObject(2,0));
404+
assertEquals("(3, 3)", r17_3.getObject(3,0));
405+
// row 4
406+
assertEquals("(0, 4)", r17_4.getObject(0,0));
407+
assertEquals("(1, 4)", r17_4.getObject(1,0));
408+
assertEquals("(2, 4)", r17_4.getObject(2,0));
409+
assertEquals("(3, 4)", r17_4.getObject(3,0));
410+
166411
}
167412

168413
@Test

0 commit comments

Comments
 (0)