Skip to content

Commit 30875d6

Browse files
authored
feat: Implement TransformSchema support. (#1838)
As discussed offline, Transformer plugins are going to need to explicitly provide a synchronous function to transform a schema. Previously, I implemented a method by which this wasn't necessary, by reusing the existing transform function, but the problem is that we cannot assume that the transform function is not gonna swallow the initial message of a given table, or that it's gonna behave properly with an empty record with a schema.
1 parent d2c5c7b commit 30875d6

File tree

9 files changed

+105
-3
lines changed

9 files changed

+105
-3
lines changed

examples/simple_plugin/plugin/client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ func (*Client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro
6969
return nil
7070
}
7171

72+
func (*Client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) {
73+
// Not implemented, just used for testing destination packaging
74+
return nil, nil
75+
}
76+
7277
func Configure(_ context.Context, logger zerolog.Logger, spec []byte, opts plugin.NewClientOptions) (plugin.Client, error) {
7378
if opts.NoConnection {
7479
return &Client{

internal/memdb/memdb.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ func (*client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro
311311
return nil
312312
}
313313

314+
func (*client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) {
315+
return nil, nil
316+
}
317+
314318
func evaluatePredicate(pred message.Predicate, record arrow.Record) bool {
315319
sc := record.Schema()
316320
indices := sc.FieldIndices(pred.Column)

internal/reversertransformer/reversertransformer.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ func (c *client) Transform(ctx context.Context, recvRecords <-chan arrow.Record,
5858
}
5959
}
6060

61+
func (*client) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
62+
return old, nil
63+
}
64+
6165
func (*client) reverseStrings(record arrow.Record) (arrow.Record, error) {
6266
for i, column := range record.Columns() {
6367
if column.DataType().ID() != arrow.STRING {

internal/reversertransformer/reversertransformer_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import (
1616
"google.golang.org/grpc/metadata"
1717
)
1818

19-
var mem = memory.NewGoAllocator()
20-
2119
func TestReverserTransformer(t *testing.T) {
2220
p := plugin.NewPlugin("test", "development", GetNewClient())
2321
s := internalPlugin.Server{
@@ -58,7 +56,7 @@ func makeRequestFromString(s string) *pb.Transform_Request {
5856
}
5957

6058
func makeRecordFromString(s string) arrow.Record {
61-
str := array.NewStringBuilder(mem)
59+
str := array.NewStringBuilder(memory.DefaultAllocator)
6260
str.AppendString(s)
6361
arr := str.NewStringArray()
6462
schema := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil)

internal/servers/plugin/v3/plugin.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,22 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
475475
return eg.Wait()
476476
}
477477

478+
func (s *Server) TransformSchema(ctx context.Context, req *pb.TransformSchema_Request) (*pb.TransformSchema_Response, error) {
479+
sc, err := pb.NewSchemaFromBytes(req.Schema)
480+
if err != nil {
481+
return nil, status.Errorf(codes.InvalidArgument, "failed to create schema from bytes: %v", err)
482+
}
483+
newSchema, err := s.Plugin.TransformSchema(ctx, sc)
484+
if err != nil {
485+
return nil, status.Errorf(codes.Internal, "failed to transform schema: %v", err)
486+
}
487+
encoded, err := pb.SchemaToBytes(newSchema)
488+
if err != nil {
489+
return nil, status.Errorf(codes.Internal, "failed to encode schema: %v", err)
490+
}
491+
return &pb.TransformSchema_Response{Schema: encoded}, nil
492+
}
493+
478494
func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) {
479495
return &pb.Close_Response{}, s.Plugin.Close(ctx)
480496
}

internal/servers/plugin/v3/plugin_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"github.com/cloudquery/plugin-sdk/v4/internal/memdb"
1313
"github.com/cloudquery/plugin-sdk/v4/plugin"
1414
"github.com/cloudquery/plugin-sdk/v4/schema"
15+
"github.com/rs/zerolog"
16+
"github.com/stretchr/testify/require"
1517
"google.golang.org/grpc"
1618
"google.golang.org/grpc/metadata"
1719
)
@@ -183,3 +185,66 @@ func TestPluginSync(t *testing.T) {
183185
t.Fatal(err)
184186
}
185187
}
188+
189+
func TestTransformSchema(t *testing.T) {
190+
ctx := context.Background()
191+
s := Server{
192+
Plugin: plugin.NewPlugin("test", "development", getColumnAdderPlugin()),
193+
}
194+
195+
_, err := s.Init(ctx, &pb.Init_Request{})
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
200+
table := &schema.Table{
201+
Name: "test",
202+
Columns: []schema.Column{
203+
{
204+
Name: "test",
205+
Type: arrow.BinaryTypes.String,
206+
},
207+
},
208+
}
209+
sc := table.ToArrowSchema()
210+
211+
schemaBytes, err := pb.SchemaToBytes(sc)
212+
require.NoError(t, err)
213+
214+
resp, err := s.TransformSchema(ctx, &pb.TransformSchema_Request{Schema: schemaBytes})
215+
if err != nil {
216+
t.Fatal(err)
217+
}
218+
219+
newSchema, err := pb.NewSchemaFromBytes(resp.Schema)
220+
require.NoError(t, err)
221+
222+
require.Len(t, newSchema.Fields(), 2)
223+
require.Equal(t, "test", newSchema.Fields()[0].Name)
224+
require.Equal(t, "source", newSchema.Fields()[1].Name)
225+
require.Equal(t, "utf8", newSchema.Fields()[1].Type.(*arrow.StringType).Name())
226+
227+
if _, err := s.Close(ctx, &pb.Close_Request{}); err != nil {
228+
t.Fatal(err)
229+
}
230+
}
231+
232+
type mockSourceColumnAdderPluginClient struct {
233+
plugin.UnimplementedDestination
234+
plugin.UnimplementedSource
235+
}
236+
237+
func getColumnAdderPlugin(...plugin.Option) plugin.NewClientFunc {
238+
c := &mockSourceColumnAdderPluginClient{}
239+
return func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) {
240+
return c, nil
241+
}
242+
}
243+
244+
func (*mockSourceColumnAdderPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
245+
return nil
246+
}
247+
func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
248+
return old.AddField(1, arrow.Field{Name: "source", Type: arrow.BinaryTypes.String})
249+
}
250+
func (*mockSourceColumnAdderPluginClient) Close(context.Context) error { return nil }

plugin/plugin.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ type UnimplementedTransformer struct{}
5858
func (UnimplementedTransformer) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
5959
return ErrNotImplemented
6060
}
61+
func (UnimplementedTransformer) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) {
62+
return nil, ErrNotImplemented
63+
}
6164

6265
// Plugin is the base structure required to pass to sdk.serve
6366
// We take a declarative approach to API here similar to Cobra

plugin/plugin_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ func (*testPluginClient) Close(context.Context) error {
5959
func (*testPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
6060
return nil
6161
}
62+
func (*testPluginClient) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) {
63+
return nil, nil
64+
}
6265

6366
func TestPluginSuccess(t *testing.T) {
6467
ctx := context.Background()

plugin/plugin_transformer.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ import (
88

99
type TransformerClient interface {
1010
Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error
11+
TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error)
1112
}
1213

1314
func (p *Plugin) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
1415
return p.client.Transform(ctx, recvRecords, sendRecords)
1516
}
17+
func (p *Plugin) TransformSchema(ctx context.Context, old *arrow.Schema) (*arrow.Schema, error) {
18+
return p.client.TransformSchema(ctx, old)
19+
}

0 commit comments

Comments
 (0)