Skip to content

Commit

Permalink
feat: Implement TransformSchema support. (#1838)
Browse files Browse the repository at this point in the history
As discussed offline, Transformer plugins are going to need to explicitly provide a synchronous function to transform a schema.

Previously, I implemented a method by which this wasn't necessary, by reusing the existing transform function, but the problem is that we cannot assume that the transform function is not gonna swallow the initial message of a given table, or that it's gonna behave properly with an empty record with a schema.
  • Loading branch information
marianogappa authored Jul 31, 2024
1 parent d2c5c7b commit 30875d6
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 3 deletions.
5 changes: 5 additions & 0 deletions examples/simple_plugin/plugin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ func (*Client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro
return nil
}

func (*Client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) {
// Not implemented, just used for testing destination packaging
return nil, nil
}

func Configure(_ context.Context, logger zerolog.Logger, spec []byte, opts plugin.NewClientOptions) (plugin.Client, error) {
if opts.NoConnection {
return &Client{
Expand Down
4 changes: 4 additions & 0 deletions internal/memdb/memdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ func (*client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro
return nil
}

func (*client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) {
return nil, nil
}

func evaluatePredicate(pred message.Predicate, record arrow.Record) bool {
sc := record.Schema()
indices := sc.FieldIndices(pred.Column)
Expand Down
4 changes: 4 additions & 0 deletions internal/reversertransformer/reversertransformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func (c *client) Transform(ctx context.Context, recvRecords <-chan arrow.Record,
}
}

func (*client) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return old, nil
}

func (*client) reverseStrings(record arrow.Record) (arrow.Record, error) {
for i, column := range record.Columns() {
if column.DataType().ID() != arrow.STRING {
Expand Down
4 changes: 1 addition & 3 deletions internal/reversertransformer/reversertransformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import (
"google.golang.org/grpc/metadata"
)

var mem = memory.NewGoAllocator()

func TestReverserTransformer(t *testing.T) {
p := plugin.NewPlugin("test", "development", GetNewClient())
s := internalPlugin.Server{
Expand Down Expand Up @@ -58,7 +56,7 @@ func makeRequestFromString(s string) *pb.Transform_Request {
}

func makeRecordFromString(s string) arrow.Record {
str := array.NewStringBuilder(mem)
str := array.NewStringBuilder(memory.DefaultAllocator)
str.AppendString(s)
arr := str.NewStringArray()
schema := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil)
Expand Down
16 changes: 16 additions & 0 deletions internal/servers/plugin/v3/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,22 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
return eg.Wait()
}

func (s *Server) TransformSchema(ctx context.Context, req *pb.TransformSchema_Request) (*pb.TransformSchema_Response, error) {
sc, err := pb.NewSchemaFromBytes(req.Schema)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to create schema from bytes: %v", err)
}
newSchema, err := s.Plugin.TransformSchema(ctx, sc)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to transform schema: %v", err)
}
encoded, err := pb.SchemaToBytes(newSchema)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to encode schema: %v", err)
}
return &pb.TransformSchema_Response{Schema: encoded}, nil
}

func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) {
return &pb.Close_Response{}, s.Plugin.Close(ctx)
}
65 changes: 65 additions & 0 deletions internal/servers/plugin/v3/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/cloudquery/plugin-sdk/v4/internal/memdb"
"github.com/cloudquery/plugin-sdk/v4/plugin"
"github.com/cloudquery/plugin-sdk/v4/schema"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
Expand Down Expand Up @@ -183,3 +185,66 @@ func TestPluginSync(t *testing.T) {
t.Fatal(err)
}
}

func TestTransformSchema(t *testing.T) {
ctx := context.Background()
s := Server{
Plugin: plugin.NewPlugin("test", "development", getColumnAdderPlugin()),
}

_, err := s.Init(ctx, &pb.Init_Request{})
if err != nil {
t.Fatal(err)
}

table := &schema.Table{
Name: "test",
Columns: []schema.Column{
{
Name: "test",
Type: arrow.BinaryTypes.String,
},
},
}
sc := table.ToArrowSchema()

schemaBytes, err := pb.SchemaToBytes(sc)
require.NoError(t, err)

resp, err := s.TransformSchema(ctx, &pb.TransformSchema_Request{Schema: schemaBytes})
if err != nil {
t.Fatal(err)
}

newSchema, err := pb.NewSchemaFromBytes(resp.Schema)
require.NoError(t, err)

require.Len(t, newSchema.Fields(), 2)
require.Equal(t, "test", newSchema.Fields()[0].Name)
require.Equal(t, "source", newSchema.Fields()[1].Name)
require.Equal(t, "utf8", newSchema.Fields()[1].Type.(*arrow.StringType).Name())

if _, err := s.Close(ctx, &pb.Close_Request{}); err != nil {
t.Fatal(err)
}
}

type mockSourceColumnAdderPluginClient struct {
plugin.UnimplementedDestination
plugin.UnimplementedSource
}

func getColumnAdderPlugin(...plugin.Option) plugin.NewClientFunc {
c := &mockSourceColumnAdderPluginClient{}
return func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) {
return c, nil
}
}

func (*mockSourceColumnAdderPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
return nil
}
func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return old.AddField(1, arrow.Field{Name: "source", Type: arrow.BinaryTypes.String})
}
func (*mockSourceColumnAdderPluginClient) Close(context.Context) error { return nil }
3 changes: 3 additions & 0 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type UnimplementedTransformer struct{}
func (UnimplementedTransformer) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
return ErrNotImplemented
}
func (UnimplementedTransformer) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) {
return nil, ErrNotImplemented
}

// Plugin is the base structure required to pass to sdk.serve
// We take a declarative approach to API here similar to Cobra
Expand Down
3 changes: 3 additions & 0 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ func (*testPluginClient) Close(context.Context) error {
func (*testPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
return nil
}
func (*testPluginClient) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) {
return nil, nil
}

func TestPluginSuccess(t *testing.T) {
ctx := context.Background()
Expand Down
4 changes: 4 additions & 0 deletions plugin/plugin_transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (

type TransformerClient interface {
Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error
TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error)
}

func (p *Plugin) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
return p.client.Transform(ctx, recvRecords, sendRecords)
}
func (p *Plugin) TransformSchema(ctx context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return p.client.TransformSchema(ctx, old)
}

1 comment on commit 30875d6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⏱️ Benchmark results

  • Glob-8 ns/op: 91.31

Please sign in to comment.