@@ -31,6 +31,8 @@ import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
3131import com .microsoft .sqlserver .jdbc .SQLServerBulkCopy
3232import org .apache .spark .sql .{DataFrame , Row }
3333
34+ import scala .util .Try
35+
3436/**
3537 * Implicit functions for DataFrame
3638 */
@@ -43,8 +45,45 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
4345 * @param config the database connection properties and bulk copy properties
4446 * @param metadata the metadata of the columns - will be null if not specified
4547 */
46- def bulkCopyToSqlDB (config : Config , metadata : BulkCopyMetadata = null ): Unit = {
47- dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, metadata))
48+ def bulkCopyToSqlDB (config : Config , metadata : BulkCopyMetadata = null , createTable: Boolean = false ): Unit = {
49+ // Ensuring the table exists in the DB already
50+ if (createTable) {
51+ dataFrame.limit(0 ).write.sqlDB(config)
52+ }
53+
54+ val actualMetadata = if (metadata == null ) {
55+ getConnectionOrFail(config).recover({
56+ case e : ClassNotFoundException =>
57+ logError(" JDBC driver not found in class path" , e)
58+ throw e
59+ case e1 : SQLException =>
60+ logError(" Connection cannot be established to the database" , e1)
61+ throw e1
62+ }).flatMap(conn => {
63+ inferBulkCopyMetadata(config, conn)
64+ }).recover({
65+ case e : SQLException =>
66+ logError(" Column metadata not specified and cannot retrieve metadata from database" , e)
67+ throw e
68+ }).get
69+ } else {
70+ metadata
71+ }
72+ dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, actualMetadata))
73+ }
74+
75+ private def getConnectionOrFail (config: Config ): Try [Connection ] = {
76+ Try {
77+ ConnectionUtils .getConnection(config)
78+ }
79+ }
80+
81+ private def inferBulkCopyMetadata (config : Config , connection: Connection ): Try [BulkCopyMetadata ] = {
82+ val dbTable = config.get[String ](SqlDBConfig .DBTable ).get
83+ Try {
84+ val resultSetMetaData = BulkCopyUtils .getTableColumns(dbTable, connection)
85+ BulkCopyUtils .createBulkCopyMetadata(resultSetMetaData)
86+ }
4887 }
4988
5089 /**
@@ -71,19 +110,7 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
71110 val dbTable = config.get[String ](SqlDBConfig .DBTable ).get
72111
73112 // Retrieves column metadata from external database table if user does not specify.
74- val bulkCopyMetadata =
75- if (metadata != null ) {
76- metadata
77- } else {
78- try {
79- val resultSetMetaData = BulkCopyUtils .getTableColumns(dbTable, connection)
80- BulkCopyUtils .createBulkCopyMetadata(resultSetMetaData)
81- } catch {
82- case e : SQLException =>
83- logError(" Column metadata not specified and cannot retrieve metadata from database" , e)
84- throw e
85- }
86- }
113+ val bulkCopyMetadata = metadata
87114
88115 var committed = false
89116 val supportsTransactions = BulkCopyUtils .getTransactionSupport(connection)
0 commit comments