Skip to content
This repository was archived by the owner on Oct 12, 2023. It is now read-only.

Commit eaa22ef

Browse files
committed
Moving bulk copy infer
Moving the collection of the bulk copy metadata from the executors to the driver. Should prevent locking table when acquiring the schema.
1 parent 93d2382 commit eaa22ef

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

src/main/scala/com/microsoft/azure/sqldb/spark/connect/DataFrameFunctions.scala

+42-15
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
3131
import com.microsoft.sqlserver.jdbc.SQLServerBulkCopy
3232
import org.apache.spark.sql.{DataFrame, Row}
3333

34+
import scala.util.Try
35+
3436
/**
3537
* Implicit functions for DataFrame
3638
*/
@@ -43,8 +45,45 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
4345
* @param config the database connection properties and bulk copy properties
4446
* @param metadata the metadata of the columns - will be null if not specified
4547
*/
46-
def bulkCopyToSqlDB(config: Config, metadata: BulkCopyMetadata = null): Unit = {
47-
dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, metadata))
48+
def bulkCopyToSqlDB(config: Config, metadata: BulkCopyMetadata = null, createTable:Boolean = false): Unit = {
49+
// Ensuring the table exists in the DB already
50+
if(createTable) {
51+
dataFrame.limit(0).write.sqlDB(config)
52+
}
53+
54+
val actualMetadata = if(metadata == null) {
55+
getConnectionOrFail(config).recover({
56+
case e: ClassNotFoundException =>
57+
logError("JDBC driver not found in class path", e)
58+
throw e
59+
case e1: SQLException =>
60+
logError("Connection cannot be established to the database", e1)
61+
throw e1
62+
}).flatMap(conn => {
63+
inferBulkCopyMetadata(config, conn)
64+
}).recover({
65+
case e: SQLException =>
66+
logError("Column metadata not specified and cannot retrieve metadata from database", e)
67+
throw e
68+
}).get
69+
} else {
70+
metadata
71+
}
72+
dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, actualMetadata))
73+
}
74+
75+
private def getConnectionOrFail(config:Config):Try[Connection] = {
76+
Try {
77+
ConnectionUtils.getConnection(config)
78+
}
79+
}
80+
81+
private def inferBulkCopyMetadata(config: Config, connection:Connection):Try[BulkCopyMetadata] = {
82+
val dbTable = config.get[String](SqlDBConfig.DBTable).get
83+
Try {
84+
val resultSetMetaData = BulkCopyUtils.getTableColumns(dbTable, connection)
85+
BulkCopyUtils.createBulkCopyMetadata(resultSetMetaData)
86+
}
4887
}
4988

5089
/**
@@ -71,19 +110,7 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
71110
val dbTable = config.get[String](SqlDBConfig.DBTable).get
72111

73112
// Retrieves column metadata from external database table if user does not specify.
74-
val bulkCopyMetadata =
75-
if (metadata != null) {
76-
metadata
77-
} else {
78-
try {
79-
val resultSetMetaData = BulkCopyUtils.getTableColumns(dbTable, connection)
80-
BulkCopyUtils.createBulkCopyMetadata(resultSetMetaData)
81-
} catch {
82-
case e: SQLException =>
83-
logError("Column metadata not specified and cannot retrieve metadata from database", e)
84-
throw e
85-
}
86-
}
113+
val bulkCopyMetadata = metadata
87114

88115
var committed = false
89116
val supportsTransactions = BulkCopyUtils.getTransactionSupport(connection)

0 commit comments

Comments
 (0)