Skip to content
This repository was archived by the owner on Oct 12, 2023. It is now read-only.

Commit fa1cf19

Browse files
Merge pull request #6 from slyons/master
Optimized the column metadata lookup to happen just once. This greatly improves performance with larger bulk loads and larger number of partitions.
2 parents 93d2382 + eaa22ef commit fa1cf19

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

src/main/scala/com/microsoft/azure/sqldb/spark/connect/DataFrameFunctions.scala

+42-15
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
3131
import com.microsoft.sqlserver.jdbc.SQLServerBulkCopy
3232
import org.apache.spark.sql.{DataFrame, Row}
3333

34+
import scala.util.Try
35+
3436
/**
3537
* Implicit functions for DataFrame
3638
*/
@@ -43,8 +45,45 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
4345
* @param config the database connection properties and bulk copy properties
4446
* @param metadata the metadata of the columns - will be null if not specified
4547
*/
46-
def bulkCopyToSqlDB(config: Config, metadata: BulkCopyMetadata = null): Unit = {
47-
dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, metadata))
48+
def bulkCopyToSqlDB(config: Config, metadata: BulkCopyMetadata = null, createTable:Boolean = false): Unit = {
49+
// Ensuring the table exists in the DB already
50+
if(createTable) {
51+
dataFrame.limit(0).write.sqlDB(config)
52+
}
53+
54+
val actualMetadata = if(metadata == null) {
55+
getConnectionOrFail(config).recover({
56+
case e: ClassNotFoundException =>
57+
logError("JDBC driver not found in class path", e)
58+
throw e
59+
case e1: SQLException =>
60+
logError("Connection cannot be established to the database", e1)
61+
throw e1
62+
}).flatMap(conn => {
63+
inferBulkCopyMetadata(config, conn)
64+
}).recover({
65+
case e: SQLException =>
66+
logError("Column metadata not specified and cannot retrieve metadata from database", e)
67+
throw e
68+
}).get
69+
} else {
70+
metadata
71+
}
72+
dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, actualMetadata))
73+
}
74+
75+
private def getConnectionOrFail(config:Config):Try[Connection] = {
76+
Try {
77+
ConnectionUtils.getConnection(config)
78+
}
79+
}
80+
81+
private def inferBulkCopyMetadata(config: Config, connection:Connection):Try[BulkCopyMetadata] = {
82+
val dbTable = config.get[String](SqlDBConfig.DBTable).get
83+
Try {
84+
val resultSetMetaData = BulkCopyUtils.getTableColumns(dbTable, connection)
85+
BulkCopyUtils.createBulkCopyMetadata(resultSetMetaData)
86+
}
4887
}
4988

5089
/**
@@ -71,19 +110,7 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
71110
val dbTable = config.get[String](SqlDBConfig.DBTable).get
72111

73112
// Retrieves column metadata from external database table if user does not specify.
74-
val bulkCopyMetadata =
75-
if (metadata != null) {
76-
metadata
77-
} else {
78-
try {
79-
val resultSetMetaData = BulkCopyUtils.getTableColumns(dbTable, connection)
80-
BulkCopyUtils.createBulkCopyMetadata(resultSetMetaData)
81-
} catch {
82-
case e: SQLException =>
83-
logError("Column metadata not specified and cannot retrieve metadata from database", e)
84-
throw e
85-
}
86-
}
113+
val bulkCopyMetadata = metadata
87114

88115
var committed = false
89116
val supportsTransactions = BulkCopyUtils.getTransactionSupport(connection)

0 commit comments

Comments
 (0)