From f30a4baa2ae9edaab0ea894d76217d5f7ebc1ac1 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sat, 14 Dec 2019 12:58:42 -0800 Subject: [PATCH 1/4] remove type HiveWriter --- pkg/sqlfs/flush_write_closer.go | 7 +- pkg/sqlfs/hive_writer.go | 143 -------------------------------- pkg/sqlfs/sql_writer.go | 4 +- pkg/sqlfs/sqlfile_test.go | 2 + pkg/sqlfs/table.go | 10 +-- pkg/sqlfs/writer.go | 80 +----------------- 6 files changed, 19 insertions(+), 227 deletions(-) diff --git a/pkg/sqlfs/flush_write_closer.go b/pkg/sqlfs/flush_write_closer.go index 80667b6d9f..f589723d02 100644 --- a/pkg/sqlfs/flush_write_closer.go +++ b/pkg/sqlfs/flush_write_closer.go @@ -13,7 +13,12 @@ package sqlfs -// flushWriteCloser implements io.WriteCloser. +// flushWriteCloser implements io.WriteCloser with two hooks: (1) +// flush, which is supposed to be called by Write when the internal +// buffer overflows, and (2) wrapup, which is to be called by Close. +// We need flushWriteCloser to implement the SQL writer and the Hive +// writer. For more details, please read sql_writer.go and +// hive_writer.go. type flushWriteCloser struct { buf []byte flushes int // record the count of flushes. diff --git a/pkg/sqlfs/hive_writer.go b/pkg/sqlfs/hive_writer.go index c45bc588de..692e0d0562 100644 --- a/pkg/sqlfs/hive_writer.go +++ b/pkg/sqlfs/hive_writer.go @@ -22,8 +22,6 @@ import ( "os" "os/exec" "path" - - pb "sqlflow.org/sqlflow/pkg/proto" ) func flushToCSV() (func([]byte) error, *os.File, error) { @@ -95,144 +93,3 @@ func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string) (io.WriteCl const bufSize = 32 * 1024 return newFlushWriteCloser(flush, upload, bufSize), nil } - -// HiveWriter implements io.WriteCloser. -type HiveWriter struct { - Writer - csvFile *os.File - session *pb.Session -} - -// NewHiveWriter returns a Hive Writer object -func NewHiveWriter(db *sql.DB, table string, session *pb.Session) (*HiveWriter, error) { - csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs") - if e != nil { - return nil, fmt.Errorf("create temporary csv file failed: %v", e) - } - return &HiveWriter{ - Writer: Writer{ - db: db, - table: table, - buf: make([]byte, 0, bufSize), - flushID: 0, - }, - csvFile: csvFile, - session: session}, nil -} - -// Write write bytes to sqlfs and returns (num_bytes, error) -func (w *HiveWriter) Write(p []byte) (n int, e error) { - n = 0 - for len(p) > 0 { - fill := bufSize - len(w.buf) - if fill > len(p) { - fill = len(p) - } - w.buf = append(w.buf, p[:fill]...) - p = p[fill:] - n += fill - if len(w.buf) >= bufSize { - if e := w.flush(); e != nil { - return 0, e - } - } - } - return n, nil -} - -func removeHDFSDir(hdfsPath string, hdfsEnv []string) error { - cmd := exec.Command("hdfs", "dfs", "-rm", "-r", "-f", hdfsPath) - cmd.Env = hdfsEnv - if out, err := cmd.CombinedOutput(); err != nil { - fmt.Println(string(out)) - return err - } - return nil -} - -func hdfsEnvWithCredentical(username, password string) []string { - hdfsEnv := os.Environ() - if username != "" { - hdfsEnv = append(hdfsEnv, - fmt.Sprintf("HADOOP_USER_NAME=%s", username), - fmt.Sprintf("HADOOP_USER_PASSWORD=%s", password)) - } - return hdfsEnv -} - -func createHDFSDir(hdfsPath string, hdfsEnv []string) error { - cmd := exec.Command("hdfs", "dfs", "-mkdir", "-p", hdfsPath) - cmd.Env = hdfsEnv - if _, err := cmd.CombinedOutput(); err != nil { - return err - } - return nil -} - -func uploadFileToHDFS(localFilePath, hdfsPath string, hdfsEnv []string) error { - cmd := exec.Command("hdfs", "dfs", "-copyFromLocal", localFilePath, hdfsPath) - cmd.Env = hdfsEnv - if _, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("upload local file into hdfs error: %v", err) - } - return nil -} - -func loadHDFSfileIntoTable(db *sql.DB, hdfsPath, table string) error { - query := fmt.Sprintf("LOAD DATA INPATH '%s' OVERWRITE INTO TABLE %s", hdfsPath, table) - if _, e := db.Exec(query); e != nil { - return fmt.Errorf("execute query: %s, error: %v", query, e) - } - return nil -} - -// Close the connection of the sqlfs -func (w *HiveWriter) Close() error { - if w.db == nil { - return nil - } - defer func() { - w.csvFile.Close() - os.Remove(w.csvFile.Name()) - w.db = nil - }() - - if e := w.flush(); e != nil { - return e - } - hdfsEnv := hdfsEnvWithCredentical(w.session.HdfsUser, w.session.HdfsPass) - - // 1. create a directory on HDFS - if err := createHDFSDir(w.hdfsPath(), hdfsEnv); err != nil { - return fmt.Errorf("create HDFDS dir: %s failed: %v", w.hdfsPath(), err) - } - defer removeHDFSDir(w.hdfsPath(), hdfsEnv) - - // 2. upload the local csv file to the HDFS directory - if err := uploadFileToHDFS(w.csvFile.Name(), w.hdfsPath(), hdfsEnv); err != nil { - return fmt.Errorf("upload local file to hdfs failed: %v", err) - } - - // 3. load hdfs files into hive table - if err := loadHDFSfileIntoTable(w.db, w.hdfsPath(), w.table); err != nil { - return fmt.Errorf("load hdfs file into table failed: %v", err) - } - - return nil -} - -func (w *HiveWriter) hdfsPath() string { - return fmt.Sprintf("%s/sqlfs/%s/", w.session.HiveLocation, w.table) -} - -func (w *HiveWriter) flush() error { - if len(w.buf) > 0 { - block := base64.StdEncoding.EncodeToString(w.buf) - if _, e := w.csvFile.Write([]byte(fmt.Sprintf("%d\001%s\n", w.flushID, block))); e != nil { - return fmt.Errorf("flush error, %v", e) - } - w.buf = w.buf[:0] - w.flushID++ - } - return nil -} diff --git a/pkg/sqlfs/sql_writer.go b/pkg/sqlfs/sql_writer.go index a555c154ad..5ff42b5d32 100644 --- a/pkg/sqlfs/sql_writer.go +++ b/pkg/sqlfs/sql_writer.go @@ -44,11 +44,11 @@ func noopWrapUp() error { return nil } -func newSQLWriter(db *sql.DB, driver, table string) (io.WriteCloser, error) { +func newSQLWriter(db *sql.DB, dbms, table string) (io.WriteCloser, error) { if e := dropTable(db, table); e != nil { return nil, fmt.Errorf("cannot drop table %s: %v", table, e) } - if e := createTable(db, driver, table); e != nil { + if e := createTable(db, dbms, table); e != nil { return nil, fmt.Errorf("cannot create table %s: %v", table, e) } diff --git a/pkg/sqlfs/sqlfile_test.go b/pkg/sqlfs/sqlfile_test.go index cda45a212a..fce74431a8 100644 --- a/pkg/sqlfs/sqlfile_test.go +++ b/pkg/sqlfs/sqlfile_test.go @@ -64,6 +64,8 @@ func TestWriterCreate(t *testing.T) { } func TestWriteAndRead(t *testing.T) { + const bufSize = 32 * 1024 + testDriver = getEnv("SQLFLOW_TEST_DB", "mysql") a := assert.New(t) diff --git a/pkg/sqlfs/table.go b/pkg/sqlfs/table.go index 3c62f59293..6f7f94a296 100644 --- a/pkg/sqlfs/table.go +++ b/pkg/sqlfs/table.go @@ -21,18 +21,18 @@ import ( // createTable creates a table, if it doesn't exist. If the table // name includes the database name, e.g., "db.tbl", it creates the // database if necessary. -func createTable(db *sql.DB, driver, table string) error { +func createTable(db *sql.DB, dbms, table string) error { // HIVE and ODPS don't support AUTO_INCREMENT // Hive and ODPS don't support BLOB, use BINARY instead var stmt string - if driver == "mysql" { + if dbms == "mysql" { stmt = fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT, block TEXT, PRIMARY KEY (id))", table) - } else if driver == "hive" { + } else if dbms == "hive" { stmt = fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT, block STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY \"\\001\" STORED AS TEXTFILE", table) - } else if driver == "maxcompute" { + } else if dbms == "maxcompute" { stmt = fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT, block STRING)", table) } else { - return fmt.Errorf("createTable doesn't recognize driver %s", driver) + return fmt.Errorf("createTable doesn't recognize dbms %s", dbms) } if _, e := db.Exec(stmt); e != nil { return fmt.Errorf("exec:[%s] failed: %v", stmt, e) diff --git a/pkg/sqlfs/writer.go b/pkg/sqlfs/writer.go index d03628bdd3..2752e599a9 100644 --- a/pkg/sqlfs/writer.go +++ b/pkg/sqlfs/writer.go @@ -15,88 +15,16 @@ package sqlfs import ( "database/sql" - "encoding/base64" - "fmt" "io" pb "sqlflow.org/sqlflow/pkg/proto" ) -// TEXT/STRING field support 64KB maximum storage size -const bufSize = 32 * 1024 - -// Writer implements io.WriteCloser. -type Writer struct { - db *sql.DB - table string - buf []byte - flushID int -} - // Create creates a new table or truncates an existing table and // returns a writer. -func Create(db *sql.DB, driver, table string, session *pb.Session) (io.WriteCloser, error) { - if e := dropTable(db, table); e != nil { - return nil, fmt.Errorf("create: %v", e) - } - if e := createTable(db, driver, table); e != nil { - return nil, fmt.Errorf("create: %v", e) - } - - if driver == "hive" { - w, err := NewHiveWriter(db, table, session) - if err != nil { - return nil, fmt.Errorf("create: %v", err) - } - return w, nil - } - // default writer implement - return &Writer{db, table, make([]byte, 0, bufSize), 0}, nil -} - -// Write write bytes to sqlfs and returns (num_bytes, error) -func (w *Writer) Write(p []byte) (n int, e error) { - n = 0 - for len(p) > 0 { - fill := bufSize - len(w.buf) - if fill > len(p) { - fill = len(p) - } - w.buf = append(w.buf, p[:fill]...) - p = p[fill:] - n += fill - if len(w.buf) >= bufSize { - if e = w.flush(); e != nil { - return n, fmt.Errorf("writer flush failed: %v", e) - } - } - } - return n, nil -} - -// Close the connection of the sqlfs -func (w *Writer) Close() error { - if e := w.flush(); e != nil { - return fmt.Errorf("close failed: %v", e) - } - w.db = nil // mark closed - return nil -} - -func (w *Writer) flush() error { - if w.db == nil { - return fmt.Errorf("bad database connection") - } - - if len(w.buf) > 0 { - block := base64.StdEncoding.EncodeToString(w.buf) - query := fmt.Sprintf("INSERT INTO %s (id, block) VALUES(%d, '%s')", - w.table, w.flushID, block) - if _, e := w.db.Exec(query); e != nil { - return fmt.Errorf("flush to %s, error:%v", w.table, e) - } - w.buf = w.buf[:0] - w.flushID++ +func Create(db *sql.DB, dbms, table string, session *pb.Session) (io.WriteCloser, error) { + if dbms == "hive" { + return newHiveWriter(db, session.HiveLocation, table, session.HdfsUser, session.HdfsPass) } - return nil + return newSQLWriter(db, dbms, table) } From 5abf98ab8683197152081190c9c761093eaed111 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sat, 14 Dec 2019 13:01:17 -0800 Subject: [PATCH 2/4] Separate table_test.go from sqlfile_test.go --- pkg/sqlfs/sqlfile_test.go | 11 ----------- pkg/sqlfs/table_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 pkg/sqlfs/table_test.go diff --git a/pkg/sqlfs/sqlfile_test.go b/pkg/sqlfs/sqlfile_test.go index fce74431a8..34d332ace6 100644 --- a/pkg/sqlfs/sqlfile_test.go +++ b/pkg/sqlfs/sqlfile_test.go @@ -36,17 +36,6 @@ var ( const testDatabaseName = `sqlfs_test` -func TestCreateHasDropTable(t *testing.T) { - a := assert.New(t) - - fn := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int()) - a.NoError(createTable(testDB, testDriver, fn)) - has, e := hasTable(testDB, fn) - a.NoError(e) - a.True(has) - a.NoError(dropTable(testDB, fn)) -} - func TestWriterCreate(t *testing.T) { a := assert.New(t) diff --git a/pkg/sqlfs/table_test.go b/pkg/sqlfs/table_test.go new file mode 100644 index 0000000000..80ef3c1c78 --- /dev/null +++ b/pkg/sqlfs/table_test.go @@ -0,0 +1,33 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlfs + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateHasDropTable(t *testing.T) { + a := assert.New(t) + + fn := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int()) + a.NoError(createTable(testDB, testDriver, fn)) + has, e := hasTable(testDB, fn) + a.NoError(e) + a.True(has) + a.NoError(dropTable(testDB, fn)) +} From 8673ead8f6436d808bc759a2640bdf0ff0110f3d Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sat, 14 Dec 2019 18:01:13 -0800 Subject: [PATCH 3/4] Update bufSize parameter --- pkg/sqlfs/hive_writer.go | 3 +-- pkg/sqlfs/sql_writer.go | 4 +--- pkg/sqlfs/sql_writer_test.go | 4 ++-- pkg/sqlfs/writer.go | 6 ++++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pkg/sqlfs/hive_writer.go b/pkg/sqlfs/hive_writer.go index 692e0d0562..22fc26a5e7 100644 --- a/pkg/sqlfs/hive_writer.go +++ b/pkg/sqlfs/hive_writer.go @@ -77,7 +77,7 @@ func uploadCSVFile(csv *os.File, db *sql.DB, hivePath, table, user, passwd strin } } -func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string) (io.WriteCloser, error) { +func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string, bufSize int) (io.WriteCloser, error) { if e := dropTable(db, table); e != nil { return nil, fmt.Errorf("cannot drop table %s: %v", table, e) } @@ -90,6 +90,5 @@ func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string) (io.WriteCl return nil, e } upload := uploadCSVFile(csv, db, hivePath, table, user, passwd) - const bufSize = 32 * 1024 return newFlushWriteCloser(flush, upload, bufSize), nil } diff --git a/pkg/sqlfs/sql_writer.go b/pkg/sqlfs/sql_writer.go index 5ff42b5d32..05d665d82a 100644 --- a/pkg/sqlfs/sql_writer.go +++ b/pkg/sqlfs/sql_writer.go @@ -44,14 +44,12 @@ func noopWrapUp() error { return nil } -func newSQLWriter(db *sql.DB, dbms, table string) (io.WriteCloser, error) { +func newSQLWriter(db *sql.DB, dbms, table string, bufSize int) (io.WriteCloser, error) { if e := dropTable(db, table); e != nil { return nil, fmt.Errorf("cannot drop table %s: %v", table, e) } if e := createTable(db, dbms, table); e != nil { return nil, fmt.Errorf("cannot create table %s: %v", table, e) } - - const bufSize = 32 * 1024 return newFlushWriteCloser(flushToSQLTable(db, table), noopWrapUp, bufSize), nil } diff --git a/pkg/sqlfs/sql_writer_test.go b/pkg/sqlfs/sql_writer_test.go index df45802e6f..f45402a6ad 100644 --- a/pkg/sqlfs/sql_writer_test.go +++ b/pkg/sqlfs/sql_writer_test.go @@ -33,7 +33,7 @@ func TestNewSQLWriter(t *testing.T) { a := assert.New(t) tbl := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int()) - w, e := newSQLWriter(testDB, testDriver, tbl) + w, e := newSQLWriter(testDB, testDriver, tbl, bufSize) a.NoError(e) a.NotNil(w) defer w.Close() @@ -55,7 +55,7 @@ func TestSQLWriterWriteAndRead(t *testing.T) { a := assert.New(t) tbl := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int()) - w, e := newSQLWriter(testDB, testDriver, tbl) + w, e := newSQLWriter(testDB, testDriver, tbl, bufSize) a.NoError(e) a.NotNil(w) diff --git a/pkg/sqlfs/writer.go b/pkg/sqlfs/writer.go index 2752e599a9..746f896d58 100644 --- a/pkg/sqlfs/writer.go +++ b/pkg/sqlfs/writer.go @@ -20,11 +20,13 @@ import ( pb "sqlflow.org/sqlflow/pkg/proto" ) +const bufSize = 32 * 1024 + // Create creates a new table or truncates an existing table and // returns a writer. func Create(db *sql.DB, dbms, table string, session *pb.Session) (io.WriteCloser, error) { if dbms == "hive" { - return newHiveWriter(db, session.HiveLocation, table, session.HdfsUser, session.HdfsPass) + return newHiveWriter(db, session.HiveLocation, table, session.HdfsUser, session.HdfsPass, bufSize) } - return newSQLWriter(db, dbms, table) + return newSQLWriter(db, dbms, table, bufSize) } From 37565af6e4a7cfcf208664abc324769c60a80134 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sat, 14 Dec 2019 21:18:00 -0800 Subject: [PATCH 4/4] Add bufSzie parameter to newHIveWriter --- pkg/sqlfs/hive_writer_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/sqlfs/hive_writer_test.go b/pkg/sqlfs/hive_writer_test.go index d8ca1706f2..666334a45f 100644 --- a/pkg/sqlfs/hive_writer_test.go +++ b/pkg/sqlfs/hive_writer_test.go @@ -33,7 +33,7 @@ func TestNewHiveWriter(t *testing.T) { a := assert.New(t) tbl := fmt.Sprintf("%s%d", testDatabaseName, rand.Int()) - w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "") + w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "", bufSize) a.NoError(e) a.NotNil(w) defer w.Close() @@ -55,7 +55,7 @@ func TestHiveWriterWriteAndRead(t *testing.T) { a := assert.New(t) tbl := fmt.Sprintf("%s%d", testDatabaseName, rand.Int()) - w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "") + w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "", bufSize) a.NoError(e) a.NotNil(w)