diff --git a/ext/dynblock/expr_wrap.go b/ext/dynblock/expr_wrap.go index 25481790..ba737c37 100644 --- a/ext/dynblock/expr_wrap.go +++ b/ext/dynblock/expr_wrap.go @@ -46,6 +46,13 @@ func (e exprWrap) Variables() []hcl.Traversal { return ret } +func (e exprWrap) Functions() []hcl.Traversal { + if fexpr, ok := e.Expression.(hcl.ExpressionWithFunctions); ok { + return fexpr.Functions() + } + return nil +} + func (e exprWrap) Value(ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { if e.i == nil { // If we don't have an active iteration then we can just use the diff --git a/ext/dynblock/functions.go b/ext/dynblock/functions.go new file mode 100644 index 00000000..574487d7 --- /dev/null +++ b/ext/dynblock/functions.go @@ -0,0 +1,228 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package dynblock + +import ( + "github.com/terramate-io/hcl/v2" + "github.com/terramate-io/hcl/v2/hclsyntax" + "github.com/zclconf/go-cty/cty" +) + +// This is duplicated from ext/dynblock/variables.go and modified to suit functions + +// WalkFunctions begins the recursive process of walking all expressions and +// nested blocks in the given body and its child bodies while taking into +// account any "dynamic" blocks. +// +// This function requires that the caller walk through the nested block +// structure in the given body level-by-level so that an appropriate schema +// can be provided at each level to inform further processing. This workflow +// is thus easiest to use for calling applications that have some higher-level +// schema representation available with which to drive this multi-step +// process. If your application uses the hcldec package, you may be able to +// use FunctionsHCLDec instead for a more automatic approach. +func WalkFunctions(body hcl.Body) WalkFunctionsNode { + return WalkFunctionsNode{ + body: body, + includeContent: true, + } +} + +// WalkExpandFunctions is like Functions but it includes only the functions +// required for successful block expansion, ignoring any functions referenced +// inside block contents. The result is the minimal set of all functions +// required for a call to Expand, excluding functions that would only be +// needed to subsequently call Content or PartialContent on the expanded +// body. +func WalkExpandFunctions(body hcl.Body) WalkFunctionsNode { + return WalkFunctionsNode{ + body: body, + } +} + +type WalkFunctionsNode struct { + body hcl.Body + it *iteration + + includeContent bool +} + +type WalkFunctionsChild struct { + BlockTypeName string + Node WalkFunctionsNode +} + +// Body returns the HCL Body associated with the child node, in case the caller +// wants to do some sort of inspection of it in order to decide what schema +// to pass to Visit. +// +// Most implementations should just fetch a fixed schema based on the +// BlockTypeName field and not access this. Deciding on a schema dynamically +// based on the body is a strange thing to do and generally necessary only if +// your caller is already doing other bizarre things with HCL bodies. +func (c WalkFunctionsChild) Body() hcl.Body { + return c.Node.body +} + +// exprFunctions handles the +func exprFunctions(expr hcl.Expression) []hcl.Traversal { + if ef, ok := expr.(hcl.ExpressionWithFunctions); ok { + return ef.Functions() + } + // hclsyntax Fallback + if hsexpr, ok := expr.(hclsyntax.Expression); ok { + return hclsyntax.Functions(hsexpr) + } + // Not exposed + return nil +} + +// Visit returns the function traversals required for any "dynamic" blocks +// directly in the body associated with this node, and also returns any child +// nodes that must be visited in order to continue the walk. +// +// Each child node has its associated block type name given in its BlockTypeName +// field, which the calling application should use to determine the appropriate +// schema for the content of each child node and pass it to the child node's +// own Visit method to continue the walk recursively. +func (n WalkFunctionsNode) Visit(schema *hcl.BodySchema) (vars []hcl.Traversal, children []WalkFunctionsChild) { + extSchema := n.extendSchema(schema) + container, _, _ := n.body.PartialContent(extSchema) + if container == nil { + return vars, children + } + + children = make([]WalkFunctionsChild, 0, len(container.Blocks)) + + if n.includeContent { + for _, attr := range container.Attributes { + for _, traversal := range exprFunctions(attr.Expr) { + var ours, inherited bool + if n.it != nil { + ours = traversal.RootName() == n.it.IteratorName + _, inherited = n.it.Inherited[traversal.RootName()] + } + + if !(ours || inherited) { + vars = append(vars, traversal) + } + } + } + } + + for _, block := range container.Blocks { + switch block.Type { + + case "dynamic": + blockTypeName := block.Labels[0] + inner, _, _ := block.Body.PartialContent(functionDetectionInnerSchema) + if inner == nil { + continue + } + + iteratorName := blockTypeName + if attr, exists := inner.Attributes["iterator"]; exists { + iterTraversal, _ := hcl.AbsTraversalForExpr(attr.Expr) + if len(iterTraversal) == 0 { + // Ignore this invalid dynamic block, since it'll produce + // an error if someone tries to extract content from it + // later anyway. + continue + } + iteratorName = iterTraversal.RootName() + } + blockIt := n.it.MakeChild(iteratorName, cty.DynamicVal, cty.DynamicVal) + + if attr, exists := inner.Attributes["for_each"]; exists { + // Filter out iterator names inherited from parent blocks + for _, traversal := range exprFunctions(attr.Expr) { + if _, inherited := blockIt.Inherited[traversal.RootName()]; !inherited { + vars = append(vars, traversal) + } + } + } + if attr, exists := inner.Attributes["labels"]; exists { + // Filter out both our own iterator name _and_ those inherited + // from parent blocks, since we provide _both_ of these to the + // label expressions. + for _, traversal := range exprFunctions(attr.Expr) { + ours := traversal.RootName() == iteratorName + _, inherited := blockIt.Inherited[traversal.RootName()] + + if !(ours || inherited) { + vars = append(vars, traversal) + } + } + } + + for _, contentBlock := range inner.Blocks { + // We only request "content" blocks in our schema, so we know + // any blocks we find here will be content blocks. We require + // exactly one content block for actual expansion, but we'll + // be more liberal here so that callers can still collect + // functions from erroneous "dynamic" blocks. + children = append(children, WalkFunctionsChild{ + BlockTypeName: blockTypeName, + Node: WalkFunctionsNode{ + body: contentBlock.Body, + it: blockIt, + includeContent: n.includeContent, + }, + }) + } + + default: + children = append(children, WalkFunctionsChild{ + BlockTypeName: block.Type, + Node: WalkFunctionsNode{ + body: block.Body, + it: n.it, + includeContent: n.includeContent, + }, + }) + + } + } + + return vars, children +} + +func (n WalkFunctionsNode) extendSchema(schema *hcl.BodySchema) *hcl.BodySchema { + // We augment the requested schema to also include our special "dynamic" + // block type, since then we'll get instances of it interleaved with + // all of the literal child blocks we must also include. + extSchema := &hcl.BodySchema{ + Attributes: schema.Attributes, + Blocks: make([]hcl.BlockHeaderSchema, len(schema.Blocks), len(schema.Blocks)+1), + } + copy(extSchema.Blocks, schema.Blocks) + extSchema.Blocks = append(extSchema.Blocks, dynamicBlockHeaderSchema) + + return extSchema +} + +// This is a more relaxed schema than what's in schema.go, since we +// want to maximize the amount of functions we can find even if there +// are erroneous blocks. +var functionDetectionInnerSchema = &hcl.BodySchema{ + Attributes: []hcl.AttributeSchema{ + { + Name: "for_each", + Required: false, + }, + { + Name: "labels", + Required: false, + }, + { + Name: "iterator", + Required: false, + }, + }, + Blocks: []hcl.BlockHeaderSchema{ + { + Type: "content", + }, + }, +} diff --git a/ext/dynblock/functions_hcldec.go b/ext/dynblock/functions_hcldec.go new file mode 100644 index 00000000..312d23d6 --- /dev/null +++ b/ext/dynblock/functions_hcldec.go @@ -0,0 +1,48 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package dynblock + +import ( + "github.com/terramate-io/hcl/v2" + "github.com/terramate-io/hcl/v2/hcldec" +) + +// This is duplicated from ext/dynblock/variables_hcldec.go and modified to suit functions + +// FunctionsHCLDec is a wrapper around WalkFunctions that uses the given hcldec +// specification to automatically drive the recursive walk through nested +// blocks in the given body. +// +// This is a drop-in replacement for hcldec.Functions which is able to treat +// blocks of type "dynamic" in the same special way that dynblock.Expand would, +// exposing both the functions referenced in the "for_each" and "labels" +// arguments and functions used in the nested "content" block. +func FunctionsHCLDec(body hcl.Body, spec hcldec.Spec) []hcl.Traversal { + rootNode := WalkFunctions(body) + return walkFunctionsWithHCLDec(rootNode, spec) +} + +// ExpandFunctionsHCLDec is like FunctionsHCLDec but it includes only the +// minimal set of functions required to call Expand, ignoring functions that +// are referenced only inside normal block contents. See WalkExpandFunctions +// for more information. +func ExpandFunctionsHCLDec(body hcl.Body, spec hcldec.Spec) []hcl.Traversal { + rootNode := WalkExpandFunctions(body) + return walkFunctionsWithHCLDec(rootNode, spec) +} + +func walkFunctionsWithHCLDec(node WalkFunctionsNode, spec hcldec.Spec) []hcl.Traversal { + vars, children := node.Visit(hcldec.ImpliedSchema(spec)) + + if len(children) > 0 { + childSpecs := hcldec.ChildBlockTypes(spec) + for _, child := range children { + if childSpec, exists := childSpecs[child.BlockTypeName]; exists { + vars = append(vars, walkFunctionsWithHCLDec(child.Node, childSpec)...) + } + } + } + + return vars +} diff --git a/ext/dynblock/functions_test.go b/ext/dynblock/functions_test.go new file mode 100644 index 00000000..b19234c2 --- /dev/null +++ b/ext/dynblock/functions_test.go @@ -0,0 +1,152 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package dynblock + +import ( + "reflect" + "testing" + + "github.com/terramate-io/hcl/v2/hcldec" + "github.com/zclconf/go-cty/cty" + + "github.com/davecgh/go-spew/spew" + + "github.com/terramate-io/hcl/v2" + "github.com/terramate-io/hcl/v2/hclsyntax" +) + +// This is heavily based on ext/dynblock/variables_test.go + +func TestFunctions(t *testing.T) { + const src = ` + +# We have some references to things inside the "val" attribute inside each +# of our "b" blocks, which should be included in the result of WalkFunctions +# but not WalkExpandFunctions. + +a { + dynamic "b" { + for_each = [for i, v in some_list_0: "${i}=${v},${baz}"] + labels = [b_label_func_0("${b.value}")] + content { + val = "${b_val_func_0(b.value)}" + } + } +} + +dynamic "a" { + for_each = b_fe_func_1(some_list_1) + + content { + b "foo" { + val = b_val_func_1("${a.value}") + } + + dynamic "b" { + for_each = b_fe_func_2(some_list_2) + iterator = dyn_b + labels = [b_label_func_2("${a.value} ${dyn_b.value}")] + content { + val = b_val_func_2("${a.value} ${dyn_b.value}") + } + } + } +} + +dynamic "a" { + for_each = b_fe_func_3(some_list_3) + iterator = dyn_a + + content { + b "foo" { + val = b_val_func_3("${dyn_a.value}") + } + + dynamic "b" { + for_each = b_fe_func_4(some_list_4) + labels = [b_label_func_4("${dyn_a.value} ${b.value}")] + content { + val = b_val_func_4("${dyn_a.value} ${b.value}") + } + } + } +} +` + + f, diags := hclsyntax.ParseConfig([]byte(src), "", hcl.Pos{}) + if len(diags) != 0 { + t.Errorf("unexpected diagnostics during parse") + for _, diag := range diags { + t.Logf("- %s", diag) + } + return + } + + spec := &hcldec.BlockListSpec{ + TypeName: "a", + Nested: &hcldec.BlockMapSpec{ + TypeName: "b", + LabelNames: []string{"key"}, + Nested: &hcldec.AttrSpec{ + Name: "val", + Type: cty.String, + }, + }, + } + + t.Run("WalkFunctions", func(t *testing.T) { + traversals := FunctionsHCLDec(f.Body, spec) + got := make([]string, len(traversals)) + for i, traversal := range traversals { + got[i] = traversal.RootName() + } + + // The block structure is traversed one level at a time, so the ordering + // here is reflecting first a pass of the root, then the first child + // under the root, then the first child under that, etc. + want := []string{ + "b_fe_func_1", + "b_fe_func_3", + "b_label_func_0", + "b_val_func_0", + "b_fe_func_2", + "b_label_func_2", + "b_val_func_1", + "b_val_func_2", + "b_fe_func_4", + "b_label_func_4", + "b_val_func_3", + "b_val_func_4", + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("wrong result\ngot: %swant: %s", spew.Sdump(got), spew.Sdump(want)) + } + }) + + t.Run("WalkExpandFunctions", func(t *testing.T) { + traversals := ExpandFunctionsHCLDec(f.Body, spec) + got := make([]string, len(traversals)) + for i, traversal := range traversals { + got[i] = traversal.RootName() + } + + // The block structure is traversed one level at a time, so the ordering + // here is reflecting first a pass of the root, then the first child + // under the root, then the first child under that, etc. + want := []string{ + "b_fe_func_1", + "b_fe_func_3", + "b_label_func_0", + "b_fe_func_2", + "b_label_func_2", + "b_fe_func_4", + "b_label_func_4", + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("wrong result\ngot: %swant: %s", spew.Sdump(got), spew.Sdump(want)) + } + }) +} diff --git a/gohcl/decode.go b/gohcl/decode.go index 2e3b1dde..ed69d8aa 100644 --- a/gohcl/decode.go +++ b/gohcl/decode.go @@ -123,6 +123,13 @@ func decodeBodyToStruct(body hcl.Body, ctx *hcl.EvalContext, val reflect.Value) fieldV.Set(reflect.ValueOf(attr)) case exprType.AssignableTo(field.Type): fieldV.Set(reflect.ValueOf(attr.Expr)) + case field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct: + // TODO might want to check for nil here + rn := reflect.New(field.Type.Elem()) + fieldV.Set(rn) + diags = append(diags, DecodeExpression( + attr.Expr, ctx, fieldV.Interface(), + )...) default: diags = append(diags, DecodeExpression( attr.Expr, ctx, fieldV.Addr().Interface(), @@ -276,7 +283,9 @@ func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) // DecodeExpression extracts the value of the given expression into the given // value. This value must be something that gocty is able to decode into, -// since the final decoding is delegated to that package. +// since the final decoding is delegated to that package. If a reference to +// a struct is provided which contains gohcl tags, it will be decoded using +// the attr and optional tags. // // The given EvalContext is used to resolve any variables or functions in // expressions encountered while decoding. This may be nil to require only @@ -290,20 +299,59 @@ func decodeBlockToValue(block *hcl.Block, ctx *hcl.EvalContext, v reflect.Value) // integration use-cases. func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{}) hcl.Diagnostics { srcVal, diags := expr.Value(ctx) + if diags.HasErrors() { + return diags + } + + return append(diags, DecodeValue(srcVal, expr.StartRange(), expr.Range(), val)...) +} + +// DecodeValue extracts the given value into the provided target. +// This value must be something that gocty is able to decode into, +// since the final decoding is delegated to that package. If a reference to +// a struct is provided which contains gohcl tags, it will be decoded using +// the attr and optional tags. +// +// The returned diagnostics should be inspected with its HasErrors method to +// determine if the populated value is valid and complete. If error diagnostics +// are returned then the given value may have been partially-populated but +// may still be accessed by a careful caller for static analysis and editor +// integration use-cases. +func DecodeValue(srcVal cty.Value, subject hcl.Range, context hcl.Range, val interface{}) hcl.Diagnostics { + rv := reflect.ValueOf(val) + if rv.Type().Kind() == reflect.Ptr && rv.Type().Elem().Kind() == reflect.Struct && hasFieldTags(rv.Elem().Type()) { + attrs := make(hcl.Attributes) + for k, v := range srcVal.AsValueMap() { + attrs[k] = &hcl.Attribute{ + Name: k, + Expr: hcl.StaticExpr(v, context), + Range: subject, + } + + } + return decodeBodyToStruct(synthBody{ + attrs: attrs, + subject: subject, + context: context, + }, nil, rv.Elem()) + + } convTy, err := gocty.ImpliedType(val) if err != nil { panic(fmt.Sprintf("unsuitable DecodeExpression target: %s", err)) } + var diags hcl.Diagnostics + srcVal, err = convert.Convert(srcVal, convTy) if err != nil { diags = append(diags, &hcl.Diagnostic{ Severity: hcl.DiagError, Summary: "Unsuitable value type", Detail: fmt.Sprintf("Unsuitable value: %s", err.Error()), - Subject: expr.StartRange().Ptr(), - Context: expr.Range().Ptr(), + Subject: subject.Ptr(), + Context: context.Ptr(), }) return diags } @@ -314,10 +362,80 @@ func DecodeExpression(expr hcl.Expression, ctx *hcl.EvalContext, val interface{} Severity: hcl.DiagError, Summary: "Unsuitable value type", Detail: fmt.Sprintf("Unsuitable value: %s", err.Error()), - Subject: expr.StartRange().Ptr(), - Context: expr.Range().Ptr(), + Subject: subject.Ptr(), + Context: context.Ptr(), }) } return diags } + +type synthBody struct { + attrs hcl.Attributes + subject hcl.Range + context hcl.Range +} + +func (s synthBody) Content(schema *hcl.BodySchema) (*hcl.BodyContent, hcl.Diagnostics) { + body, partial, diags := s.PartialContent(schema) + + attrs, _ := partial.JustAttributes() + for name := range attrs { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Unsupported argument", + Detail: fmt.Sprintf("An argument named %q is not expected here.", name), + Subject: s.subject.Ptr(), + Context: s.context.Ptr(), + }) + } + + return body, diags +} + +func (s synthBody) PartialContent(schema *hcl.BodySchema) (*hcl.BodyContent, hcl.Body, hcl.Diagnostics) { + var diags hcl.Diagnostics + + for _, block := range schema.Blocks { + panic("hcl block tags are not allowed in attribute structs: " + block.Type) + } + + attrs := make(hcl.Attributes) + remainder := make(hcl.Attributes) + + for _, attr := range schema.Attributes { + v, ok := s.attrs[attr.Name] + if !ok { + if attr.Required { + diags = append(diags, &hcl.Diagnostic{ + Severity: hcl.DiagError, + Summary: "Missing required argument", + Detail: fmt.Sprintf("The argument %q is required, but no definition was found.", attr.Name), + Subject: s.subject.Ptr(), + Context: s.context.Ptr(), + }) + } + continue + } + + attrs[attr.Name] = v + } + + for k, v := range s.attrs { + if _, ok := attrs[k]; !ok { + remainder[k] = v + } + } + + return &hcl.BodyContent{ + Attributes: attrs, + MissingItemRange: s.context, + }, synthBody{attrs: remainder}, diags +} + +func (s synthBody) JustAttributes() (hcl.Attributes, hcl.Diagnostics) { + return s.attrs, nil +} +func (s synthBody) MissingItemRange() hcl.Range { + return s.context +} diff --git a/gohcl/decode_test.go b/gohcl/decode_test.go index 02f7cacd..eb3b46f7 100644 --- a/gohcl/decode_test.go +++ b/gohcl/decode_test.go @@ -806,6 +806,10 @@ func (e *fixedExpression) Variables() []hcl.Traversal { return nil } +func (e *fixedExpression) Functions() []hcl.Traversal { + return nil +} + func makeInstantiateType(target interface{}) func() interface{} { return func() interface{} { return reflect.New(reflect.TypeOf(target)).Interface() diff --git a/gohcl/doc.go b/gohcl/doc.go index cfec2530..5e103a6a 100644 --- a/gohcl/doc.go +++ b/gohcl/doc.go @@ -10,18 +10,18 @@ // A struct field tag scheme is used, similar to other decoding and // unmarshalling libraries. The tags are formatted as in the following example: // -// ThingType string `hcl:"thing_type,attr"` +// ThingType string `hcl:"thing_type,attr"` // // Within each tag there are two comma-separated tokens. The first is the // name of the corresponding construct in configuration, while the second // is a keyword giving the kind of construct expected. The following // kind keywords are supported: // -// attr (the default) indicates that the value is to be populated from an attribute -// block indicates that the value is to populated from a block -// label indicates that the value is to populated from a block label -// optional is the same as attr, but the field is optional -// remain indicates that the value is to be populated from the remaining body after populating other fields +// attr (the default) indicates that the value is to be populated from an attribute +// block indicates that the value is to populated from a block +// label indicates that the value is to populated from a block label +// optional is the same as attr, but the field is optional +// remain indicates that the value is to be populated from the remaining body after populating other fields // // "attr" fields may either be of type *hcl.Expression, in which case the raw // expression is assigned, or of any type accepted by gocty, in which case diff --git a/gohcl/schema.go b/gohcl/schema.go index 98b994aa..37bdb002 100644 --- a/gohcl/schema.go +++ b/gohcl/schema.go @@ -111,6 +111,18 @@ func ImpliedBodySchema(val interface{}) (schema *hcl.BodySchema, partial bool) { return schema, partial } +func hasFieldTags(ty reflect.Type) bool { + ct := ty.NumField() + for i := 0; i < ct; i++ { + field := ty.Field(i) + tag := field.Tag.Get("hcl") + if tag != "" { + return true + } + } + return false +} + type fieldTags struct { Attributes map[string]int Blocks map[string]int diff --git a/gohcl/vardecode.go b/gohcl/vardecode.go new file mode 100644 index 00000000..7f998fa1 --- /dev/null +++ b/gohcl/vardecode.go @@ -0,0 +1,99 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package gohcl + +import ( + "fmt" + "reflect" + + "github.com/terramate-io/hcl/v2" +) + +func VariablesInBody(body hcl.Body, val interface{}) ([]hcl.Traversal, hcl.Diagnostics) { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Ptr { + panic(fmt.Sprintf("target value must be a pointer, not %s", rv.Type().String())) + } + + return findVariablesInBody(body, rv.Elem()) +} + +func findVariablesInBody(body hcl.Body, val reflect.Value) ([]hcl.Traversal, hcl.Diagnostics) { + et := val.Type() + switch et.Kind() { + case reflect.Struct: + return findVariablesInBodyStruct(body, val) + case reflect.Map: + return findVariablesInBodyMap(body, val) + default: + panic(fmt.Sprintf("target value must be pointer to struct or map, not %s", et.String())) + } +} + +func findVariablesInBodyStruct(body hcl.Body, val reflect.Value) ([]hcl.Traversal, hcl.Diagnostics) { + var variables []hcl.Traversal + + schema, partial := ImpliedBodySchema(val.Interface()) + + var content *hcl.BodyContent + var diags hcl.Diagnostics + if partial { + content, _, diags = body.PartialContent(schema) + } else { + content, diags = body.Content(schema) + } + if content == nil { + return variables, diags + } + + tags := getFieldTags(val.Type()) + + for name := range tags.Attributes { + attr := content.Attributes[name] + if attr != nil { + variables = append(variables, attr.Expr.Variables()...) + } + } + + blocksByType := content.Blocks.ByType() + + for typeName, fieldIdx := range tags.Blocks { + blocks := blocksByType[typeName] + field := val.Type().Field(fieldIdx) + + ty := field.Type + if ty.Kind() == reflect.Slice { + ty = ty.Elem() + } + if ty.Kind() == reflect.Ptr { + ty = ty.Elem() + } + + for _, block := range blocks { + blockVars, blockDiags := findVariablesInBody(block.Body, reflect.New(ty).Elem()) + variables = append(variables, blockVars...) + diags = append(diags, blockDiags...) + } + + } + + return variables, diags +} + +func findVariablesInBodyMap(body hcl.Body, v reflect.Value) ([]hcl.Traversal, hcl.Diagnostics) { + var variables []hcl.Traversal + + attrs, diags := body.JustAttributes() + if attrs == nil { + return variables, diags + } + + for _, attr := range attrs { + variables = append(variables, attr.Expr.Variables()...) + } + + return variables, diags +} diff --git a/gohcl/vardecode_test.go b/gohcl/vardecode_test.go new file mode 100644 index 00000000..f9cf8a88 --- /dev/null +++ b/gohcl/vardecode_test.go @@ -0,0 +1,84 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package gohcl_test + +import ( + "fmt" + "testing" + + "github.com/terramate-io/hcl/v2" + "github.com/terramate-io/hcl/v2/gohcl" + "github.com/terramate-io/hcl/v2/hclsyntax" + "github.com/zclconf/go-cty/cty" +) + +var data = ` +inner "foo" "bar" { + val = magic.foo.bar + data = { + "z" = nested.value + } +} +` + +type InnerBlock struct { + Type string `hcl:"type,label"` + Name string `hcl:"name,label"` + Value string `hcl:"val"` + Data map[string]string `hcl:"data"` +} + +type OuterBlock struct { + Contents InnerBlock `hcl:"inner,block"` +} + +func Test(t *testing.T) { + + println("> Parse HCL") + file, diags := hclsyntax.ParseConfig([]byte(data), "INLINE", hcl.Pos{Byte: 0, Line: 1, Column: 1}) + + println(diags.Error()) + + ob := &OuterBlock{} + + println() + println("> Detect Variables") + vars, diags := gohcl.VariablesInBody(file.Body, ob) + println(diags.Error()) + for _, v := range vars { + ident := "" + for _, p := range v { + if root, ok := p.(hcl.TraverseRoot); ok { + ident += root.Name + } + if attr, ok := p.(hcl.TraverseAttr); ok { + ident += "." + attr.Name + } + } + println("Required: " + ident) + } + + println() + println("> Decode Body") + + ctx := &hcl.EvalContext{ + Variables: map[string]cty.Value{ + "magic": cty.ObjectVal(map[string]cty.Value{ + "foo": cty.ObjectVal(map[string]cty.Value{ + "bar": cty.StringVal("BAR IS BEST BAR"), + }), + }), + "nested": cty.ObjectVal(map[string]cty.Value{ + "value": cty.StringVal("ZISHERE"), + }), + }, + } + + diags = gohcl.DecodeBody(file.Body, ctx, ob) + println(diags.Error()) + + fmt.Printf("%#v\n", ob) +} diff --git a/hcldec/functions.go b/hcldec/functions.go new file mode 100644 index 00000000..eeee6502 --- /dev/null +++ b/hcldec/functions.go @@ -0,0 +1,41 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hcldec + +import ( + "github.com/terramate-io/hcl/v2" +) + +// This is based off of hcldec/variables.go + +// Functions processes the given body with the given spec and returns a +// list of the function traversals that would be required to decode +// the same pairing of body and spec. +// +// This can be used to conditionally populate the functions in the EvalContext +// passed to Decode, for applications where a static scope is insufficient. +// +// If the given body is not compliant with the given schema, the result may +// be incomplete, but that's assumed to be okay because the eventual call +// to Decode will produce error diagnostics anyway. +func Functions(body hcl.Body, spec Spec) []hcl.Traversal { + var funcs []hcl.Traversal + schema := ImpliedSchema(spec) + content, _, _ := body.PartialContent(schema) + + if vs, ok := spec.(specNeedingFunctions); ok { + funcs = append(funcs, vs.functionsNeeded(content)...) + } + + var visitFn visitFunc + visitFn = func(s Spec) { + if vs, ok := s.(specNeedingFunctions); ok { + funcs = append(funcs, vs.functionsNeeded(content)...) + } + s.visitSameBodyChildren(visitFn) + } + spec.visitSameBodyChildren(visitFn) + + return funcs +} diff --git a/hcldec/functions_test.go b/hcldec/functions_test.go new file mode 100644 index 00000000..5590075d --- /dev/null +++ b/hcldec/functions_test.go @@ -0,0 +1,217 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hcldec + +import ( + "fmt" + "reflect" + "testing" + + "github.com/terramate-io/hcl/v2" + "github.com/terramate-io/hcl/v2/hclsyntax" + "github.com/zclconf/go-cty/cty" +) + +// This is inspired by hcldec/variables_test.go + +func TestFunctions(t *testing.T) { + tests := []struct { + config string + spec Spec + want []hcl.Traversal + }{ + { + ``, + &ObjectSpec{}, + nil, + }, + { + "a = foo()\n", + &ObjectSpec{}, + nil, // "a" is not actually used, so "foo" is not required + }, + { + "a = foo()\n", + &AttrSpec{ + Name: "a", + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "foo", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 1, Column: 5, Byte: 4}, + End: hcl.Pos{Line: 1, Column: 8, Byte: 7}, + }, + }, + }, + }, + }, + { + "a = foo()\nb = bar()\n", + &DefaultSpec{ + Primary: &AttrSpec{ + Name: "a", + }, + Default: &AttrSpec{ + Name: "b", + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "foo", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 1, Column: 5, Byte: 4}, + End: hcl.Pos{Line: 1, Column: 8, Byte: 7}, + }, + }, + }, + { + hcl.TraverseRoot{ + Name: "bar", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 2, Column: 5, Byte: 14}, + End: hcl.Pos{Line: 2, Column: 8, Byte: 17}, + }, + }, + }, + }, + }, + { + "a = foo()\n", + &ObjectSpec{ + "a": &AttrSpec{ + Name: "a", + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "foo", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 1, Column: 5, Byte: 4}, + End: hcl.Pos{Line: 1, Column: 8, Byte: 7}, + }, + }, + }, + }, + }, + { + ` +b { + a = foo() +} +`, + &BlockSpec{ + TypeName: "b", + Nested: &AttrSpec{ + Name: "a", + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "foo", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 3, Column: 7, Byte: 11}, + End: hcl.Pos{Line: 3, Column: 10, Byte: 14}, + }, + }, + }, + }, + }, + { + ` +b { + a = foo() + b = bar() +} + `, + &BlockAttrsSpec{ + TypeName: "b", + ElementType: cty.String, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "foo", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 3, Column: 7, Byte: 11}, + End: hcl.Pos{Line: 3, Column: 10, Byte: 14}, + }, + }, + }, + { + hcl.TraverseRoot{ + Name: "bar", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 4, Column: 7, Byte: 23}, + End: hcl.Pos{Line: 4, Column: 10, Byte: 26}, + }, + }, + }, + }, + }, + { + ` +b { + a = foo() +} +b { + a = bar() +} +c { + a = baz() +} +`, + &BlockListSpec{ + TypeName: "b", + Nested: &AttrSpec{ + Name: "a", + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "foo", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 3, Column: 7, Byte: 11}, + End: hcl.Pos{Line: 3, Column: 10, Byte: 14}, + }, + }, + }, + { + hcl.TraverseRoot{ + Name: "bar", + SrcRange: hcl.Range{ + Start: hcl.Pos{Line: 6, Column: 7, Byte: 29}, + End: hcl.Pos{Line: 6, Column: 10, Byte: 32}, + }, + }, + }, + }, + }, /**/ + } + + for i, test := range tests { + t.Run(fmt.Sprintf("%02d-%s", i, test.config), func(t *testing.T) { + file, diags := hclsyntax.ParseConfig([]byte(test.config), "", hcl.Pos{Line: 1, Column: 1, Byte: 0}) + if len(diags) != 0 { + t.Errorf("wrong number of diagnostics from ParseConfig %d; want %d", len(diags), 0) + for _, diag := range diags { + t.Logf(" - %s", diag.Error()) + } + } + body := file.Body + + got := Functions(body, test.spec) + + if !reflect.DeepEqual(got, test.want) { + t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.want) + } + }) + } + +} diff --git a/hcldec/spec.go b/hcldec/spec.go index b50cd305..84f69c16 100644 --- a/hcldec/spec.go +++ b/hcldec/spec.go @@ -67,6 +67,12 @@ type specNeedingVariables interface { variablesNeeded(content *hcl.BodyContent) []hcl.Traversal } +// specNeedingFunctions is implemented by specs that can use functions +// from the EvalContext, to declare which functions they need. +type specNeedingFunctions interface { + functionsNeeded(content *hcl.BodyContent) []hcl.Traversal +} + // UnknownBody can be optionally implemented by an hcl.Body instance which may // be entirely unknown. type UnknownBody interface { @@ -182,6 +188,19 @@ func (s *AttrSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traversal { return attr.Expr.Variables() } +// specNeedingFunctions implementation +func (s *AttrSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + attr, exists := content.Attributes[s.Name] + if !exists { + return nil + } + + if fexpr, ok := attr.Expr.(hcl.ExpressionWithFunctions); ok { + return fexpr.Functions() + } + return nil +} + // attrSpec implementation func (s *AttrSpec) attrSchemata() []hcl.AttributeSchema { return []hcl.AttributeSchema{ @@ -288,6 +307,14 @@ func (s *ExprSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traversal { return s.Expr.Variables() } +// specNeedingFunctions implementation +func (s *ExprSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + if fexpr, ok := s.Expr.(hcl.ExpressionWithFunctions); ok { + return fexpr.Functions() + } + return nil +} + func (s *ExprSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { return s.Expr.Value(ctx) } @@ -351,6 +378,25 @@ func (s *BlockSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traversal { return Variables(childBlock.Body, s.Nested) } +// specNeedingFunctions implementation +func (s *BlockSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + var childBlock *hcl.Block + for _, candidate := range content.Blocks { + if candidate.Type != s.TypeName { + continue + } + + childBlock = candidate + break + } + + if childBlock == nil { + return nil + } + + return Functions(childBlock.Body, s.Nested) +} + func (s *BlockSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics @@ -463,6 +509,21 @@ func (s *BlockListSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traversa return ret } +// specNeedingFunctions implementation +func (s *BlockListSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + var ret []hcl.Traversal + + for _, childBlock := range content.Blocks { + if childBlock.Type != s.TypeName { + continue + } + + ret = append(ret, Functions(childBlock.Body, s.Nested)...) + } + + return ret +} + func (s *BlockListSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics @@ -626,6 +687,21 @@ func (s *BlockTupleSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Travers return ret } +// specNeedingFunctions implementation +func (s *BlockTupleSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + var ret []hcl.Traversal + + for _, childBlock := range content.Blocks { + if childBlock.Type != s.TypeName { + continue + } + + ret = append(ret, Functions(childBlock.Body, s.Nested)...) + } + + return ret +} + func (s *BlockTupleSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics @@ -749,6 +825,21 @@ func (s *BlockSetSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traversal return ret } +// specNeedingFunctions implementation +func (s *BlockSetSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + var ret []hcl.Traversal + + for _, childBlock := range content.Blocks { + if childBlock.Type != s.TypeName { + continue + } + + ret = append(ret, Functions(childBlock.Body, s.Nested)...) + } + + return ret +} + func (s *BlockSetSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics @@ -911,6 +1002,21 @@ func (s *BlockMapSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traversal return ret } +// specNeedingFunctions implementation +func (s *BlockMapSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + var ret []hcl.Traversal + + for _, childBlock := range content.Blocks { + if childBlock.Type != s.TypeName { + continue + } + + ret = append(ret, Functions(childBlock.Body, s.Nested)...) + } + + return ret +} + func (s *BlockMapSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics @@ -1069,6 +1175,21 @@ func (s *BlockObjectSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Traver return ret } +// specNeedingFunctions implementation +func (s *BlockObjectSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + var ret []hcl.Traversal + + for _, childBlock := range content.Blocks { + if childBlock.Type != s.TypeName { + continue + } + + ret = append(ret, Functions(childBlock.Body, s.Nested)...) + } + + return ret +} + func (s *BlockObjectSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics @@ -1246,6 +1367,36 @@ func (s *BlockAttrsSpec) variablesNeeded(content *hcl.BodyContent) []hcl.Travers return vars } +// specNeedingFunctions implementation +func (s *BlockAttrsSpec) functionsNeeded(content *hcl.BodyContent) []hcl.Traversal { + + block, _ := s.findBlock(content) + if block == nil { + return nil + } + + var funcs []hcl.Traversal + + attrs, diags := block.Body.JustAttributes() + if diags.HasErrors() { + return nil + } + + for _, attr := range attrs { + if fexpr, ok := attr.Expr.(hcl.ExpressionWithFunctions); ok { + funcs = append(funcs, fexpr.Functions()...) + } + } + + // We'll return the functions references in source order so that any + // error messages that result are also in source order. + sort.Slice(funcs, func(i, j int) bool { + return funcs[i].SourceRange().Start.Byte < funcs[j].SourceRange().Start.Byte + }) + + return funcs +} + func (s *BlockAttrsSpec) decode(content *hcl.BodyContent, blockLabels []blockLabel, ctx *hcl.EvalContext) (cty.Value, hcl.Diagnostics) { var diags hcl.Diagnostics diff --git a/hclsyntax/expression_funcs.go b/hclsyntax/expression_funcs.go new file mode 100755 index 00000000..82af07cc --- /dev/null +++ b/hclsyntax/expression_funcs.go @@ -0,0 +1,83 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hclsyntax + +// Generated by expression_vars_gen.go. DO NOT EDIT. +// Run 'go generate' on this package to update the set of functions here. + +import ( + "github.com/terramate-io/hcl/v2" +) + +func (e *AnonSymbolExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *BinaryOpExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *ConditionalExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *ExprSyntaxError) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *ForExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *FunctionCallExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *IndexExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *LiteralValueExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *ObjectConsExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *ObjectConsKeyExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *RelativeTraversalExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *ScopeTraversalExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *SplatExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *TemplateExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *TemplateJoinExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *TemplateWrapExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *TupleConsExpr) Functions() []hcl.Traversal { + return Functions(e) +} + +func (e *UnaryOpExpr) Functions() []hcl.Traversal { + return Functions(e) +} diff --git a/hclsyntax/expression_funcs_gen.go b/hclsyntax/expression_funcs_gen.go new file mode 100644 index 00000000..c77f951e --- /dev/null +++ b/hclsyntax/expression_funcs_gen.go @@ -0,0 +1,106 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +// This is a 'go generate'-oriented program for producing the "Functions" +// method on every Expression implementation found within this package. +// All expressions share the same implementation for this method, which +// just wraps the package-level function "Functions" and uses an AST walk +// to do its work. + +//go:build ignore +// +build ignore + +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "sort" +) + +func main() { + fs := token.NewFileSet() + pkgs, err := parser.ParseDir(fs, ".", nil, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "error while parsing: %s\n", err) + os.Exit(1) + } + pkg := pkgs["hclsyntax"] + + // Walk all the files and collect the receivers of any "Value" methods + // that look like they are trying to implement Expression. + var recvs []string + for _, f := range pkg.Files { + for _, decl := range f.Decls { + fd, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if fd.Name.Name != "Value" { + continue + } + results := fd.Type.Results.List + if len(results) != 2 { + continue + } + valResult := fd.Type.Results.List[0].Type.(*ast.SelectorExpr).X.(*ast.Ident) + diagsResult := fd.Type.Results.List[1].Type.(*ast.SelectorExpr).X.(*ast.Ident) + + if valResult.Name != "cty" && diagsResult.Name != "hcl" { + continue + } + + // If we have a method called Value and it returns something in + // "cty" followed by something in "hcl" then that's specific enough + // for now, even though this is not 100% exact as a correct + // implementation of Value. + + recvTy := fd.Recv.List[0].Type + + switch rtt := recvTy.(type) { + case *ast.StarExpr: + name := rtt.X.(*ast.Ident).Name + recvs = append(recvs, fmt.Sprintf("*%s", name)) + default: + fmt.Fprintf(os.Stderr, "don't know what to do with a %T receiver\n", recvTy) + } + + } + } + + sort.Strings(recvs) + + of, err := os.OpenFile("expression_funcs.go", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to open output file: %s\n", err) + os.Exit(1) + } + + fmt.Fprint(of, outputPreamble) + for _, recv := range recvs { + fmt.Fprintf(of, outputMethodFmt, recv) + } + fmt.Fprint(of, "\n") + +} + +const outputPreamble = `// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hclsyntax + +// Generated by expression_vars_gen.go. DO NOT EDIT. +// Run 'go generate' on this package to update the set of functions here. + +import ( + "github.com/terramate-io/hcl/v2" +)` + +const outputMethodFmt = ` + +func (e %s) Functions() []hcl.Traversal { + return Functions(e) +}` diff --git a/hclsyntax/functions.go b/hclsyntax/functions.go new file mode 100644 index 00000000..cbd45f31 --- /dev/null +++ b/hclsyntax/functions.go @@ -0,0 +1,29 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hclsyntax + +import ( + "github.com/terramate-io/hcl/v2" +) + +func Functions(expr Expression) []hcl.Traversal { + walker := make(fnWalker, 0) + Walk(expr, &walker) + return walker +} + +type fnWalker []hcl.Traversal + +func (w *fnWalker) Enter(node Node) hcl.Diagnostics { + if fn, ok := node.(*FunctionCallExpr); ok { + *w = append(*w, hcl.Traversal{hcl.TraverseRoot{ + Name: fn.Name, + SrcRange: fn.NameRange, + }}) + } + return nil +} +func (w *fnWalker) Exit(node Node) hcl.Diagnostics { + return nil +} diff --git a/hclsyntax/functions_test.go b/hclsyntax/functions_test.go new file mode 100644 index 00000000..fe2d046b --- /dev/null +++ b/hclsyntax/functions_test.go @@ -0,0 +1,227 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hclsyntax + +import ( + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/terramate-io/hcl/v2" + "github.com/zclconf/go-cty/cty" +) + +// Covers similar cases to hclsyntax/variables_test.go + +func TestFunctions(t *testing.T) { + tests := []struct { + Expr Expression + Want []hcl.Traversal + }{ + { + &LiteralValueExpr{ + Val: cty.True, + }, + []hcl.Traversal{}, + }, + { + &FunctionCallExpr{ + Name: "funky", + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "funky", + }, + }, + }, + }, + { + &BinaryOpExpr{ + LHS: &FunctionCallExpr{ + Name: "lhs", + }, + Op: OpAdd, + RHS: &FunctionCallExpr{ + Name: "rhs", + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "lhs", + }, + }, + { + hcl.TraverseRoot{ + Name: "rhs", + }, + }, + }, + }, + { + &UnaryOpExpr{ + Val: &FunctionCallExpr{ + Name: "neg", + }, + Op: OpNegate, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "neg", + }, + }, + }, + }, + { + &ConditionalExpr{ + Condition: &FunctionCallExpr{ + Name: "cond", + }, + TrueResult: &FunctionCallExpr{ + Name: "true", + }, + FalseResult: &FunctionCallExpr{ + Name: "false", + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "cond", + }, + }, + { + hcl.TraverseRoot{ + Name: "true", + }, + }, + { + hcl.TraverseRoot{ + Name: "false", + }, + }, + }, + }, + { + &ForExpr{ + KeyVar: "k", + ValVar: "v", + + CollExpr: &FunctionCallExpr{ + Name: "coll", + }, + KeyExpr: &BinaryOpExpr{ + LHS: &FunctionCallExpr{ + Name: "key_lhs", + }, + Op: OpAdd, + RHS: &FunctionCallExpr{ + Name: "key_rhs", + }, + }, + ValExpr: &BinaryOpExpr{ + LHS: &FunctionCallExpr{ + Name: "val_lhs", + }, + Op: OpAdd, + RHS: &FunctionCallExpr{ + Name: "val_rhs", + }, + }, + CondExpr: &BinaryOpExpr{ + LHS: &FunctionCallExpr{ + Name: "cond_lhs", + }, + Op: OpLessThan, + RHS: &FunctionCallExpr{ + Name: "cond_rhs", + }, + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "coll", + }, + }, + { + hcl.TraverseRoot{ + Name: "key_lhs", + }, + }, + { + hcl.TraverseRoot{ + Name: "key_rhs", + }, + }, + { + hcl.TraverseRoot{ + Name: "val_lhs", + }, + }, + { + hcl.TraverseRoot{ + Name: "val_rhs", + }, + }, + { + hcl.TraverseRoot{ + Name: "cond_lhs", + }, + }, + { + hcl.TraverseRoot{ + Name: "cond_rhs", + }, + }, + }, + }, + { + &FunctionCallExpr{ + Name: "funky", + Args: []Expression{ + &FunctionCallExpr{ + Name: "sub_a", + }, + &FunctionCallExpr{ + Name: "sub_b", + }, + }, + }, + []hcl.Traversal{ + { + hcl.TraverseRoot{ + Name: "funky", + }, + }, + { + hcl.TraverseRoot{ + Name: "sub_a", + }, + }, + { + hcl.TraverseRoot{ + Name: "sub_b", + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%#v", test.Expr), func(t *testing.T) { + got := Functions(test.Expr) + + if !reflect.DeepEqual(got, test.Want) { + t.Errorf( + "wrong result\ngot: %s\nwant: %s", + spew.Sdump(got), spew.Sdump(test.Want), + ) + } + }) + } +} diff --git a/hcltest/mock.go b/hcltest/mock.go index 49f3411d..4f72e80b 100644 --- a/hcltest/mock.go +++ b/hcltest/mock.go @@ -144,6 +144,10 @@ func (e mockExprLiteral) Variables() []hcl.Traversal { return nil } +func (e mockExprLiteral) Functions() []hcl.Traversal { + return nil +} + func (e mockExprLiteral) Range() hcl.Range { return hcl.Range{ Filename: "MockExprLiteral", @@ -225,6 +229,10 @@ func (e mockExprVariable) Variables() []hcl.Traversal { } } +func (e mockExprVariable) Functions() []hcl.Traversal { + return nil +} + func (e mockExprVariable) Range() hcl.Range { return hcl.Range{ Filename: "MockExprVariable", @@ -278,6 +286,10 @@ func (e mockExprTraversal) Variables() []hcl.Traversal { return []hcl.Traversal{e.Traversal} } +func (e mockExprTraversal) Functions() []hcl.Traversal { + return nil +} + func (e mockExprTraversal) Range() hcl.Range { return e.Traversal.SourceRange() } @@ -325,6 +337,16 @@ func (e mockExprList) Variables() []hcl.Traversal { return traversals } +func (e mockExprList) Functions() []hcl.Traversal { + var traversals []hcl.Traversal + for _, expr := range e.Exprs { + if fexpr, ok := expr.(hcl.ExpressionWithFunctions); ok { + traversals = append(traversals, fexpr.Functions()...) + } + } + return traversals +} + func (e mockExprList) Range() hcl.Range { return hcl.Range{ Filename: "MockExprList", diff --git a/json/structure.go b/json/structure.go index 8fce12a2..fccab54c 100644 --- a/json/structure.go +++ b/json/structure.go @@ -557,6 +557,52 @@ func (e *expression) Variables() []hcl.Traversal { return vars } +func (e *expression) Functions() []hcl.Traversal { + // This is based off of the logic in Variables() + var funcs []hcl.Traversal + + switch v := e.src.(type) { + case *stringVal: + templateSrc := v.Value + expr, diags := hclsyntax.ParseTemplate( + []byte(templateSrc), + v.SrcRange.Filename, + + // This won't produce _exactly_ the right result, since + // the hclsyntax parser can't "see" any escapes we removed + // while parsing JSON, but it's better than nothing. + hcl.Pos{ + Line: v.SrcRange.Start.Line, + + // skip over the opening quote mark + Byte: v.SrcRange.Start.Byte + 1, + Column: v.SrcRange.Start.Column + 1, + }, + ) + if diags.HasErrors() { + return funcs + } + if fexpr, ok := expr.(hcl.ExpressionWithFunctions); ok { + return fexpr.Functions() + } + case *arrayVal: + for _, jsonVal := range v.Values { + funcs = append(funcs, (&expression{src: jsonVal}).Functions()...) + } + case *objectVal: + for _, jsonAttr := range v.Attrs { + keyExpr := &stringVal{ // we're going to treat key as an expression in this context + Value: jsonAttr.Name, + SrcRange: jsonAttr.NameRange, + } + funcs = append(funcs, (&expression{src: keyExpr}).Functions()...) + funcs = append(funcs, (&expression{src: jsonAttr.Value}).Functions()...) + } + } + + return funcs +} + func (e *expression) Range() hcl.Range { return e.src.Range() } diff --git a/static_expr.go b/static_expr.go index e14d7f89..a725bc5b 100644 --- a/static_expr.go +++ b/static_expr.go @@ -34,6 +34,10 @@ func (e staticExpr) Variables() []Traversal { return nil } +func (e staticExpr) Functions() []Traversal { + return nil +} + func (e staticExpr) Range() Range { return e.rng } diff --git a/structure.go b/structure.go index 2bdf579d..9d44e5ff 100644 --- a/structure.go +++ b/structure.go @@ -126,6 +126,11 @@ type Expression interface { StartRange() Range } +type ExpressionWithFunctions interface { + Expression + Functions() []Traversal +} + // OfType filters the receiving block sequence by block type name, // returning a new block sequence including only the blocks of the // requested type.