diff --git a/connector/server_test.go b/connector/server_test.go index 8aa74036..14bb60fd 100644 --- a/connector/server_test.go +++ b/connector/server_test.go @@ -188,8 +188,8 @@ func (mc *mockConnector) Query(ctx context.Context, configuration *mockConfigura return schema.QueryResponse{ { Aggregates: schema.RowSetAggregates{}, - Rows: []schema.Row{ - map[string]any{ + Rows: []map[string]any{ + { "id": 1, "title": "Hello world", "author_id": 1, @@ -449,8 +449,8 @@ func TestServerConnector(t *testing.T) { assertHTTPResponse(t, res, http.StatusOK, schema.QueryResponse{ { Aggregates: schema.RowSetAggregates{}, - Rows: []schema.Row{ - map[string]any{ + Rows: []map[string]any{ + { "id": 1, "title": "Hello world", "author_id": 1, diff --git a/example/reference/connector.go b/example/reference/connector.go index 0396c2d8..d15093ba 100644 --- a/example/reference/connector.go +++ b/example/reference/connector.go @@ -4,6 +4,10 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "regexp" + "slices" + "strings" "github.com/hasura/ndc-sdk-go/connector" "github.com/hasura/ndc-sdk-go/schema" @@ -25,10 +29,31 @@ type Author struct { LastName string `json:"last_name"` } +type InstitutionLocation struct { + City string `json:"city"` + Country string `json:"country"` + Campuses []string `json:"campuses"` +} + +type InstitutionStaff struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Specialities []string `json:"specialities"` +} + +type Institution struct { + ID int `json:"id"` + Name string `json:"name"` + Location InstitutionLocation `json:"location"` + Staff []InstitutionStaff `json:"staff"` + Departments []string `json:"departments"` +} + type State struct { - Authors map[int]Author - Articles map[int]Article - Telemetry *connector.TelemetryState + Authors map[int]Author + Articles map[int]Article + Institutions map[int]Institution + Telemetry *connector.TelemetryState } func (s State) GetLatestArticle() *Article { @@ -72,10 +97,18 @@ func (mc *Connector) TryInitState(configuration *Configuration, metrics *connect }) } + institutions, err := readInstitutions() + if err != nil { + return nil, schema.InternalServerError("failed to read institutions from json", map[string]any{ + "cause": err.Error(), + }) + } + return &State{ - Authors: authors, - Articles: articles, - Telemetry: metrics, + Authors: authors, + Articles: articles, + Institutions: institutions, + Telemetry: metrics, }, nil } @@ -235,63 +268,24 @@ func (mc *Connector) MutationExplain(ctx context.Context, configuration *Configu } func (mc *Connector) Query(ctx context.Context, configuration *Configuration, state *State, request *schema.QueryRequest) (schema.QueryResponse, error) { - var rows []schema.Row - switch request.Collection { - case "articles": - for _, item := range state.Articles { - row, err := schema.PruneFields(request.Query.Fields, item) - if err != nil { - return nil, err - } - rows = append(rows, row) - } - case "authors": - for _, item := range state.Authors { - row, err := schema.PruneFields(request.Query.Fields, item) - if err != nil { - return nil, err - } - rows = append(rows, row) - } - case "articles_by_author": - authorIdArg, ok := request.Arguments["author_id"] - if !ok { - return nil, schema.BadGatewayError("missing argument author_id", nil) - } - for _, row := range state.Articles { - switch authorIdArg.Type { - case schema.ArgumentTypeLiteral: - if fmt.Sprint(row.AuthorID) == fmt.Sprint(authorIdArg.Value) { - r, err := schema.PruneFields(request.Query.Fields, row) - if err != nil { - return nil, err - } - rows = append(rows, r) - } - } - } - case "latest_article_id": - latestArticle := state.GetLatestArticle() - if latestArticle == nil { - return nil, schema.BadRequestError("No available article", nil) - } + variableSets := request.Variables + if variableSets == nil { + variableSets = []schema.QueryRequestVariablesElem{make(map[string]any)} + } - rows = []schema.Row{ - map[string]any{ - "__value": latestArticle.ID, - }, + rowSets := make([]schema.RowSet, 0, len(variableSets)) + + for _, variables := range variableSets { + rowSet, err := executeQueryWithVariables(request.Collection, request.Arguments, request.CollectionRelationships, &request.Query, variables, state) + if err != nil { + return nil, err } - default: - return nil, schema.BadRequestError(fmt.Sprintf("invalid collection name %s", request.Collection), nil) + + rowSets = append(rowSets, *rowSet) } - return schema.QueryResponse{ - { - Rows: rows, - Aggregates: schema.RowSetAggregates{}, - }, - }, nil + return rowSets, nil } func (mc *Connector) Mutation(ctx context.Context, configuration *Configuration, state *State, request *schema.MutationRequest) (*schema.MutationResponse, error) { @@ -341,11 +335,949 @@ func executeProcedure(ctx context.Context, state *State, collectionRelationship } state.Articles[args.Article.ID] = args.Article + row, err := schema.PruneFields(operation.Fields, args.Article) + if err != nil { + return nil, err + } return &schema.MutationOperationResults{ AffectedRows: 1, - Returning: []schema.Row{args.Article}, + Returning: []map[string]any{row}, }, nil default: return nil, schema.BadRequestError("unknown procedure", nil) } } + +func executeQueryWithVariables( + collection string, + arguments map[string]schema.Argument, + collectionRelationships map[string]schema.Relationship, + query *schema.Query, + variables map[string]any, + state *State, +) (*schema.RowSet, error) { + argumentValues := make(map[string]any) + + for argumentName, argument := range arguments { + argumentValue, err := evalArgument(variables, &argument) + if err != nil { + return nil, err + } + argumentValues[argumentName] = argumentValue + } + + // FIXME: argument + coll, err := getCollectionByName(collection, nil, state) + if err != nil { + return nil, err + } + return executeQuery(collectionRelationships, variables, state, query, nil, coll) +} + +func evalAggregate(aggregate *schema.Aggregate, paginated []map[string]any) (any, error) { + iAgg, err := aggregate.Interface() + switch agg := iAgg.(type) { + case *schema.AggregateStarCount: + return len(paginated), nil + case *schema.AggregateColumnCount: + var values []string + for _, value := range paginated { + v, ok := value[agg.Column] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid column name: %s", agg.Column), nil) + } + if v == nil { + continue + } + values = append(values, fmt.Sprint(v)) + } + if !agg.Distinct { + return len(values), nil + } + distinctValue := make(map[string]bool) + for _, v := range values { + distinctValue[v] = true + } + return len(distinctValue), nil + case *schema.AggregateSingleColumn: + var values []any + for _, value := range paginated { + v, ok := value[agg.Column] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid column name: %s", agg.Column), nil) + } + if v == nil { + continue + } + values = append(values, v) + } + return evalAggregateFunction(agg.Function, values) + default: + return nil, err + } +} + +func evalAggregateFunction(function string, values []any) (*int64, error) { + if len(values) == 0 { + return nil, nil + } + + var intValues []int64 + for _, value := range values { + switch v := value.(type) { + case int: + intValues = append(intValues, int64(v)) + case int16: + intValues = append(intValues, int64(v)) + case int32: + intValues = append(intValues, int64(v)) + case int64: + intValues = append(intValues, v) + default: + return nil, schema.BadRequestError(fmt.Sprintf("%s: column is not an integer, got %+v", function, reflect.ValueOf(v).Kind()), nil) + } + } + + slices.Sort(intValues) + + switch function { + case "min": + return &intValues[0], nil + case "max": + return &intValues[len(intValues)-1], nil + default: + return nil, schema.BadRequestError(fmt.Sprintf("%s: invalid aggregation function", function), nil) + } +} + +func executeQuery( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + query *schema.Query, + root map[string]any, + collection []map[string]any, +) (*schema.RowSet, error) { + sorted, err := sortCollection(collectionRelationships, variables, state, collection, query.OrderBy) + if err != nil { + return nil, err + } + + filtered := sorted + if query.Predicate != nil { + filtered = []map[string]any{} + for _, item := range sorted { + rootItem := root + if rootItem == nil { + rootItem = item + } + ok, err := evalExpression(collectionRelationships, variables, state, query.Predicate, rootItem, item) + if err != nil { + return nil, err + } + if ok { + filtered = append(filtered, item) + } + } + } + + paginated := paginate(filtered, query.Limit, query.Offset) + aggregates := make(map[string]any) + + for aggKey, aggregate := range query.Aggregates { + aggValue, err := evalAggregate(&aggregate, paginated) + if err != nil { + return nil, err + } + aggregates[aggKey] = aggValue + } + + rows := make([]map[string]any, 0) + for _, item := range paginated { + row, err := evalRow(query.Fields, collectionRelationships, variables, state, item) + if err != nil { + return nil, err + } + if row != nil { + rows = append(rows, row) + } + } + return &schema.RowSet{ + Aggregates: aggregates, + Rows: rows, + }, nil +} + +func sortCollection( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + collection []map[string]any, + orderBy *schema.OrderBy, +) ([]map[string]any, error) { + if orderBy == nil { + return collection, nil + } + + var results []map[string]any + for _, itemToInsert := range collection { + if len(results) == 0 { + results = append(results, itemToInsert) + continue + } + index := 0 + for _, other := range results { + ordering, err := evalOrderBy(collectionRelationships, variables, state, orderBy, other, itemToInsert) + if err != nil { + return nil, err + } + if ordering > 0 { + break + } + index++ + } + results = append(append(results[:index], itemToInsert), results[(index+1):]...) + } + return results, nil +} + +func paginate[R any](collection []R, limit *int, offset *int) []R { + var start int + if offset != nil { + start = *offset + } + if limit == nil { + return collection[start:] + } + + return collection[start : start+*limit] +} + +func evalOrderBy( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + orderBy *schema.OrderBy, + t1 map[string]any, + t2 map[string]any, +) (int, error) { + ordering := 0 + for _, orderElem := range orderBy.Elements { + v1, err := evalOrderByElement(collectionRelationships, variables, state, &orderElem, t1) + if err != nil { + return 0, err + } + v2, err := evalOrderByElement(collectionRelationships, variables, state, &orderElem, t2) + if err != nil { + return 0, err + } + switch orderElem.OrderDirection { + case schema.OrderDirectionAsc: + // FIXME: compose ordering + ordering, err = compare(v1, v2) + if err != nil { + return 0, err + } + case schema.OrderDirectionDesc: + ordering, err = compare(v2, v1) + if err != nil { + return 0, err + } + } + } + + return ordering, nil +} + +func boolToInt(v bool) int { + if v { + return 1 + } + return 0 +} + +func compare(v1 any, v2 any) (int, error) { + if v1 == v2 || (v1 == nil && v2 == nil) { + return 0, nil + } + if v1 == nil { + return -1, nil + } + if v2 == nil { + return 1, nil + } + + kindV1 := reflect.ValueOf(v1).Kind() + kindV2 := reflect.ValueOf(v2).Kind() + + if kindV1 != kindV2 { + return 0, schema.InternalServerError(fmt.Sprintf("cannot compare values with different types: %s <> %s", kindV1, kindV2), nil) + } + + switch value1 := v1.(type) { + case bool: + value2 := v2.(bool) + return boolToInt(value1) - boolToInt(value2), nil + case int: + value2 := v2.(int) + return value1 - value2, nil + case int8: + value2 := v2.(int8) + return int(value1 - value2), nil + case int16: + value2 := v2.(int16) + return int(value1 - value2), nil + case int32: + value2 := v2.(int32) + return int(value1 - value2), nil + case int64: + value2 := v2.(int64) + return int(value1 - value2), nil + case string: + value2 := v2.(string) + return strings.Compare(value1, value2), nil + default: + return 0, schema.InternalServerError(fmt.Sprintf("cannot compare values with type: %s", kindV1), nil) + } +} + +func evalOrderByElement( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + element *schema.OrderByElement, + item map[string]any, +) (any, error) { + iTarget, err := element.Target.Interface() + switch target := iTarget.(type) { + case *schema.OrderByColumn: + return evalOrderByColumn(collectionRelationships, variables, state, item, target.Path, target.Name) + case *schema.OrderBySingleColumnAggregate: + return evalOrderBySingleColumnAggregate(collectionRelationships, variables, state, item, target.Path, target.Column, target.Function) + case *schema.OrderByStarCountAggregate: + return evalOrderByStarCountAggregate(collectionRelationships, variables, state, item, target.Path) + default: + return nil, err + } +} + +func evalOrderByStarCountAggregate( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + item map[string]any, + path []schema.PathElement, +) (int, error) { + rows, err := evalPath(collectionRelationships, variables, state, path, item) + + return len(rows), err +} + +func evalOrderBySingleColumnAggregate( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + item map[string]any, + path []schema.PathElement, + column string, + function string, +) (any, error) { + rows, err := evalPath(collectionRelationships, variables, state, path, item) + if err != nil { + return nil, err + } + + var values []any + for _, row := range rows { + value, ok := row[column] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid column name: %s", column), nil) + } + values = append(values, value) + } + return evalAggregateFunction(function, values) +} + +func evalOrderByColumn( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + item map[string]any, + path []schema.PathElement, + name string, +) (any, error) { + rows, err := evalPath(collectionRelationships, variables, state, path, item) + if err != nil { + return nil, err + } + if len(rows) > 1 { + return nil, schema.BadRequestError("expected one path value only", nil) + } + if len(rows) == 0 || rows[0] == nil { + return nil, nil + } + value, ok := rows[0][name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid column name: %s", name), nil) + } + return value, nil +} + +func evalInCollection( + collectionRelationships map[string]schema.Relationship, + item map[string]any, + variables map[string]any, + state *State, + inCollection schema.ExistsInCollection, +) ([]map[string]any, error) { + iInCollection, err := inCollection.Interface() + switch inCol := iInCollection.(type) { + case *schema.ExistsInCollectionRelated: + relationship, ok := collectionRelationships[inCol.Relationship] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid in collection relationship: %s", inCol.Relationship), nil) + } + source := []map[string]any{item} + return evalPathElement(collectionRelationships, variables, state, &relationship, inCol.Arguments, source, nil) + case *schema.ExistsInCollectionUnrelated: + arguments := make(map[string]any) + for key, relArg := range inCol.Arguments { + arg, err := evalRelationshipArgument(variables, item, &relArg) + if err != nil { + return nil, err + } + arguments[key] = arg + } + // FIXME: arguments? + return getCollectionByName(inCol.Collection, nil, state) + default: + return nil, err + } +} + +func evalRow(fields map[string]schema.Field, collectionRelationships map[string]schema.Relationship, variables map[string]any, state *State, item map[string]any) (map[string]any, error) { + if len(fields) == 0 { + return nil, nil + } + row := make(map[string]any) + for fieldName, field := range fields { + fieldValue, err := evalField(collectionRelationships, variables, state, field, item) + if err != nil { + return nil, err + } + row[fieldName] = fieldValue + } + + return row, nil +} + +func evalNestedField( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + value any, + nestedField *schema.NestedField, +) (any, error) { + iNestedField, err := nestedField.Interface() + switch nf := iNestedField.(type) { + case *schema.NestedObject: + fullRow, ok := value.(map[string]any) + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("expected object, got %s", reflect.ValueOf(value).Kind()), nil) + } + + row, err := evalRow( + nf.Fields, + collectionRelationships, + variables, + state, + fullRow, + ) + if err != nil { + return nil, err + } + + return row, nil + case *schema.NestedArray: + array, ok := value.(map[string]any) + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("expected object, got %s", reflect.ValueOf(value).Kind()), nil) + } + + result := []any{} + for _, item := range array { + val, err := evalNestedField(collectionRelationships, variables, state, item, &nf.Fields) + if err != nil { + return nil, err + } + result = append(result, val) + } + return result, nil + default: + return nil, err + } +} + +func evalField( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + field schema.Field, + row map[string]any, +) (any, error) { + iField, err := field.Interface() + switch f := iField.(type) { + case *schema.ColumnField: + value, ok := row[f.Column] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid column name: %s", f.Column), nil) + } + if len(f.Fields) == 0 { + return value, nil + } + return evalNestedField(collectionRelationships, variables, state, value, &f.Fields) + case *schema.RelationshipField: + relationship, ok := collectionRelationships[f.Relationship] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid relationship name %s", f.Relationship), nil) + } + + collection, err := evalPathElement(collectionRelationships, variables, state, &relationship, f.Arguments, []map[string]any{row}, nil) + if err != nil { + return nil, err + } + + return executeQuery(collectionRelationships, variables, state, &f.Query, nil, collection) + + default: + return nil, err + } +} + +func evalPathElement( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + relationship *schema.Relationship, + arguments map[string]schema.RelationshipArgument, + source []map[string]any, + predicate schema.Expression, +) ([]map[string]any, error) { + allArguments := make(map[string]any) + var matchingRows []map[string]any + + // Note: Join strategy + // + // Rows can be related in two ways: 1) via a column mapping, and + // 2) via collection arguments. Because collection arguments can be computed + // using the columns on the source side of a relationship, in general + // we need to compute the target collection once for each source row. + // This join strategy can result in some target rows appearing in the + // resulting row set more than once, if two source rows are both related + // to the same target row. + // + // In practice, this is not an issue, either because a) the relationship + // is computed in the course of evaluating a predicate, and all predicates are + // implicitly or explicitly existentially quantified, or b) if the + // relationship is computed in the course of evaluating an ordering, the path + // should consist of all object relationships, and possibly terminated by a + // single array relationship, so there should be no double counting. + for _, srcRow := range source { + for argName, arg := range relationship.Arguments { + relValue, err := evalRelationshipArgument(variables, srcRow, &arg) + if err != nil { + return nil, err + } + allArguments[argName] = relValue + } + for argName, arg := range arguments { + if _, ok := allArguments[argName]; ok { + return nil, schema.BadRequestError(fmt.Sprintf("duplicate argument name: %s", argName), nil) + } + relValue, err := evalRelationshipArgument(variables, srcRow, &arg) + if err != nil { + return nil, err + } + allArguments[argName] = relValue + } + + // FIXME: arguments? + targetRows, err := getCollectionByName(relationship.TargetCollection, nil, state) + if err != nil { + return nil, err + } + + for _, targetRow := range targetRows { + ok, err := evalColumnMapping(relationship, srcRow, targetRow) + if err != nil { + return nil, err + } + if !ok { + continue + } + if predicate != nil { + ok, err := evalExpression(collectionRelationships, variables, state, predicate, targetRow, targetRow) + if err != nil { + return nil, err + } + if !ok { + continue + } + } + matchingRows = append(matchingRows, targetRow) + } + } + + return matchingRows, nil +} + +func evalRelationshipArgument(variables map[string]any, row map[string]any, argument *schema.RelationshipArgument) (any, error) { + switch argument.Type { + case schema.RelationshipArgumentTypeColumn: + value, ok := row[argument.Name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid column name: %s", argument.Name), nil) + } + return value, nil + case schema.RelationshipArgumentTypeLiteral: + return argument.Value, nil + case schema.RelationshipArgumentTypeVariable: + variable, ok := variables[argument.Name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid variable name: %s", argument.Name), nil) + } + return variable, nil + default: + return nil, schema.BadRequestError(fmt.Sprintf("invalid argument type: %s", argument.Type), nil) + } +} + +func getCollectionByName(collectionName string, arguments schema.QueryRequestArguments, state *State) ([]map[string]any, error) { + var rows []map[string]any + switch collectionName { + // function + case "latest_article_id": + latestArticle := state.GetLatestArticle() + if latestArticle == nil { + return nil, schema.BadRequestError("No available article", nil) + } + return []map[string]any{ + { + "__value": latestArticle.ID, + }, + }, nil + + // collections + case "articles": + for _, item := range state.Articles { + row, err := schema.EncodeRow(item) + if err != nil { + return nil, err + } + rows = append(rows, row) + } + case "authors": + for _, item := range state.Authors { + row, err := schema.EncodeRow(item) + if err != nil { + return nil, err + } + rows = append(rows, row) + } + case "institutions": + for _, item := range state.Institutions { + row, err := schema.EncodeRow(item) + if err != nil { + return nil, err + } + rows = append(rows, row) + } + case "articles_by_author": + authorIdArg, ok := arguments["author_id"] + if !ok { + return nil, schema.BadGatewayError("missing argument author_id", nil) + } + + for _, row := range state.Articles { + switch authorIdArg.Type { + case schema.ArgumentTypeLiteral: + if fmt.Sprint(row.AuthorID) == fmt.Sprint(authorIdArg.Value) { + r, err := schema.EncodeRow(row) + if err != nil { + return nil, err + } + rows = append(rows, r) + } + } + } + default: + return nil, schema.BadRequestError(fmt.Sprintf("invalid collection name %s", collectionName), nil) + } + + return rows, nil +} + +func evalComparisonValue( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + comparisonValue schema.ComparisonValue, + root map[string]any, + item map[string]any, +) ([]any, error) { + iCompValue, err := comparisonValue.Interface() + switch compValue := iCompValue.(type) { + case *schema.ComparisonValueColumn: + return evalComparisonTarget(collectionRelationships, variables, state, &compValue.Column, root, item) + case *schema.ComparisonValueScalar: + return []any{compValue.Value}, nil + case *schema.ComparisonValueVariable: + if len(variables) == 0 { + return nil, schema.BadRequestError(fmt.Sprintf("invalid variable name: %s", compValue.Name), nil) + } + val, ok := variables[compValue.Name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid variable name: %s", compValue.Name), nil) + } + return []any{val}, nil + default: + return nil, err + } +} + +func evalComparisonTarget( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + target *schema.ComparisonTarget, + root map[string]any, + item map[string]any, +) ([]any, error) { + switch target.Type { + case schema.ComparisonTargetTypeColumn: + rows, err := evalPath(collectionRelationships, variables, state, target.Path, item) + if err != nil { + return nil, err + } + var result []any + for _, row := range rows { + value, ok := row[target.Name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid comparison target column name: %s", target.Name), nil) + } + result = append(result, value) + } + return result, nil + case schema.ComparisonTargetTypeRootCollectionColumn: + value, ok := root[target.Name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid comparison target column name: %s", target.Name), nil) + } + return []any{value}, nil + default: + return nil, schema.BadRequestError(fmt.Sprintf("invalid comparison target type: %s", target.Type), nil) + } +} + +func evalPath( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + path []schema.PathElement, + item map[string]any, +) ([]map[string]any, error) { + var err error + result := []map[string]any{item} + + for _, pathElem := range path { + relationshipName := pathElem.Relationship + relationship, ok := collectionRelationships[relationshipName] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid relationship name in path: %s", relationshipName), nil) + } + result, err = evalPathElement(collectionRelationships, variables, state, &relationship, pathElem.Arguments, result, pathElem.Predicate) + if err != nil { + return nil, err + } + } + return result, nil +} + +func evalExpression( + collectionRelationships map[string]schema.Relationship, + variables map[string]any, + state *State, + expr schema.Expression, + root map[string]any, + item map[string]any, +) (bool, error) { + iExpr, err := expr.Interface() + switch expression := iExpr.(type) { + case *schema.ExpressionAnd: + for _, exp := range expression.Expressions { + ok, err := evalExpression(collectionRelationships, variables, state, exp, root, item) + if err != nil || !ok { + return false, err + } + } + return true, nil + case *schema.ExpressionOr: + for _, exp := range expression.Expressions { + ok, err := evalExpression(collectionRelationships, variables, state, exp, root, item) + if err != nil { + return false, err + } + if ok { + return true, nil + } + } + return false, nil + case *schema.ExpressionNot: + ok, err := evalExpression(collectionRelationships, variables, state, expression.Expression, root, item) + if err != nil { + return false, err + } + return !ok, nil + case *schema.ExpressionUnaryComparisonOperator: + switch expression.Operator { + case schema.UnaryComparisonOperatorIsNull: + values, err := evalComparisonTarget(collectionRelationships, variables, state, &expression.Column, root, item) + if err != nil { + return false, err + } + for _, val := range values { + if val == nil { + return true, nil + } + } + return false, nil + default: + return false, schema.BadRequestError(fmt.Sprintf("invalid unary comparison operator: %s", expression.Operator), nil) + } + case *schema.ExpressionBinaryComparisonOperator: + switch expression.Operator { + case "eq": + leftValues, err := evalComparisonTarget(collectionRelationships, variables, state, &expression.Column, root, item) + if err != nil { + return false, err + } + rightValues, err := evalComparisonValue(collectionRelationships, variables, state, expression.Value, root, item) + if err != nil { + return false, err + } + for _, leftVal := range leftValues { + for _, rightVal := range rightValues { + if leftVal == rightVal { + return true, nil + } + } + } + return false, nil + case "like": + columnValues, err := evalComparisonTarget(collectionRelationships, variables, state, &expression.Column, root, item) + if err != nil { + return false, err + } + regexValues, err := evalComparisonValue(collectionRelationships, variables, state, expression.Value, root, item) + if err != nil { + return false, err + } + + for _, columnValue := range columnValues { + columnStr, ok := columnValue.(string) + if !ok { + return false, schema.BadRequestError(fmt.Sprintf("value of column %s is not a string, got %+v", expression.Column, columnValue), nil) + } + for _, rawRegex := range regexValues { + regexStr, ok := rawRegex.(string) + if !ok { + return false, schema.BadRequestError(fmt.Sprintf("invalid regular expression, got %+v", rawRegex), nil) + } + + regex, err := regexp.Compile(regexStr) + if err != nil { + return false, schema.BadRequestError(fmt.Sprintf("invalid regular expression: %s", err), nil) + } + if regex.Match([]byte(columnStr)) { + return true, nil + } + } + } + + return false, nil + case "in": + leftValues, err := evalComparisonTarget(collectionRelationships, variables, state, &expression.Column, root, item) + if err != nil { + return false, err + } + rightValueSets, err := evalComparisonValue(collectionRelationships, variables, state, expression.Value, root, item) + if err != nil { + return false, err + } + for _, rightValueSet := range rightValueSets { + rightValues, ok := rightValueSet.([]any) + if !ok { + return false, schema.BadRequestError(fmt.Sprintf("expected array, got %+v", rightValueSet), nil) + } + for _, leftVal := range leftValues { + for _, rightVal := range rightValues { + if leftVal == rightVal { + return true, nil + } + } + } + } + return false, nil + default: + return false, schema.BadRequestError(fmt.Sprintf("invalid comparison operator: %s", expression.Operator), nil) + } + case *schema.ExpressionExists: + query := &schema.Query{ + Predicate: expression.Predicate, + } + collection, err := evalInCollection(collectionRelationships, item, variables, state, expression.InCollection) + if err != nil { + return false, err + } + + rowSet, err := executeQuery(collectionRelationships, variables, state, query, root, collection) + if err != nil { + return false, err + } + return len(rowSet.Rows) > 0, nil + default: + return false, err + } +} + +func evalArgument(variables map[string]any, argument *schema.Argument) (any, error) { + switch argument.Type { + case schema.ArgumentTypeVariable: + value, ok := variables[argument.Name] + if !ok { + return nil, schema.BadRequestError(fmt.Sprintf("invalid variable name: %s", argument.Name), nil) + } + return value, nil + case schema.ArgumentTypeLiteral: + return argument.Value, nil + default: + return nil, schema.BadRequestError(fmt.Sprintf("invalid argument type: %s", argument.Type), nil) + } +} + +func evalColumnMapping(relationship *schema.Relationship, srcRow map[string]any, target map[string]any) (bool, error) { + for srcColumn, targetColumn := range relationship.ColumnMapping { + srcValue, ok := srcRow[srcColumn] + if !ok { + return false, schema.BadRequestError(fmt.Sprintf("source column does not exist: %s", srcColumn), nil) + } + targetValue, ok := target[targetColumn] + if !ok { + return false, schema.BadRequestError(fmt.Sprintf("target column does not exist: %s", targetColumn), nil) + } + if srcValue != targetValue { + return false, nil + } + } + return true, nil +} diff --git a/example/reference/connector_test.go b/example/reference/connector_test.go index d800b024..1e8d5936 100644 --- a/example/reference/connector_test.go +++ b/example/reference/connector_test.go @@ -70,11 +70,41 @@ func TestQuery(t *testing.T) { requestURL string responseURL string }{ + { + name: "aggregate_function", + requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/aggregate_function/request.json", + responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/aggregate_function/expected.json", + }, + { + name: "authors_with_article_aggregate", + requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/authors_with_article_aggregate/request.json", + responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/authors_with_article_aggregate/expected.json", + }, + { + name: "authors_with_articles", + requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/authors_with_articles/request.json", + responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/authors_with_articles/expected.json", + }, + { + name: "column_count", + requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/column_count/request.json", + responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/column_count/expected.json", + }, { name: "get_all_articles", requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/get_all_articles/request.json", responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/get_all_articles/expected.json", }, + { + name: "get_max_article_id", + requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/get_max_article_id/request.json", + responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/get_max_article_id/expected.json", + }, + // { + // name: "nested_array_select", + // requestURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/nested_array_select/request.json", + // responseURL: "https://raw.githubusercontent.com/hasura/ndc-spec/main/ndc-reference/tests/query/nested_array_select/expected.json", + // }, } for _, tc := range testCases { diff --git a/example/reference/data.go b/example/reference/data.go index 2f644ae7..cf685bb6 100644 --- a/example/reference/data.go +++ b/example/reference/data.go @@ -3,6 +3,7 @@ package main import ( _ "embed" "encoding/csv" + "encoding/json" "io" "sort" "strconv" @@ -15,6 +16,9 @@ var csvArticles string //go:embed authors.csv var csvAuthors string +//go:embed institutions.json +var jsonInstitutions []byte + func readAuthors() (map[int]Author, error) { r := csv.NewReader(strings.NewReader(csvAuthors)) results := make(map[int]Author) @@ -89,14 +93,6 @@ func readArticles() (map[int]Article, error) { return results, nil } -func getMapValues[K comparable, V any](input map[K]V) []V { - results := make([]V, 0, len(input)) - for _, v := range input { - results = append(results, v) - } - return results -} - func getMapKeys[K comparable, V any](input map[K]V) []K { results := make([]K, 0, len(input)) for k := range input { @@ -122,3 +118,16 @@ func sortArticles(input map[int]Article, key string, descending bool) []Article } return results } + +func readInstitutions() (map[int]Institution, error) { + var institutions []Institution + if err := json.Unmarshal(jsonInstitutions, &institutions); err != nil { + return nil, err + } + + results := make(map[int]Institution) + for _, inst := range institutions { + results[inst.ID] = inst + } + return results, nil +} diff --git a/example/reference/institutions.json b/example/reference/institutions.json new file mode 100644 index 00000000..4adc67bb --- /dev/null +++ b/example/reference/institutions.json @@ -0,0 +1,64 @@ +[ + { + "id": 1, + "name": "Queen Mary University of London", + "location": { + "city": "London", + "country": "UK", + "campuses": [ + "Mile End", + "Whitechapel", + "Charterhouse Square", + "West Smithfield" + ] + }, + "staff": [ + { + "first_name": "Peter", + "last_name": "Landin", + "specialities": ["Computer Science", "Education"] + } + ], + "departments": [ + "Humanities and Social Sciences", + "Science and Engineering", + "Medicine and Dentistry" + ] + }, + { + "id": 2, + "name": "Chalmers University of Technology", + "location": { + "city": "Gothenburg", + "country": "Sweden", + "campuses": ["Johanneberg", "Lindholmen"] + }, + "staff": [ + { + "first_name": "John", + "last_name": "Hughes", + "specialities": [ + "Computer Science", + "Functional Programming", + "Software Testing" + ] + }, + { + "first_name": "Koen", + "last_name": "Claessen", + "specialities": [ + "Computer Science", + "Functional Programming", + "Automated Reasoning" + ] + } + ], + "departments": [ + "Architecture and Civil Engineering", + "Computer Science and Engineering", + "Electrical Engineering", + "Physics", + "Industrial and Materials Science" + ] + } +] diff --git a/schema/extend.go b/schema/extend.go index 5b9d19c2..0d23d2a8 100644 --- a/schema/extend.go +++ b/schema/extend.go @@ -2331,15 +2331,15 @@ func (j *OrderByTarget) UnmarshalJSON(b []byte) error { } switch ty { case OrderByTargetTypeColumn: - rawColumn, ok := raw["column"] + rawName, ok := raw["name"] if !ok { - return errors.New("field column in OrderByTarget is required for column type") + return errors.New("field name in OrderByTarget is required for column type") } - var column string - if err := json.Unmarshal(rawColumn, &column); err != nil { - return fmt.Errorf("field column in OrderByTarget: %s", err) + var name string + if err := json.Unmarshal(rawName, &name); err != nil { + return fmt.Errorf("field name in OrderByTarget: %s", err) } - result["column"] = column + result["name"] = name rawPath, ok := raw["path"] if !ok { @@ -2425,9 +2425,9 @@ func (j OrderByTarget) AsColumn() (*OrderByColumn, error) { return nil, fmt.Errorf("invalid type; expected: %s, got: %s", OrderByTargetTypeColumn, t) } - column := getStringValueByKey(j, "column") - if column == "" { - return nil, errors.New("OrderByColumn.column is required") + name := getStringValueByKey(j, "name") + if name == "" { + return nil, errors.New("OrderByColumn.name is required") } rawPath, ok := j["path"] if !ok { @@ -2438,9 +2438,9 @@ func (j OrderByTarget) AsColumn() (*OrderByColumn, error) { return nil, fmt.Errorf("invalid OrderByColumn.path type; expected: []PathElement, got: %+v", rawPath) } return &OrderByColumn{ - Type: t, - Column: column, - Path: p, + Type: t, + Name: name, + Path: p, }, nil } @@ -2531,7 +2531,7 @@ type OrderByTargetEncoder interface { type OrderByColumn struct { Type OrderByTargetType `json:"type" mapstructure:"type"` // The name of the column - Column string `json:"column" mapstructure:"column"` + Name string `json:"name" mapstructure:"name"` // Any relationships to traverse to reach this column Path []PathElement `json:"path" mapstructure:"path"` } @@ -2539,9 +2539,9 @@ type OrderByColumn struct { // Encode converts the instance to raw OrderByTarget func (ob OrderByColumn) Encode() OrderByTarget { return OrderByTarget{ - "type": ob.Type, - "column": ob.Column, - "path": ob.Path, + "type": ob.Type, + "name": ob.Name, + "path": ob.Path, } } diff --git a/schema/schema.generated.go b/schema/schema.generated.go index 299efa38..9644805b 100644 --- a/schema/schema.generated.go +++ b/schema/schema.generated.go @@ -141,7 +141,7 @@ type MutationOperationResults struct { AffectedRows int `json:"affected_rows" yaml:"affected_rows" mapstructure:"affected_rows"` // The rows affected by the mutation operation - Returning []Row `json:"returning,omitempty" yaml:"returning,omitempty" mapstructure:"returning,omitempty"` + Returning []map[string]any `json:"returning,omitempty" yaml:"returning,omitempty" mapstructure:"returning,omitempty"` } type MutationRequest struct { @@ -314,7 +314,7 @@ type Relationship struct { } // Values to be provided to any collection arguments -type RelationshipArguments map[string]interface{} +type RelationshipArguments map[string]RelationshipArgument type RelationshipCapabilities struct { // Does the connector support ordering by an aggregated array relationship? @@ -777,14 +777,12 @@ func (j *RelationshipType) UnmarshalJSON(b []byte) error { // The results of the aggregates returned by the query type RowSetAggregates map[string]interface{} -type Row any - type RowSet struct { // The results of the aggregates returned by the query Aggregates RowSetAggregates `json:"aggregates,omitempty" yaml:"aggregates,omitempty" mapstructure:"aggregates,omitempty"` // The rows returned by the query, corresponding to the query's fields - Rows []Row `json:"rows,omitempty" yaml:"rows,omitempty" mapstructure:"rows,omitempty"` + Rows []map[string]any `json:"rows,omitempty" yaml:"rows,omitempty" mapstructure:"rows,omitempty"` } // UnmarshalJSON implements json.Unmarshaler. diff --git a/schema/utils.go b/schema/utils.go index 515a8505..540220eb 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -22,15 +22,6 @@ func ToAnySlice[V any](slice []V) []any { return results } -// ToRows converts a typed slice to Row slice -func ToRows[V any](slice []V) []Row { - results := make([]Row, len(slice)) - for i, v := range slice { - results[i] = v - } - return results -} - // Index returns the index of the first occurrence of item in slice, // or -1 if not present. func Index[E comparable](s []E, v E) int { @@ -95,13 +86,9 @@ func unmarshalStringFromJsonMap(collection map[string]json.RawMessage, key strin return result, nil } -// PruneFields prune unnecessary fields from selection -func PruneFields(fields map[string]Field, result any) (any, error) { - if len(fields) == 0 { - return result, nil - } - - if result == nil { +// EncodeRow encodes an object row to a map[string]any, using json tag to convert object keys +func EncodeRow(row any) (map[string]any, error) { + if row == nil { return nil, errors.New("expected object fields, got nil") } @@ -113,10 +100,24 @@ func PruneFields(fields map[string]Field, result any) (any, error) { if err != nil { return nil, err } - if err := decoder.Decode(result); err != nil { + if err := decoder.Decode(row); err != nil { return nil, err } + return outputMap, nil +} + +// PruneFields prune unnecessary fields from selection +func PruneFields(fields map[string]Field, result any) (map[string]any, error) { + outputMap, err := EncodeRow(result) + if err != nil { + return nil, err + } + + if len(fields) == 0 { + return outputMap, nil + } + output := make(map[string]any) for key, field := range fields { f, err := field.Interface() diff --git a/typegen/regenerate-schema.sh b/typegen/regenerate-schema.sh index c60cade5..b068c63a 100755 --- a/typegen/regenerate-schema.sh +++ b/typegen/regenerate-schema.sh @@ -25,10 +25,10 @@ sed -i 's/type MutationOperation interface{}//g' ../schema/schema.generated.go sed -i 's/type MutationRequestOperationsElem interface{}//g' ../schema/schema.generated.go sed -i 's/MutationRequestOperationsElem/MutationOperation/g' ../schema/schema.generated.go sed -i 's/QueryRequestArguments map\[string\]interface{}/QueryRequestArguments map[string]Argument/g' ../schema/schema.generated.go -sed -i 's/RowSetRowsElem map\[string\]interface{}/Row any/g' ../schema/schema.generated.go -sed -i 's/RowSetRowsElem/Row/g' ../schema/schema.generated.go +sed -i 's/type RowSetRowsElem map\[string\]interface{}//g' ../schema/schema.generated.go +sed -i 's/RowSetRowsElem/map[string]any/g' ../schema/schema.generated.go sed -i 's/type MutationOperationResultsReturningElem map\[string\]interface{}//g' ../schema/schema.generated.go -sed -i 's/MutationOperationResultsReturningElem/Row/g' ../schema/schema.generated.go +sed -i 's/MutationOperationResultsReturningElem/map[string]any/g' ../schema/schema.generated.go sed -i 's/Query interface{}/Query Query/g' ../schema/schema.generated.go sed -i 's/OrderBy interface{}/OrderBy *OrderBy/g' ../schema/schema.generated.go sed -i 's/type Expression interface{}//g' ../schema/schema.generated.go @@ -38,6 +38,7 @@ sed -i 's/Where interface{}/Where Expression/g' ../schema/schema.generated.go sed -i 's/type Aggregate interface{}//g' ../schema/schema.generated.go sed -i 's/type OrderByTarget interface{}//g' ../schema/schema.generated.go sed -i 's/QueryAggregates map\[string\]interface{}/QueryAggregates map[string]Aggregate/g' ../schema/schema.generated.go +sed -i 's/RelationshipArguments map\[string\]interface{}/RelationshipArguments map[string]RelationshipArgument/g' ../schema/schema.generated.go sed -i 's/Predicate interface{}/Predicate Expression/g' ../schema/schema.generated.go sed -i 's/type OrderByElementTarget interface{}//g' ../schema/schema.generated.go sed -i 's/OrderByElementTarget/OrderByTarget/g' ../schema/schema.generated.go