Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmu committed Jan 3, 2025
1 parent 682352d commit 0e6b767
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 63 deletions.
10 changes: 5 additions & 5 deletions core/dataloader/csvdataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (cdl *CsvDataLoader) IsReadOnly() bool {
}

type csvRowIter struct {
cdl *CsvDataLoader
cdl *CsvDataLoader
reader *csvReader
}

Expand All @@ -181,14 +181,14 @@ func (c csvRowIter) Next(ctx *sql.Context) (sql.Row, error) {
if err != nil {
return nil, err
}

// TODO: this isn't the best way to handle the count of rows, something like a RowUpdateAccumulator would be better
if hasNext {
c.cdl.results.RowsLoaded++
} else {
return nil, io.EOF
}

return row, nil
}

Expand All @@ -206,6 +206,6 @@ func (cdl *CsvDataLoader) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, err
if err != nil {
return nil, err
}

return &csvRowIter{cdl: cdl, reader: csvReader}, nil
}
}
6 changes: 3 additions & 3 deletions core/dataloader/dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ import (
// with the incomplete record.
type DataLoader interface {
sql.ExecSourceRel

// SetNextDataChunk sets the next data chunk to be processed by the DataLoader. Data records
// are not guaranteed to start and end cleanly on chunk boundaries, so implementations must recognize incomplete
// records and save them to prepend on the next processed chunk.
SetNextDataChunk(ctx *sql.Context, data *bufio.Reader) error

// Finish finalizes the current load operation and cleans up any resources used. Implementations should check that
// Finish finalizes the current load operation and cleans up any resources used. Implementations should check that
// the last call to LoadChunk did not end with an incomplete record and return an error to the caller if so. The
// returned LoadDataResults describe the load operation, including how many rows were inserted.
Finish(ctx *sql.Context) (*LoadDataResults, error)
Expand Down Expand Up @@ -64,7 +64,7 @@ func getColumnTypes(colNames []string, sch sql.Schema) ([]*types.DoltgresType, s
if !ok {
return nil, nil, fmt.Errorf("unsupported column type: name: %s, type: %T", col.Name, col.Type)
}

reducedSch[i] = col
}

Expand Down
15 changes: 8 additions & 7 deletions core/dataloader/tabdataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import (
"io"
"strings"

"github.com/dolthub/doltgresql/server/types"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/server/types"
)

const defaultTextDelimiter = "\t"
Expand Down Expand Up @@ -49,7 +50,7 @@ func NewTabularDataLoader(colNames []string, tableSch sql.Schema, delimiterChar,
if err != nil {
return nil, err
}

if delimiterChar == "" {
delimiterChar = defaultTextDelimiter
}
Expand Down Expand Up @@ -113,15 +114,15 @@ func (tdl *TabularDataLoader) nextRow(ctx *sql.Context, data *bufio.Reader) (sql
if len(line) == 0 {
continue
}

// Split the values by the delimiter, ensuring the correct number of values have been read
values := strings.Split(line, tdl.delimiterChar)
if len(values) > len(tdl.colTypes) {
return nil, false, fmt.Errorf("extra data after last expected column")
} else if len(values) < len(tdl.colTypes) {
return nil, false, fmt.Errorf(`missing data for column "%s"`, tdl.sch[len(values)].Name)
}

// Cast the values using I/O input
row := make(sql.Row, len(tdl.colTypes))
for i := range tdl.colTypes {
Expand All @@ -134,7 +135,7 @@ func (tdl *TabularDataLoader) nextRow(ctx *sql.Context, data *bufio.Reader) (sql
}
}
}

return row, true, nil
}
}
Expand Down Expand Up @@ -182,7 +183,7 @@ func (tdl *TabularDataLoader) IsReadOnly() bool {
}

type tabularRowIter struct {
tdl *TabularDataLoader
tdl *TabularDataLoader
reader *bufio.Reader
}

Expand All @@ -208,4 +209,4 @@ func (t tabularRowIter) Close(context *sql.Context) error {

func (tdl *TabularDataLoader) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
return &tabularRowIter{tdl: tdl, reader: tdl.nextDataChunk}, nil
}
}
12 changes: 6 additions & 6 deletions server/analyzer/assign_insert_casts.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"fmt"
"strings"

"github.com/dolthub/doltgresql/server/node"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/expression"
Expand All @@ -27,6 +26,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"

pgexprs "github.com/dolthub/doltgresql/server/expression"
"github.com/dolthub/doltgresql/server/node"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

