Skip to content

Commit 618d7ff

Browse files
authored
test: move common logic into TestBase (#193)
1 parent 33ca926 commit 618d7ff

9 files changed

+29
-78
lines changed

core/src/test/java/io/substrait/TestBase.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
package io.substrait;
22

3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.substrait.dsl.SubstraitBuilder;
6+
import io.substrait.extension.ExtensionCollector;
37
import io.substrait.extension.SimpleExtension;
8+
import io.substrait.relation.ProtoRelConverter;
9+
import io.substrait.relation.Rel;
10+
import io.substrait.relation.RelProtoConverter;
11+
import io.substrait.type.TypeCreator;
412
import java.io.IOException;
513

614
public abstract class TestBase {
@@ -14,4 +22,18 @@ public abstract class TestBase {
1422
throw new RuntimeException(e);
1523
}
1624
}
25+
26+
protected TypeCreator R = TypeCreator.REQUIRED;
27+
28+
protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
29+
protected ExtensionCollector functionCollector = new ExtensionCollector();
30+
protected RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
31+
protected ProtoRelConverter protoRelConverter =
32+
new ProtoRelConverter(functionCollector, defaultExtensionCollection);
33+
34+
protected void verifyRoundTrip(Rel rel) {
35+
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
36+
Rel relReturned = protoRelConverter.from(protoRel);
37+
assertEquals(rel, relReturned);
38+
}
1739
}

core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,14 @@
55
import static org.junit.jupiter.api.Assertions.assertThrows;
66

77
import io.substrait.TestBase;
8-
import io.substrait.dsl.SubstraitBuilder;
98
import io.substrait.extension.AdvancedExtension;
10-
import io.substrait.extension.ExtensionCollector;
119
import io.substrait.relation.utils.StringHolder;
1210
import java.util.Collections;
1311
import org.junit.jupiter.api.Nested;
1412
import org.junit.jupiter.api.Test;
1513

