Skip to content

Remove type HiveWriter #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pkg/sqlfs/flush_write_closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

package sqlfs

// flushWriteCloser implements io.WriteCloser.
// flushWriteCloser implements io.WriteCloser with two hooks: (1)
// flush, which is supposed to be called by Write when the internal
// buffer overflows, and (2) wrapup, which is to be called by Close.
// We need flushWriteCloser to implement the SQL writer and the Hive
// writer. For more details, please read sql_writer.go and
// hive_writer.go.
type flushWriteCloser struct {
buf []byte
flushes int // record the count of flushes.
Expand Down
146 changes: 1 addition & 145 deletions pkg/sqlfs/hive_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import (
"os"
"os/exec"
"path"

pb "sqlflow.org/sqlflow/pkg/proto"
)

func flushToCSV() (func([]byte) error, *os.File, error) {
Expand Down Expand Up @@ -79,7 +77,7 @@ func uploadCSVFile(csv *os.File, db *sql.DB, hivePath, table, user, passwd strin
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add defer removeHDFSDir(w.hdfsPath(), hdfsEnv) at L67 to fix #1345 , maybe we can mrege this PR and I can do that in the next PR.

}

func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string) (io.WriteCloser, error) {
func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string, bufSize int) (io.WriteCloser, error) {
if e := dropTable(db, table); e != nil {
return nil, fmt.Errorf("cannot drop table %s: %v", table, e)
}
Expand All @@ -92,147 +90,5 @@ func newHiveWriter(db *sql.DB, hivePath, table, user, passwd string) (io.WriteCl
return nil, e
}
upload := uploadCSVFile(csv, db, hivePath, table, user, passwd)
const bufSize = 32 * 1024
return newFlushWriteCloser(flush, upload, bufSize), nil
}

// HiveWriter implements io.WriteCloser.
type HiveWriter struct {
Writer
csvFile *os.File
session *pb.Session
}

// NewHiveWriter returns a Hive Writer object
func NewHiveWriter(db *sql.DB, table string, session *pb.Session) (*HiveWriter, error) {
csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs")
if e != nil {
return nil, fmt.Errorf("create temporary csv file failed: %v", e)
}
return &HiveWriter{
Writer: Writer{
db: db,
table: table,
buf: make([]byte, 0, bufSize),
flushID: 0,
},
csvFile: csvFile,
session: session}, nil
}

// Write write bytes to sqlfs and returns (num_bytes, error)
func (w *HiveWriter) Write(p []byte) (n int, e error) {
n = 0
for len(p) > 0 {
fill := bufSize - len(w.buf)
if fill > len(p) {
fill = len(p)
}
w.buf = append(w.buf, p[:fill]...)
p = p[fill:]
n += fill
if len(w.buf) >= bufSize {
if e := w.flush(); e != nil {
return 0, e
}
}
}
return n, nil
}

func removeHDFSDir(hdfsPath string, hdfsEnv []string) error {
cmd := exec.Command("hdfs", "dfs", "-rm", "-r", "-f", hdfsPath)
cmd.Env = hdfsEnv
if out, err := cmd.CombinedOutput(); err != nil {
fmt.Println(string(out))
return err
}
return nil
}

func hdfsEnvWithCredentical(username, password string) []string {
hdfsEnv := os.Environ()
if username != "" {
hdfsEnv = append(hdfsEnv,
fmt.Sprintf("HADOOP_USER_NAME=%s", username),
fmt.Sprintf("HADOOP_USER_PASSWORD=%s", password))
}
return hdfsEnv
}

func createHDFSDir(hdfsPath string, hdfsEnv []string) error {
cmd := exec.Command("hdfs", "dfs", "-mkdir", "-p", hdfsPath)
cmd.Env = hdfsEnv
if _, err := cmd.CombinedOutput(); err != nil {
return err
}
return nil
}

func uploadFileToHDFS(localFilePath, hdfsPath string, hdfsEnv []string) error {
cmd := exec.Command("hdfs", "dfs", "-copyFromLocal", localFilePath, hdfsPath)
cmd.Env = hdfsEnv
if _, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("upload local file into hdfs error: %v", err)
}
return nil
}

func loadHDFSfileIntoTable(db *sql.DB, hdfsPath, table string) error {
query := fmt.Sprintf("LOAD DATA INPATH '%s' OVERWRITE INTO TABLE %s", hdfsPath, table)
if _, e := db.Exec(query); e != nil {
return fmt.Errorf("execute query: %s, error: %v", query, e)
}
return nil
}

// Close the connection of the sqlfs
func (w *HiveWriter) Close() error {
if w.db == nil {
return nil
}
defer func() {
w.csvFile.Close()
os.Remove(w.csvFile.Name())
w.db = nil
}()

if e := w.flush(); e != nil {
return e
}
hdfsEnv := hdfsEnvWithCredentical(w.session.HdfsUser, w.session.HdfsPass)

// 1. create a directory on HDFS
if err := createHDFSDir(w.hdfsPath(), hdfsEnv); err != nil {
return fmt.Errorf("create HDFDS dir: %s failed: %v", w.hdfsPath(), err)
}
defer removeHDFSDir(w.hdfsPath(), hdfsEnv)

// 2. upload the local csv file to the HDFS directory
if err := uploadFileToHDFS(w.csvFile.Name(), w.hdfsPath(), hdfsEnv); err != nil {
return fmt.Errorf("upload local file to hdfs failed: %v", err)
}

// 3. load hdfs files into hive table
if err := loadHDFSfileIntoTable(w.db, w.hdfsPath(), w.table); err != nil {
return fmt.Errorf("load hdfs file into table failed: %v", err)
}

return nil
}

