@@ -31,6 +31,8 @@ import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
31
31
import com .microsoft .sqlserver .jdbc .SQLServerBulkCopy
32
32
import org .apache .spark .sql .{DataFrame , Row }
33
33
34
+ import scala .util .Try
35
+
34
36
/**
35
37
* Implicit functions for DataFrame
36
38
*/
@@ -43,8 +45,45 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
43
45
* @param config the database connection properties and bulk copy properties
44
46
* @param metadata the metadata of the columns - will be null if not specified
45
47
*/
46
- def bulkCopyToSqlDB (config : Config , metadata : BulkCopyMetadata = null ): Unit = {
47
- dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, metadata))
48
+ def bulkCopyToSqlDB (config : Config , metadata : BulkCopyMetadata = null , createTable: Boolean = false ): Unit = {
49
+ // Ensuring the table exists in the DB already
50
+ if (createTable) {
51
+ dataFrame.limit(0 ).write.sqlDB(config)
52
+ }
53
+
54
+ val actualMetadata = if (metadata == null ) {
55
+ getConnectionOrFail(config).recover({
56
+ case e : ClassNotFoundException =>
57
+ logError(" JDBC driver not found in class path" , e)
58
+ throw e
59
+ case e1 : SQLException =>
60
+ logError(" Connection cannot be established to the database" , e1)
61
+ throw e1
62
+ }).flatMap(conn => {
63
+ inferBulkCopyMetadata(config, conn)
64
+ }).recover({
65
+ case e : SQLException =>
66
+ logError(" Column metadata not specified and cannot retrieve metadata from database" , e)
67
+ throw e
68
+ }).get
69
+ } else {
70
+ metadata
71
+ }
72
+ dataFrame.foreachPartition(iterator => bulkCopy(config, iterator, actualMetadata))
73
+ }
74
+
75
+ private def getConnectionOrFail (config: Config ): Try [Connection ] = {
76
+ Try {
77
+ ConnectionUtils .getConnection(config)
78
+ }
79
+ }
80
+
81
+ private def inferBulkCopyMetadata (config : Config , connection: Connection ): Try [BulkCopyMetadata ] = {
82
+ val dbTable = config.get[String ](SqlDBConfig .DBTable ).get
83
+ Try {
84
+ val resultSetMetaData = BulkCopyUtils .getTableColumns(dbTable, connection)
85
+ BulkCopyUtils .createBulkCopyMetadata(resultSetMetaData)
86
+ }
48
87
}
49
88
50
89
/**
@@ -71,19 +110,7 @@ private[spark] case class DataFrameFunctions[T](@transient dataFrame: DataFrame)
71
110
val dbTable = config.get[String ](SqlDBConfig .DBTable ).get
72
111
73
112
// Retrieves column metadata from external database table if user does not specify.
74
- val bulkCopyMetadata =
75
- if (metadata != null ) {
76
- metadata
77
- } else {
78
- try {
79
- val resultSetMetaData = BulkCopyUtils .getTableColumns(dbTable, connection)
80
- BulkCopyUtils .createBulkCopyMetadata(resultSetMetaData)
81
- } catch {
82
- case e : SQLException =>
83
- logError(" Column metadata not specified and cannot retrieve metadata from database" , e)
84
- throw e
85
- }
86
- }
113
+ val bulkCopyMetadata = metadata
87
114
88
115
var committed = false
89
116
val supportsTransactions = BulkCopyUtils .getTransactionSupport(connection)
0 commit comments