Expand All @@ -36,12 +36,12 @@ func AssignInsertCasts(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
if !ok {
return node, transform.SameTree, nil
}

// We have some sources that are already postgres native, so skip them
if isDoltgresNativeSource(insertInto.Destination, insertInto.Source) {
return insertInto, transform.SameTree, nil
}

// First we'll make a map for each column, so that it's easier to match a name to a type. We also ensure that the
// types use Doltgres types, as casts rely on them. At this point, we shouldn't have any GMS types floating around
// anymore, so no need to include a lot of additional code to handle them.
Expand All @@ -53,7 +53,7 @@ func AssignInsertCasts(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
}
destinationNameToType[strings.ToLower(col.Name)] = colType
}

// Create the destination type slice that will match each inserted column
destinationTypes := make([]*pgtypes.DoltgresType, len(insertInto.ColumnNames))
for i, colName := range insertInto.ColumnNames {
Expand All @@ -62,7 +62,7 @@ func AssignInsertCasts(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
return nil, transform.NewTree, fmt.Errorf("INSERT: cannot find destination column with name `%s`", colName)
}
}

// Replace expressions with casts as needed
if values, ok := insertInto.Source.(*plan.Values); ok {
// Values do not return the correct Schema since each row may contain different types, so we must handle it differently
Expand Down Expand Up @@ -134,7 +134,7 @@ func isDoltgresNativeSource(dest sql.Node, source sql.Node) bool {
return false
}
}

switch source.(type) {
case *node.CopyFrom:
return true
Expand Down
2 changes: 1 addition & 1 deletion server/analyzer/type_sanitizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope
return nil, transform.NewTree, fmt.Errorf("default values must have a non-GMS OutType: `%s`", expr.OutType.String())
}
if !outType.Equals(defaultExprType) {
// TODO (next): this
// TODO (next): this
defaultExpr = pgexprs.NewAssignmentCast(defaultExpr, defaultExprType, outType)
}
newDefault, err := sql.NewColumnDefaultValue(defaultExpr, outType, expr.Literal, expr.Parenthesized, expr.ReturnNil)
Expand Down
10 changes: 5 additions & 5 deletions server/ast/copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func nodeCopyFrom(ctx *Context, node *tree.CopyFrom) (vitess.Statement, error) {
return nil, fmt.Errorf("COPY FROM does not support format BINARY")
}

