Skip to content

Commit b6703b9

Browse files
akudiyarAlexey Kuzin
authored and
Alexey Kuzin
committed
Close the client correctly in case of exceptions, improve integration test
1 parent d221189 commit b6703b9

File tree

5 files changed

+187
-28
lines changed

5 files changed

+187
-28
lines changed

src/main/scala/io/tarantool/spark/connector/TarantoolSparkException.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ import io.tarantool.driver.exceptions.TarantoolException
77
*
88
* @author Alexey Kuzin
99
*/
10-
case class TarantoolSparkException(message: String) extends TarantoolException(message) {}
10+
trait TarantoolSparkException extends TarantoolException {}
1111

1212
object TarantoolSparkException {
1313

14-
def TarantoolSparkException(message: String): TarantoolSparkException =
15-
new TarantoolSparkException(message)
14+
def apply(message: String): TarantoolSparkException =
15+
new TarantoolException(message) with TarantoolSparkException
16+
17+
def apply(exception: Throwable): TarantoolSparkException =
18+
new TarantoolException(exception) with TarantoolSparkException
19+
20+
def apply(message: String, exception: Throwable): TarantoolSparkException =
21+
new TarantoolException(message, exception) with TarantoolSparkException
1622
}

src/main/scala/io/tarantool/spark/connector/rdd/TarantoolRDD.scala

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ import io.tarantool.spark.connector.config.{ReadConfig, TarantoolConfig}
99
import io.tarantool.spark.connector.connection.TarantoolConnection
1010
import io.tarantool.spark.connector.partition.TarantoolPartition
1111
import io.tarantool.spark.connector.rdd.converter.{FunctionBasedTupleConverter, TupleConverter}
12-
import io.tarantool.spark.connector.util.ScalaToJavaHelper.{toJavaConsumer, toJavaFunction}
12+
import io.tarantool.spark.connector.util.ScalaToJavaHelper.{
13+
toJavaBiFunction,
14+
toJavaConsumer,
15+
toJavaFunction
16+
}
1317
import io.tarantool.spark.connector.util.TarantoolCursorIterator
1418
import org.apache.spark.rdd.RDD
1519
import org.apache.spark.sql.tarantool.MapFunctions.rowToTuple
@@ -108,11 +112,12 @@ class TarantoolRDD[R] private[spark] (
108112
}
109113
.toArray[CompletableFuture[_]]
110114

111-
CompletableFuture
112-
.allOf(allFutures: _*)
113-
.thenAccept(toJavaConsumer {
114-
_: Void =>
115-
try {
115+
var savedException: Throwable = null
116+
try {
117+
CompletableFuture
118+
.allOf(allFutures: _*)
119+
.handle(toJavaBiFunction {
120+
(_: Void, exception: Throwable) =>
116121
if (failedRowsExceptions.nonEmpty) {
117122
val sw: StringWriter = new StringWriter()
118123
val pw: PrintWriter = new PrintWriter(sw)
@@ -121,19 +126,34 @@ class TarantoolRDD[R] private[spark] (
121126
pw.append("\n\n")
122127
exception.printStackTrace(pw)
123128
}
124-
throw new TarantoolSparkException("Dataset write failed: " + sw.toString)
129+
savedException = TarantoolSparkException("Dataset write failed: " + sw.toString)
130+
logError(savedException.getMessage)
125131
} finally {
126132
pw.close()
127133
}
128134
} else {
129-
logInfo(s"Dataset write success, $rowCount rows written")
135+
if (Option(exception).isDefined) {
136+
savedException = exception
137+
logError("Dataset write failed: ", savedException)
138+
} else {
139+
logInfo(s"Dataset write success, $rowCount rows written")
140+
}
130141
}
131-
} finally {
132-
client.close()
133-
}
134-
})
135-
.get()
136-
.asInstanceOf[Unit]
142+
null
143+
})
144+
.join()
145+
} catch {
146+
case throwable: Throwable => savedException = throwable
147+
} finally {
148+
client.close()
149+
}
150+
151+
if (Option(savedException).isDefined) {
152+
savedException match {
153+
case e: RuntimeException => throw e
154+
case e: Any => throw TarantoolSparkException(e)
155+
}
156+
}
137157
}
138158
)
139159
}

src/main/scala/io/tarantool/spark/connector/util/ScalaToJavaHelper.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
package io.tarantool.spark.connector.util
22