func (w *HiveWriter) hdfsPath() string {
return fmt.Sprintf("%s/sqlfs/%s/", w.session.HiveLocation, w.table)
}

func (w *HiveWriter) flush() error {
if len(w.buf) > 0 {
block := base64.StdEncoding.EncodeToString(w.buf)
if _, e := w.csvFile.Write([]byte(fmt.Sprintf("%d\001%s\n", w.flushID, block))); e != nil {
return fmt.Errorf("flush error, %v", e)
}
w.buf = w.buf[:0]
w.flushID++
}
return nil
}
4 changes: 2 additions & 2 deletions pkg/sqlfs/hive_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestNewHiveWriter(t *testing.T) {
a := assert.New(t)

tbl := fmt.Sprintf("%s%d", testDatabaseName, rand.Int())
w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "")
w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "", bufSize)
a.NoError(e)
a.NotNil(w)
defer w.Close()
Expand All @@ -55,7 +55,7 @@ func TestHiveWriterWriteAndRead(t *testing.T) {
a := assert.New(t)

tbl := fmt.Sprintf("%s%d", testDatabaseName, rand.Int())
w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "")
w, e := newHiveWriter(testDB, "/hivepath", tbl, "", "", bufSize)
a.NoError(e)
a.NotNil(w)

Expand Down
6 changes: 2 additions & 4 deletions pkg/sqlfs/sql_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,12 @@ func noopWrapUp() error {
return nil
}

func newSQLWriter(db *sql.DB, driver, table string) (io.WriteCloser, error) {
func newSQLWriter(db *sql.DB, dbms, table string, bufSize int) (io.WriteCloser, error) {
if e := dropTable(db, table); e != nil {
return nil, fmt.Errorf("cannot drop table %s: %v", table, e)
}
if e := createTable(db, driver, table); e != nil {
if e := createTable(db, dbms, table); e != nil {
return nil, fmt.Errorf("cannot create table %s: %v", table, e)
}

const bufSize = 32 * 1024
return newFlushWriteCloser(flushToSQLTable(db, table), noopWrapUp, bufSize), nil
}
4 changes: 2 additions & 2 deletions pkg/sqlfs/sql_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestNewSQLWriter(t *testing.T) {
a := assert.New(t)

tbl := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int())
w, e := newSQLWriter(testDB, testDriver, tbl)
w, e := newSQLWriter(testDB, testDriver, tbl, bufSize)
a.NoError(e)
a.NotNil(w)
defer w.Close()
Expand All @@ -55,7 +55,7 @@ func TestSQLWriterWriteAndRead(t *testing.T) {
a := assert.New(t)

tbl := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int())
w, e := newSQLWriter(testDB, testDriver, tbl)
w, e := newSQLWriter(testDB, testDriver, tbl, bufSize)
a.NoError(e)
a.NotNil(w)

Expand Down
13 changes: 2 additions & 11 deletions pkg/sqlfs/sqlfile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,6 @@ var (

const testDatabaseName = `sqlfs_test`

func TestCreateHasDropTable(t *testing.T) {
a := assert.New(t)

fn := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int())
a.NoError(createTable(testDB, testDriver, fn))
has, e := hasTable(testDB, fn)
a.NoError(e)
a.True(has)
a.NoError(dropTable(testDB, fn))
}

func TestWriterCreate(t *testing.T) {
a := assert.New(t)

Expand All @@ -64,6 +53,8 @@ func TestWriterCreate(t *testing.T) {
}

func TestWriteAndRead(t *testing.T) {
const bufSize = 32 * 1024

testDriver = getEnv("SQLFLOW_TEST_DB", "mysql")
a := assert.New(t)

Expand Down
10 changes: 5 additions & 5 deletions pkg/sqlfs/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ import (
// createTable creates a table, if it doesn't exist. If the table
// name includes the database name, e.g., "db.tbl", it creates the
// database if necessary.
func createTable(db *sql.DB, driver, table string) error {
func createTable(db *sql.DB, dbms, table string) error {
// HIVE and ODPS don't support AUTO_INCREMENT
// Hive and ODPS don't support BLOB, use BINARY instead
var stmt string
if driver == "mysql" {
if dbms == "mysql" {
stmt = fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT, block TEXT, PRIMARY KEY (id))", table)
} else if driver == "hive" {
} else if dbms == "hive" {
stmt = fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT, block STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY \"\\001\" STORED AS TEXTFILE", table)
} else if driver == "maxcompute" {
} else if dbms == "maxcompute" {
stmt = fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT, block STRING)", table)
} else {
return fmt.Errorf("createTable doesn't recognize driver %s", driver)
return fmt.Errorf("createTable doesn't recognize dbms %s", dbms)
}
if _, e := db.Exec(stmt); e != nil {
return fmt.Errorf("exec:[%s] failed: %v", stmt, e)
Expand Down
33 changes: 33 additions & 0 deletions pkg/sqlfs/table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2019 The SQLFlow Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sqlfs

import (
"fmt"
"math/rand"
"testing"

"github.com/stretchr/testify/assert"
)

func TestCreateHasDropTable(t *testing.T) {
a := assert.New(t)

fn := fmt.Sprintf("%s.unittest%d", testDatabaseName, rand.Int())
a.NoError(createTable(testDB, testDriver, fn))
has, e := hasTable(testDB, fn)
a.NoError(e)
a.True(has)
a.NoError(dropTable(testDB, fn))
}
Loading