// We start by creating a stub insert statement for the COPY FROM statement, which we will use to build a basic
// INSERT plan for. At runtime we will swap out the bogus row values for our actual data read from STDIN.
// We start by creating a stub insert statement for the COPY FROM statement, which we will use to build a basic
// INSERT plan for. At runtime we will swap out the bogus row values for our actual data read from STDIN.
var columns []vitess.ColIdent
if len(node.Columns) > 0 {
columns = make([]vitess.ColIdent, len(node.Columns))
Expand All @@ -53,7 +53,7 @@ func nodeCopyFrom(ctx *Context, node *tree.CopyFrom) (vitess.Statement, error) {
for i := range columns {
stubValues[0][i] = &vitess.NullVal{}
}

return vitess.InjectedStatement{
Statement: pgnodes.NewCopyFrom(
node.Table.Catalog(),
Expand All @@ -69,8 +69,8 @@ func nodeCopyFrom(ctx *Context, node *tree.CopyFrom) (vitess.Statement, error) {
Action: vitess.InsertStr,
Table: tableName,
Columns: columns,
Rows: &vitess.AliasedValues{
Values: stubValues,
Rows: &vitess.AliasedValues{
Values: stubValues,
},
},
),
Expand Down
22 changes: 11 additions & 11 deletions server/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ import (

"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/doltgresql/core/dataloader"
psql "github.com/dolthub/doltgresql/postgres/parser/parser/sql"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
Expand All @@ -44,7 +41,10 @@ import (
"github.com/mitchellh/go-ps"
"github.com/sirupsen/logrus"

"github.com/dolthub/doltgresql/core/dataloader"
"github.com/dolthub/doltgresql/postgres/parser/parser"
psql "github.com/dolthub/doltgresql/postgres/parser/parser/sql"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/doltgresql/server/ast"
"github.com/dolthub/doltgresql/server/node"
)
Expand Down Expand Up @@ -651,15 +651,15 @@ func (h *ConnectionHandler) copyFromFileQuery(stmt *node.CopyFrom) error {
copyState := &copyFromStdinState{
copyFromStdinNode: stmt,
}

// TODO: security check for file path
// TODO: Privilege Checking: https://www.postgresql.org/docs/15/sql-copy.html
f, err := os.Open(stmt.File)
if err != nil {
return err
}
defer f.Close()

_, _, err = h.handleCopyDataHelper(copyState, f)
if err != nil {
return err
Expand Down Expand Up @@ -693,7 +693,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(copyState *copyFromStdinState,
if copyFromStdinNode == nil {
return false, false, fmt.Errorf("no COPY FROM STDIN node found")
}

// we build an insert node to use for the full insert plan, for which the copy from node will be the row source
builder := planbuilder.New(sqlCtx, h.doltgresHandler.e.Analyzer.Catalog, nil, psql.NewPostgresParser())
node, flags, err := builder.BindOnly(copyFromStdinNode.InsertStub, "", nil)
Expand Down Expand Up @@ -729,7 +729,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(copyState *copyFromStdinState,
return false, false, err
}

// we have to set the data loader on the copyFrom node before we analyze it, because we need the loader's
// we have to set the data loader on the copyFrom node before we analyze it, because we need the loader's
// schema to analyze
copyState.copyFromStdinNode.DataLoader = dataLoader

Expand All @@ -739,7 +739,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(copyState *copyFromStdinState,
if err != nil {
return false, false, err
}

copyState.insertNode = analyzedNode
copyState.dataLoader = dataLoader
}
Expand All @@ -754,7 +754,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(copyState *copyFromStdinState,
if err != nil {
return false, false, err
}

// We expect to see more CopyData messages until we see either a CopyDone or CopyFail message, so
// return false for endOfMessages
return false, false, nil
Expand All @@ -772,7 +772,7 @@ func getInsertableTable(node sql.Node) sql.InsertableTable {
}
return true
})

return tbl
}

Expand Down Expand Up @@ -863,7 +863,7 @@ func startTransactionIfNecessary(ctx *sql.Context) error {
if _, err := doltSession.StartTransaction(ctx, sql.ReadWrite); err != nil {
return err
}

// When we start a transaction ourselves, we must ignore auto-commit settings for transaction
ctx.SetIgnoreAutoCommit(true)
}
Expand Down
25 changes: 13 additions & 12 deletions server/node/copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ import (
"fmt"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/doltgresql/core/dataloader"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/go-mysql-server/sql"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core/dataloader"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
)

// CopyFrom handles the COPY ... FROM ... statement.
Expand All @@ -41,13 +42,13 @@ var _ sql.ExecSourceRel = (*CopyFrom)(nil)

// NewCopyFrom returns a new *CopyFrom.
func NewCopyFrom(
databaseName string,
tableName doltdb.TableName,
options tree.CopyOptions,
fileName string,
stdin bool,
columns tree.NameList,
insertStub *vitess.Insert,
databaseName string,
tableName doltdb.TableName,
options tree.CopyOptions,
fileName string,
stdin bool,
columns tree.NameList,
insertStub *vitess.Insert,
) *CopyFrom {
switch options.CopyFormat {
case tree.CopyFormatCsv, tree.CopyFormatText:
Expand All @@ -65,7 +66,7 @@ func NewCopyFrom(
Stdin: stdin,
Columns: columns,
CopyOptions: options,
InsertStub: insertStub,
InsertStub: insertStub,
}
}

Expand All @@ -91,7 +92,7 @@ func (cf *CopyFrom) RowIter(ctx *sql.Context, r sql.Row) (_ sql.RowIter, err err

// Schema implements the interface sql.ExecSourceRel.
func (cf *CopyFrom) Schema() sql.Schema {
// For Parse calls, we need access to the schema before we have a DataLoader created, so return a stub schema.
// For Parse calls, we need access to the schema before we have a DataLoader created, so return a stub schema.
if cf.DataLoader == nil {
return nil
}
Expand Down Expand Up @@ -121,4 +122,4 @@ func (cf *CopyFrom) WithResolvedChildren(children []any) (any, error) {
return nil, ErrVitessChildCount.New(0, len(children))
}
return cf, nil
}
}
2 changes: 1 addition & 1 deletion testing/dataloader/csvdataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func TestCsvDataLoader(t *testing.T) {
{int64(2), int64(200), "bash"},
}, rows)
})

// Tests when a PSV (i.e. delimiter='|') record is split across two chunks of data,
// and a header row is present.
t.Run("delimiter='|', record split across two chunks, with header", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 0e6b767

Please sign in to comment.