3-
import java.util.function.{Consumer => JConsumer}
4-
import java.util.function.{Function => JFunction}
5-
import java.util.function.{Supplier => JSupplier}
3+
import java.util.function.{
4+
BiFunction => JBiFunction,
5+
Consumer => JConsumer,
6+
Function => JFunction,
7+
Supplier => JSupplier
8+
}
69
import scala.reflect.ClassTag
710

811
/**
@@ -29,6 +32,14 @@ object ScalaToJavaHelper {
2932
override def apply(t: T1): R = f.apply(t)
3033
}
3134

35+
/**
36+
* Converts a Scala {@link Function2} to a Java {@link java.util.function.BiFunction}
37+
*/
38+
def toJavaBiFunction[T1, T2, R](f: (T1, T2) => R): JBiFunction[T1, T2, R] =
39+
new JBiFunction[T1, T2, R] {
40+
override def apply(t1: T1, t2: T2): R = f.apply(t1, t2)
41+
}
42+
3243
/**
3344
* Converts a Scala {@link Function1} to a Java {@link java.util.function.Function}
3445
*/

src/test/resources/test_teardown.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ local crud = require('crud')
22

33
local function truncate_space(space)
44
local ok, err
5-
ok, err = crud.truncate('test_space')
5+
ok, err = crud.truncate(space)
66
if (not ok) then
77
error("Failed to truncate space '" .. space .. "', error: " .. tostring(err))
88
end
99
end
1010

1111
truncate_space('test_space')
12+
truncate_space('orders')

src/test/scala/io/tarantool/spark/connector/integration/TarantoolSparkWriteClusterTest.scala

Lines changed: 128 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import io.tarantool.driver.api.conditions.Conditions
44
import io.tarantool.driver.api.tuple.{DefaultTarantoolTupleFactory, TarantoolTuple}
55
import io.tarantool.driver.mappers.DefaultMessagePackMapperFactory
66
import io.tarantool.spark.connector.toSparkContextFunctions
7-
import org.apache.spark.sql.{Encoders, Row}
7+
import org.apache.spark.SparkException
8+
import org.apache.spark.sql.{Encoders, Row, SaveMode}
89
import org.scalatest.funsuite.AnyFunSuite
910
import org.scalatest.matchers.should.Matchers
1011
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
@@ -27,24 +28,25 @@ class TarantoolSparkWriteClusterTest
2728

2829
private val orderSchema = Encoders.product[Order].schema
2930