1614
public class ProtoRelConverterTest extends TestBase {
1715

18-
final SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
19-
final ExtensionCollector functionCollector = new ExtensionCollector();
20-
final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
21-
final ProtoRelConverter protoRelConverter =
22-
new ProtoRelConverter(functionCollector, defaultExtensionCollection);
23-
2416
final NamedScan commonTable =
2517
b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
2618

core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import org.junit.jupiter.api.Test;
2222

2323
public class AggregateRoundtripTest extends TestBase {
24-
static final org.slf4j.Logger logger =
25-
org.slf4j.LoggerFactory.getLogger(AggregateRoundtripTest.class);
2624

2725
private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) {
2826
var expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2);

core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
import static org.junit.jupiter.api.Assertions.assertEquals;
44

55
import io.substrait.TestBase;
6-
import io.substrait.dsl.SubstraitBuilder;
76
import io.substrait.expression.Expression;
87
import io.substrait.extension.AdvancedExtension;
9-
import io.substrait.extension.ExtensionCollector;
10-
import io.substrait.extension.SimpleExtension;
118
import io.substrait.relation.Aggregate;
129
import io.substrait.relation.Cross;
1310
import io.substrait.relation.ExtensionLeaf;
@@ -22,7 +19,6 @@
2219
import io.substrait.relation.Project;
2320
import io.substrait.relation.ProtoRelConverter;
2421
import io.substrait.relation.Rel;
25-
import io.substrait.relation.RelProtoConverter;
2622
import io.substrait.relation.Set;
2723
import io.substrait.relation.Sort;
2824
import io.substrait.relation.VirtualTableScan;
@@ -45,15 +41,8 @@
4541
*/
4642
public class ExtensionRoundtripTest extends TestBase {
4743

48-
TypeCreator R = TypeCreator.REQUIRED;
49-
50-
final SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection;
51-
52-
final SubstraitBuilder b = new SubstraitBuilder(extensions);
53-
final ExtensionCollector functionCollector = new ExtensionCollector();
54-
final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
5544
final ProtoRelConverter protoRelConverter =
56-
new StringHolderHandlingProtoRelConverter(functionCollector, extensions);
45+
new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection);
5746

5847
final Rel commonTable =
5948
b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
@@ -72,7 +61,8 @@ public class ExtensionRoundtripTest extends TestBase {
7261
.optimization(new StringHolder("REL OPTIMIZATION"))
7362
.build();
7463

75-
void verifyRoundTrip(Rel rel) {
64+
@Override
65+
protected void verifyRoundTrip(Rel rel) {
7666
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
7767
Rel relReturned = protoRelConverter.from(protoRel);
7868
assertEquals(rel, relReturned);

core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import org.junit.jupiter.params.provider.MethodSource;
3030

3131
public class GenericRoundtripTest extends TestBase {
32-
static final org.slf4j.Logger logger =
33-
org.slf4j.LoggerFactory.getLogger(GenericRoundtripTest.class);
3432

3533
static Random rand = new Random(123);
3634

core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,14 @@
11
package io.substrait.type.proto;
22

3-
import static org.junit.jupiter.api.Assertions.assertEquals;
4-
53
import io.substrait.TestBase;
6-
import io.substrait.dsl.SubstraitBuilder;
7-
import io.substrait.extension.ExtensionCollector;
8-
import io.substrait.extension.SimpleExtension;
9-
import io.substrait.relation.ProtoRelConverter;
104
import io.substrait.relation.Rel;
11-
import io.substrait.relation.RelProtoConverter;
125
import io.substrait.relation.physical.HashJoin;
13-
import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter;
14-
import io.substrait.type.TypeCreator;
156
import java.util.Arrays;
167
import java.util.List;
178
import org.junit.jupiter.api.Test;
189

1910
public class JoinRoundtripTest extends TestBase {
2011

21-
final SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection;
22-
23-
TypeCreator R = TypeCreator.REQUIRED;
24-
25-
final SubstraitBuilder b = new SubstraitBuilder(extensions);
26-
27-
final ExtensionCollector functionCollector = new ExtensionCollector();
28-
final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
29-
final ProtoRelConverter protoRelConverter =
30-
new StringHolderHandlingProtoRelConverter(functionCollector, extensions);
31-
3212
final Rel leftTable =
3313
b.namedScan(
3414
Arrays.asList("T1"),
@@ -41,12 +21,6 @@ public class JoinRoundtripTest extends TestBase {
4121
Arrays.asList("d", "e", "f"),
4222
Arrays.asList(R.FP64, R.STRING, R.I64));
4323

44-
void verifyRoundTrip(Rel rel) {
45-
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
46-
Rel relReturned = protoRelConverter.from(protoRel);
47-
assertEquals(rel, relReturned);
48-
}
49-
5024
@Test
5125
void hashJoin() {
5226
List<Integer> leftKeys = Arrays.asList(0, 1);

core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,16 @@
77
import io.substrait.expression.ExpressionCreator;
88
import io.substrait.expression.proto.ExpressionProtoConverter;
99
import io.substrait.expression.proto.ProtoExpressionConverter;
10-
import io.substrait.extension.ExtensionCollector;
11-
import io.substrait.relation.ProtoRelConverter;
1210
import java.math.BigDecimal;
1311
import org.junit.jupiter.api.Test;
1412

1513
public class LiteralRoundtripTest extends TestBase {
16-
static final org.slf4j.Logger logger =
17-
org.slf4j.LoggerFactory.getLogger(LiteralRoundtripTest.class);
1814

1915
@Test
2016
void decimal() {
2117
var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2);
2218
var to = new ExpressionProtoConverter(null, null);
23-
var from =
24-
new ProtoExpressionConverter(
25-
null,
26-
null,
27-
EMPTY_TYPE,
28-
new ProtoRelConverter(new ExtensionCollector(), defaultExtensionCollection));
19+
var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter);
2920
assertEquals(val, from.from(val.accept(to)));
3021
}
3122
}

core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,20 @@
77
import io.substrait.expression.ExpressionCreator;
88
import io.substrait.expression.FieldReference;
99
import io.substrait.expression.ImmutableFieldReference;
10-
import io.substrait.extension.ExtensionCollector;
11-
import io.substrait.extension.SimpleExtension;
1210
import io.substrait.proto.ReadRel;
1311
import io.substrait.relation.LocalFiles;
14-
import io.substrait.relation.ProtoRelConverter;
15-
import io.substrait.relation.RelProtoConverter;
1612
import io.substrait.relation.files.FileOrFiles;
1713
import io.substrait.relation.files.ImmutableFileFormat;
1814
import io.substrait.relation.files.ImmutableFileOrFiles;
1915
import io.substrait.type.ImmutableNamedStruct;
2016
import io.substrait.type.Type;
2117
import io.substrait.type.TypeCreator;
22-
import java.io.IOException;
2318
import java.util.Arrays;
2419
import org.junit.jupiter.api.Test;
2520

2621
public class LocalFilesRoundtripTest extends TestBase {
2722

28-
SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection;
29-
30-
public LocalFilesRoundtripTest() throws IOException {}
31-
3223
private void assertLocalFilesRoundtrip(FileOrFiles file) {
33-
ExtensionCollector functionCollector = new ExtensionCollector();
34-
RelProtoConverter to = new RelProtoConverter(functionCollector);
35-
ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions);
36-
3724
var builder =
3825
LocalFiles.builder()
3926
.initialSchema(
@@ -47,7 +34,7 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) {
4734
.build())
4835
.addItems(file);
4936

50-
extensions.scalarFunctions().stream()
37+
defaultExtensionCollection.scalarFunctions().stream()
5138
.filter(s -> s.name().equalsIgnoreCase("equal"))
5239
.findFirst()
5340
.map(
@@ -63,9 +50,9 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) {
6350
.ifPresent(builder::filter);
6451

6552
var localFiles = builder.build();
66-
var protoFileRel = to.toProto(localFiles);
53+
var protoFileRel = relProtoConverter.toProto(localFiles);
6754
assertTrue(protoFileRel.getRead().hasFilter());
68-
assertEquals(protoFileRel, to.toProto(from.from(protoFileRel)));
55+
assertEquals(protoFileRel, relProtoConverter.toProto(protoRelConverter.from(protoFileRel)));
6956
}
7057

7158
private ImmutableFileOrFiles.Builder setPath(

core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.junit.jupiter.params.provider.ValueSource;
1111

1212
public class TestTypeRoundtrip {
13-
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TestTypeRoundtrip.class);
1413

1514
@ParameterizedTest
1615
@ValueSource(booleans = {true, false})

0 commit comments

Comments
 (0)