30-
test("should write a list of objects to the space") {
31+
test("should write a dataset of objects to the specified space with different modes") {
3132

3233
val orders = Range(1, 10).map(i => Order(i))
3334

34-
val df = spark.createDataFrame(
35+
var df = spark.createDataFrame(
3536
spark.sparkContext.parallelize(orders.map(order => order.asRow())),
3637
orderSchema
3738
)
3839

40+
// Insert, the partition is empty at first
3941
df.write
4042
.format("org.apache.spark.sql.tarantool")
41-
.mode("overwrite")
43+
.mode(SaveMode.Append)
4244
.option("tarantool.space", SPACE_NAME)
4345
.save()
4446

45-
val actual = spark.sparkContext.tarantoolSpace(SPACE_NAME, Conditions.any()).collect()
46-
47+
var actual = spark.sparkContext.tarantoolSpace(SPACE_NAME, Conditions.any()).collect()
4748
actual.length should be > 0
49+
4850
val sorted = actual.sorted[TarantoolTuple](new Ordering[TarantoolTuple]() {
4951
override def compare(x: TarantoolTuple, y: TarantoolTuple): Int =
5052
x.getInteger("id").compareTo(y.getInteger("id"))
@@ -70,20 +72,139 @@ class TarantoolSparkWriteClusterTest
7072
)
7173
actualItem.getBoolean("cleared") should equal(expectedItem.getBoolean(6))
7274
}
75+
76+
// Replace
77+
df = spark.createDataFrame(
78+
spark.sparkContext.parallelize(
79+
orders
80+
.map(order => order.changeOrderType(order.orderType + "222"))
81+
.map(order => order.asRow())
82+
),
83+
orderSchema
84+
)
85+
86+
df.write
87+
.format("org.apache.spark.sql.tarantool")
88+
.mode(SaveMode.Overwrite)
89+
.option("tarantool.space", SPACE_NAME)
90+
.save()
91+
92+
actual = spark.sparkContext.tarantoolSpace(SPACE_NAME, Conditions.any()).collect()
93+
actual.length should be > 0
94+
95+
actual.foreach(item => item.getString("order_type") should endWith("222"))
96+
97+
// Second insert with the same IDs produces an exception
98+
var thrownException: Throwable = the[SparkException] thrownBy {
99+
df.write
100+
.format("org.apache.spark.sql.tarantool")
101+
.mode(SaveMode.Append)
102+
.option("tarantool.space", SPACE_NAME)
103+
.save()
104+
}
105+
thrownException.getMessage should include("Duplicate key exists")
106+
107+
// ErrorIfExists mode checks that partition is empty and provides an exception if it is not
108+
thrownException = the[IllegalStateException] thrownBy {
109+
df.write
110+
.format("org.apache.spark.sql.tarantool")
111+
.mode(SaveMode.ErrorIfExists)
112+
.option("tarantool.space", SPACE_NAME)
113+
.save()
114+
}
115+
thrownException.getMessage should include("already exists in Tarantool")
116+
117+
// Clear the data and check that they are written in ErrorIfExists mode
118+
container.executeScript("test_teardown.lua").get()
119+
120+
df = spark.createDataFrame(
121+
spark.sparkContext.parallelize(
122+
orders
123+
.map(order => order.changeOrderType(order.orderType + "333"))
124+
.map(order => order.asRow())
125+
),
126+
orderSchema
127+
)
128+
129+
df.write
130+
.format("org.apache.spark.sql.tarantool")
131+
.mode(SaveMode.ErrorIfExists)
132+
.option("tarantool.space", SPACE_NAME)
133+
.save()
134+
135+
actual = spark.sparkContext.tarantoolSpace(SPACE_NAME, Conditions.any()).collect()
136+
actual.length should be > 0
137+
138+
actual.foreach(item => item.getString("order_type") should endWith("333"))
139+
140+
// Check that new data are not written in Ignore mode if the partition is not empty
141+
df = spark.createDataFrame(
142+
spark.sparkContext.parallelize(
143+
orders
144+
.map(order => order.changeOrderType(order.orderType + "444"))
145+
.map(order => order.asRow())
146+
),
147+
orderSchema
148+
)
149+
150+
df.write
151+
.format("org.apache.spark.sql.tarantool")
152+
.mode(SaveMode.Ignore)
153+
.option("tarantool.space", SPACE_NAME)
154+
.save()
155+
156+
actual = spark.sparkContext.tarantoolSpace(SPACE_NAME, Conditions.any()).collect()
157+
actual.length should be > 0
158+
159+
actual.foreach(item => item.getString("order_type") should endWith("333"))
160+
161+
// Clear the data and check if they are written in Ignore mode
162+
container.executeScript("test_teardown.lua").get()
163+
164+
df.write
165+
.format("org.apache.spark.sql.tarantool")
166+
.mode(SaveMode.Ignore)
167+
.option("tarantool.space", SPACE_NAME)
168+
.save()
169+
170+
actual = spark.sparkContext.tarantoolSpace(SPACE_NAME, Conditions.any()).collect()
171+
actual.length should be > 0
172+
173+
actual.foreach(item => item.getString("order_type") should endWith("444"))
73174
}
74175

176+
test("should throw an exception if the space name is not specified") {
177+
assertThrows[IllegalArgumentException] {
178+
val orders = Range(1, 10).map(i => Order(i))
179+
180+
val df = spark.createDataFrame(
181+
spark.sparkContext.parallelize(orders.map(order => order.asRow())),
182+
orderSchema
183+
)
184+
185+
df.write
186+
.format("org.apache.spark.sql.tarantool")
187+
.mode(SaveMode.Overwrite)
188+
.save()
189+
}
190+
}
75191
}
76192

77193
case class Order(
78194
id: Int,
79195
bucketId: Int,
80-
orderType: String,
196+
var orderType: String,
81197
orderValue: BigDecimal,
82198
orderItems: List[Int],
83199
options: Map[String, String],
84200
cleared: Boolean
85201
) {
86202

203+
def changeOrderType(newOrderType: String): Order = {
204+
orderType = newOrderType
205+
this
206+
}
207+
87208
def asRow(): Row =
88209
Row(id, bucketId, orderType, orderValue, orderItems, options, cleared)
89210

0 commit comments

Comments
 (0)