From b1352b395ac99007326b4da6025fddd6e333e848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Fabianski?= Date: Tue, 21 Nov 2023 14:35:02 +0100 Subject: [PATCH] feat: improve testing for rules (#1404) * feat: improve testing for rules * chore: add expected to jsonv2 * chore: add test for expected rules in JSONV2 * refactor: make it less costly to find expected rules --- ...stExpectedRule-testdata-data-expected_rule | 5 + e2e/rules/helper.go | 16 ++ e2e/rules/rules_test.go | 15 ++ e2e/rules/testdata/data/expected_rule/main.rb | 14 ++ e2e/rules/testdata/rules/expected_rule.yml | 9 + .../__template__/analyzer/analyzer.go | 218 ------------------ .../__template__/detectors/detectors_test.go | 20 -- .../__template__/detectors/object/object.go | 207 ----------------- .../detectors/object/projection.go | 85 ------- .../__template__/detectors/string/string.go | 53 ----- .../detectors/testdata/class.language | 8 - .../detectors/testdata/no_class.language | 1 - .../detectors/testdata/string.language | 15 -- internal/languages/__template__/language.go | 58 ----- .../languages/__template__/language_test.go | 22 -- .../languages/__template__/pattern/pattern.go | 197 ---------------- .../__template__/testdata/logger.yml | 10 - .../testdata/scope/scope.language | 17 -- .../__template__/testdata/scope_rule.yml | 30 --- .../testcases/flow/different-line.language | 3 - .../testdata/testcases/flow/same-line.php | 2 - internal/report/detections/detections.go | 1 + internal/report/output/dataflow/dataflow.go | 15 +- .../report/output/dataflow/risks/risks.go | 6 +- internal/report/output/security/formatter.go | 12 +- internal/report/output/security/security.go | 17 ++ .../report/output/security/types/types.go | 5 + internal/report/output/types/types.go | 12 +- .../scanner/ast/.snapshots/TestExpectedRules | 67 ++++++ internal/scanner/ast/ast.go | 34 +++ internal/scanner/ast/ast_test.go | 52 +++++ internal/scanner/ast/tree/builder.go | 16 ++ internal/scanner/ast/tree/tree.go | 16 ++ internal/scanner/detectors/types/types.go | 5 +- internal/scanner/detectorset/detectorset.go | 1 + .../languagescanner/languagescanner.go | 40 +++- internal/scanner/rulescanner/rulescanner.go | 5 +- internal/scanner/ruleset/ruleset.go | 4 + internal/scanner/scanner.go | 24 +- 39 files changed, 361 insertions(+), 976 deletions(-) create mode 100644 e2e/rules/.snapshots/TestExpectedRule-testdata-data-expected_rule create mode 100644 e2e/rules/testdata/data/expected_rule/main.rb create mode 100644 e2e/rules/testdata/rules/expected_rule.yml delete mode 100644 internal/languages/__template__/analyzer/analyzer.go delete mode 100644 internal/languages/__template__/detectors/detectors_test.go delete mode 100644 internal/languages/__template__/detectors/object/object.go delete mode 100644 internal/languages/__template__/detectors/object/projection.go delete mode 100644 internal/languages/__template__/detectors/string/string.go delete mode 100644 internal/languages/__template__/detectors/testdata/class.language delete mode 100644 internal/languages/__template__/detectors/testdata/no_class.language delete mode 100644 internal/languages/__template__/detectors/testdata/string.language delete mode 100644 internal/languages/__template__/language.go delete mode 100644 internal/languages/__template__/language_test.go delete mode 100644 internal/languages/__template__/pattern/pattern.go delete mode 100644 internal/languages/__template__/testdata/logger.yml delete mode 100644 internal/languages/__template__/testdata/scope/scope.language delete mode 100644 internal/languages/__template__/testdata/scope_rule.yml delete mode 100644 internal/languages/__template__/testdata/testcases/flow/different-line.language delete mode 100644 internal/languages/__template__/testdata/testcases/flow/same-line.php create mode 100644 internal/scanner/ast/.snapshots/TestExpectedRules diff --git a/e2e/rules/.snapshots/TestExpectedRule-testdata-data-expected_rule b/e2e/rules/.snapshots/TestExpectedRule-testdata-data-expected_rule new file mode 100644 index 000000000..df0c513f7 --- /dev/null +++ b/e2e/rules/.snapshots/TestExpectedRule-testdata-data-expected_rule @@ -0,0 +1,5 @@ +{"source":"Bearer","version":"dev","findings":[{"cwe_ids":["319"],"id":"expected_rule","title":"","description":"","documentation_url":"","line_number":3,"full_filename":"e2e/rules/testdata/data/expected_rule/main.rb","filename":"main.rb","source":{"start":3,"end":3,"column":{"start":3,"end":7}},"sink":{"start":3,"end":3,"column":{"start":3,"end":7},"content":"sink"},"parent_line_number":3,"snippet":"sink","fingerprint":"c50ecec7e1fcfba6cce5fcfab129556c_0","old_fingerprint":"6630ae26e5210b1e43bb4c02426e6be7_0","code_extract":" sink","severity":"low"},{"cwe_ids":["319"],"id":"expected_rule","title":"","description":"","documentation_url":"","line_number":8,"full_filename":"e2e/rules/testdata/data/expected_rule/main.rb","filename":"main.rb","source":{"start":8,"end":8,"column":{"start":3,"end":7}},"sink":{"start":8,"end":8,"column":{"start":3,"end":7},"content":"sink"},"parent_line_number":8,"snippet":"sink","fingerprint":"c50ecec7e1fcfba6cce5fcfab129556c_1","old_fingerprint":"6630ae26e5210b1e43bb4c02426e6be7_1","code_extract":" sink","severity":"low"}],"expected_findings":[{"rule_id":"expected_rule","location":{"start":3,"end":3,"column":{"start":3,"end":7}}},{"rule_id":"expected_rule","location":{"start":8,"end":8,"column":{"start":3,"end":7}}}]} + +-- +Analyzing codebase + diff --git a/e2e/rules/helper.go b/e2e/rules/helper.go index 0eeb778d5..5a35d0d66 100644 --- a/e2e/rules/helper.go +++ b/e2e/rules/helper.go @@ -24,6 +24,22 @@ func buildRulesTestCase(testName, path, ruleID string) testhelper.TestCase { return testhelper.NewTestCase(testName, arguments, options) } +func buildRulesTestCaseJsonV2(testName, path, ruleID string) testhelper.TestCase { + arguments := []string{ + "scan", + path, + "--only-rule=" + ruleID, + "--format=jsonv2", + "--disable-default-rules", + "--exit-code=0", + "--external-rule-dir=" + filepath.Join("e2e", "rules", "testdata", "rules"), + } + + options := testhelper.TestCaseOptions{} + + return testhelper.NewTestCase(testName, arguments, options) +} + func runRulesTest(folderPath string, ruleID string, t *testing.T) { snapshotDirectory := ".snapshots" diff --git a/e2e/rules/rules_test.go b/e2e/rules/rules_test.go index 7aec4b52e..385e82954 100644 --- a/e2e/rules/rules_test.go +++ b/e2e/rules/rules_test.go @@ -55,3 +55,18 @@ func TestRubyRailsDefaultEncryptionStructure(t *testing.T) { func TestRubyRailsDefaultEncryptionSchema(t *testing.T) { runRulesTest("ruby_rails_default_encryption_schema_rb", "ruby_rails_default_encryption", t) } + +func TestExpectedRule(t *testing.T) { + testDataDir := "testdata/data/expected_rule" + + testCases := []testhelper.TestCase{} + testCases = append(testCases, + buildRulesTestCaseJsonV2( + testDataDir, + filepath.Join("e2e", "rules", testDataDir), + "expected_rule", + ), + ) + + testhelper.RunTestsWithSnapshotSubdirectory(t, testCases, ".snapshots") +} diff --git a/e2e/rules/testdata/data/expected_rule/main.rb b/e2e/rules/testdata/data/expected_rule/main.rb new file mode 100644 index 000000000..84f16e04e --- /dev/null +++ b/e2e/rules/testdata/data/expected_rule/main.rb @@ -0,0 +1,14 @@ +def m + # bearer:expected expected_rule + sink +end + +def n + # bearer:expected expected_rule + sink +end + +def foo + bar +end + diff --git a/e2e/rules/testdata/rules/expected_rule.yml b/e2e/rules/testdata/rules/expected_rule.yml new file mode 100644 index 000000000..958f3d4ba --- /dev/null +++ b/e2e/rules/testdata/rules/expected_rule.yml @@ -0,0 +1,9 @@ +patterns: + - sink +languages: + - ruby +severity: low +metadata: + cwe_id: + - 319 + id: expected_rule diff --git a/internal/languages/__template__/analyzer/analyzer.go b/internal/languages/__template__/analyzer/analyzer.go deleted file mode 100644 index 6669c5056..000000000 --- a/internal/languages/__template__/analyzer/analyzer.go +++ /dev/null @@ -1,218 +0,0 @@ -package analyzer - -import ( - sitter "github.com/smacker/go-tree-sitter" - - "github.com/bearer/bearer/internal/scanner/ast/tree" - "github.com/bearer/bearer/internal/scanner/language" -) - -type analyzer struct { - builder *tree.Builder - scope *language.Scope -} - -func New(builder *tree.Builder) language.Analyzer { - return &analyzer{ - builder: builder, - scope: language.NewScope(nil), - } -} - -func (analyzer *analyzer) Analyze(node *sitter.Node, visitChildren func() error) error { - switch node.Type() { - case "declaration_list", "class_declaration", "anonymous_function_creation_expression", "for_statement", "block", "method_declaration": - return analyzer.withScope(language.NewScope(analyzer.scope), func() error { - return visitChildren() - }) - case "augmented_assignment_expression": - return analyzer.analyzeAugmentedAssignment(node, visitChildren) - case "assignment_expression": - return analyzer.analyzeAssignment(node, visitChildren) - case "parenthesized_expression": - return analyzer.analyzeParentheses(node, visitChildren) - case "conditional_expression": - return analyzer.analyzeConditional(node, visitChildren) - case "function_call_expression", "member_call_expression": - return analyzer.analyzeMethodInvocation(node, visitChildren) - case "member_access_expression": - return analyzer.analyzeFieldAccess(node, visitChildren) - case "simple_parameter", "variadic_parameter": - return analyzer.analyzeParameter(node, visitChildren) - case "switch_statement": - return analyzer.analyzeSwitch(node, visitChildren) - case "switch_block": - return analyzer.analyzeGenericConstruct(node, visitChildren) - case "switch_label": - return visitChildren() - case "dynamic_variable_name": - return analyzer.analyzeDynamicVariableName(node, visitChildren) - case "binary_expression", - "unary_op_expression", - "argument", - "encapsed_string", - "sequence_expression", - "array_element_initializer", - "formal_parameters", - "include_expression", - "include_once_expression", - "require_expression", - "require_once_expression": - return analyzer.analyzeGenericOperation(node, visitChildren) - case "while_statement", "do_statement", "if_statement", "expression_statement", "compound_statement": // statements don't have results - return visitChildren() - case "variable_name": - return visitChildren() - case "match_expression": - analyzer.builder.Dataflow(node, analyzer.builder.ChildrenExcept(node, node.ChildByFieldName("condition"))...) - return visitChildren() - default: - analyzer.builder.Dataflow(node, analyzer.builder.ChildrenFor(node)...) - return visitChildren() - } -} - -// $foo = a -func (analyzer *analyzer) analyzeAssignment(node *sitter.Node, visitChildren func() error) error { - left := node.ChildByFieldName("left") - right := node.ChildByFieldName("right") - analyzer.builder.Alias(node, right) - analyzer.lookupVariable(right) - - err := visitChildren() - - if left.Type() == "variable_name" { - analyzer.scope.Assign(analyzer.builder.ContentFor(left), node) - } - - return err -} - -// $foo .= a -func (analyzer *analyzer) analyzeAugmentedAssignment(node *sitter.Node, visitChildren func() error) error { - left := node.ChildByFieldName("left") - right := node.ChildByFieldName("right") - analyzer.builder.Dataflow(node, left, right) - analyzer.lookupVariable(left) - analyzer.lookupVariable(right) - - err := visitChildren() - - if left.Type() == "variable_name" { - analyzer.scope.Assign(analyzer.builder.ContentFor(left), node) - } - - return err -} - -func (analyzer *analyzer) analyzeParentheses(node *sitter.Node, visitChildren func() error) error { - analyzer.builder.Alias(node, node.NamedChild(0)) - analyzer.lookupVariable(node.NamedChild(0)) - err := visitChildren() - - return err -} - -// a ? x : y -// a ?: x -func (analyzer *analyzer) analyzeConditional(node *sitter.Node, visitChildren func() error) error { - condition := node.ChildByFieldName("condition") - consequence := node.ChildByFieldName("body") - alternative := node.ChildByFieldName("alternative") - - analyzer.lookupVariable(condition) - analyzer.lookupVariable(consequence) - analyzer.lookupVariable(alternative) - - if consequence != nil { - analyzer.builder.Alias(node, consequence, alternative) - } else { - analyzer.builder.Alias(node, condition, alternative) - } - - return visitChildren() -} - -// foo(1, 2); -// foo->bar(1, 2); -func (analyzer *analyzer) analyzeMethodInvocation(node *sitter.Node, visitChildren func() error) error { - analyzer.lookupVariable(node.ChildByFieldName("object")) // method - analyzer.lookupVariable(node.ChildByFieldName("function")) // function - - if arguments := node.ChildByFieldName("arguments"); arguments != nil { - analyzer.builder.Dataflow(node, arguments) - } - - return visitChildren() -} - -// foo->bar -func (analyzer *analyzer) analyzeFieldAccess(node *sitter.Node, visitChildren func() error) error { - analyzer.lookupVariable(node.ChildByFieldName("object")) - - return visitChildren() -} - -// method parameter declaration -// -// fn(bool $a) => $a; -// fn($x = 42) => $x; -// fn($x, ...$rest) => $rest; -func (analyzer *analyzer) analyzeParameter(node *sitter.Node, visitChildren func() error) error { - name := node.ChildByFieldName("name") - analyzer.builder.Alias(node, name) - analyzer.scope.Declare(analyzer.builder.ContentFor(name), name) - - return visitChildren() -} - -func (analyzer *analyzer) analyzeSwitch(node *sitter.Node, visitChildren func() error) error { - analyzer.builder.Alias(node, node.ChildByFieldName("body")) - - return visitChildren() -} - -func (analyzer *analyzer) analyzeDynamicVariableName(node *sitter.Node, visitChildren func() error) error { - analyzer.lookupVariable(node.NamedChild(0)) - - return visitChildren() -} - -// default analysis, where the children are assumed to be aliases -func (analyzer *analyzer) analyzeGenericConstruct(node *sitter.Node, visitChildren func() error) error { - analyzer.builder.Alias(node, analyzer.builder.ChildrenFor(node)...) - - return visitChildren() -} - -// default analysis, where the children are assumed to be data sources -func (analyzer *analyzer) analyzeGenericOperation(node *sitter.Node, visitChildren func() error) error { - children := analyzer.builder.ChildrenFor(node) - analyzer.builder.Dataflow(node, children...) - - for _, child := range children { - analyzer.lookupVariable(child) - } - - return visitChildren() -} - -func (analyzer *analyzer) withScope(newScope *language.Scope, body func() error) error { - oldScope := analyzer.scope - - analyzer.scope = newScope - err := body() - analyzer.scope = oldScope - - return err -} - -func (analyzer *analyzer) lookupVariable(node *sitter.Node) { - if node == nil || node.Type() != "variable_name" { - return - } - - if pointsToNode := analyzer.scope.Lookup(analyzer.builder.ContentFor(node)); pointsToNode != nil { - analyzer.builder.Alias(node, pointsToNode) - } -} diff --git a/internal/languages/__template__/detectors/detectors_test.go b/internal/languages/__template__/detectors/detectors_test.go deleted file mode 100644 index 5d0a798a2..000000000 --- a/internal/languages/__template__/detectors/detectors_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package detectors_test - -import ( - "testing" - - "github.com/bearer/bearer/internal/scanner/detectors/testhelper" -) - -func TestObjects(t *testing.T) { - runTest(t, "object_class", "object", "testdata/class.") - runTest(t, "object_no_class", "object", "testdata/no_class.") -} - -func TestString(t *testing.T) { - runTest(t, "string", "string", "testdata/string.") -} - -func runTest(t *testing.T, name, detectorType, fileName string) { - testhelper.RunTest(t, name, .Get(), detectorType, fileName) -} diff --git a/internal/languages/__template__/detectors/object/object.go b/internal/languages/__template__/detectors/object/object.go deleted file mode 100644 index 2d53d3c0f..000000000 --- a/internal/languages/__template__/detectors/object/object.go +++ /dev/null @@ -1,207 +0,0 @@ -package object - -import ( - "github.com/bearer/bearer/internal/scanner/ast/query" - "github.com/bearer/bearer/internal/scanner/ast/traversalstrategy" - "github.com/bearer/bearer/internal/scanner/ast/tree" - - "github.com/bearer/bearer/internal/scanner/detectors/common" - detectorscommon "github.com/bearer/bearer/internal/scanner/detectors/common" - "github.com/bearer/bearer/internal/scanner/detectors/types" - "github.com/bearer/bearer/internal/scanner/ruleset" -) - -type objectDetector struct { - types.DetectorBase - // Base - classQuery *query.Query - arrayCreationQuery *query.Query - // Naming - assignmentQuery *query.Query - // Projection - fieldAccessQuery *query.Query - subscriptExpressionQuery *query.Query -} - -func New(querySet *query.Set) types.Detector { - // $user = new ; - assignmentQuery := querySet.Add(`[ - (assignment_expression left: (variable_name) @name right: (_) @value) @root - ]`) - - // class User { - // public $name; - // public $gender; - // function set_name($name) { - // $this->name = $name; - // } - // } - classQuery := querySet.Add(` - ( - class_declaration - name: (name) @class_name - body: ( - declaration_list [ - (property_declaration (property_element (variable_name) @name )) - (method_declaration name: (name) @name) - ] - ) - ) @root`) - - // $user->name; - // $user->name(); - fieldAccessQuery := querySet.Add(`[ - (member_access_expression object: (_) @object name: (name) @field) @root - (member_call_expression object: (_) @object name: (name) @field) @root - ]`) - - // array('foo' => 'bar'); - // [ 'foo' => 'bar' ]; - arrayCreationQuery := querySet.Add(` - (array_creation_expression (array_element_initializer . (_) @key . (_) @value )) @root - `) - - // $user["uuid"]; - subscriptExpressionQuery := querySet.Add(` - (subscript_expression (_) @object (_) @key) @root - `) - - return &objectDetector{ - classQuery: classQuery, - arrayCreationQuery: arrayCreationQuery, - assignmentQuery: assignmentQuery, - fieldAccessQuery: fieldAccessQuery, - subscriptExpressionQuery: subscriptExpressionQuery, - } -} - -func (detector *objectDetector) Rule() *ruleset.Rule { - return ruleset.BuiltinObjectRule -} - -func (detector *objectDetector) DetectAt( - node *tree.Node, - detectorContext types.Context, -) ([]interface{}, error) { - detections, err := detector.getAssignment(node, detectorContext) - if len(detections) != 0 || err != nil { - return detections, err - } - - detections, err = detector.getClass(node) - if len(detections) != 0 || err != nil { - return detections, err - } - - detections, err = detector.getArrayCreation(node, detectorContext) - if len(detections) != 0 || err != nil { - return detections, err - } - - return detector.getProjections(node, detectorContext) -} - -func (detector *objectDetector) getArrayCreation( - node *tree.Node, - detectorContext types.Context, -) ([]interface{}, error) { - results := detector.arrayCreationQuery.MatchAt(node) - if len(results) == 0 { - return nil, nil - } - - var properties []detectorscommon.Property - for _, result := range results { - pairNode := result["key"] - name := result["value"].Content() - - propertyObjects, err := detectorContext.Scan(result["value"], ruleset.BuiltinObjectRule, traversalstrategy.Cursor) - if err != nil { - return nil, err - } - - if len(propertyObjects) == 0 { - properties = append(properties, detectorscommon.Property{ - Name: name, - Node: pairNode, - }) - - continue - } - - for _, propertyObject := range propertyObjects { - properties = append(properties, detectorscommon.Property{ - Name: name, - Node: pairNode, - Object: propertyObject, - }) - } - } - - return []interface{}{detectorscommon.Object{Properties: properties}}, nil -} - -func (detector *objectDetector) getAssignment( - node *tree.Node, - detectorContext types.Context, -) ([]interface{}, error) { - result, err := detector.assignmentQuery.MatchOnceAt(node) - - if result == nil || err != nil { - return nil, err - } - - rightObjects, err := common.GetNonVirtualObjects( - detectorContext, - result["value"], - ) - if err != nil { - return nil, err - } - - var objects []interface{} - for _, object := range rightObjects { - objects = append(objects, common.Object{ - IsVirtual: true, - Properties: []common.Property{{ - Name: result["name"].Content(), - Node: node, - Object: object, - }}, - }) - } - - return objects, nil -} - -func (detector *objectDetector) getClass(node *tree.Node) ([]interface{}, error) { - results := detector.classQuery.MatchAt(node) - if len(results) == 0 { - return nil, nil - } - - className := results[0]["class_name"].Content() - - var properties []common.Property - for _, result := range results { - nameNode := result["name"] - - properties = append(properties, common.Property{ - Name: nameNode.Content(), - Node: nameNode, - }) - } - - return []interface{}{common.Object{ - Properties: []common.Property{{ - Name: className, - Object: &types.Detection{ - RuleID: ruleset.BuiltinObjectRule.ID(), - MatchNode: node, - Data: common.Object{ - Properties: properties, - }, - }, - }}, - }}, nil -} diff --git a/internal/languages/__template__/detectors/object/projection.go b/internal/languages/__template__/detectors/object/projection.go deleted file mode 100644 index fc4f7a008..000000000 --- a/internal/languages/__template__/detectors/object/projection.go +++ /dev/null @@ -1,85 +0,0 @@ -package object - -import ( - "github.com/bearer/bearer/internal/scanner/ast/tree" - "github.com/bearer/bearer/internal/util/stringutil" - - "github.com/bearer/bearer/internal/scanner/detectors/common" - detectorscommon "github.com/bearer/bearer/internal/scanner/detectors/common" - "github.com/bearer/bearer/internal/scanner/detectors/types" -) - -func (detector *objectDetector) getProjections( - node *tree.Node, - detectorContext types.Context, -) ([]interface{}, error) { - result, err := detector.fieldAccessQuery.MatchOnceAt(node) - if err != nil { - return nil, err - } - - if result != nil { - objectNode := result["object"] - objects, err := common.ProjectObject( - node, - detectorContext, - objectNode, - getObjectName(objectNode), - result["field"].Content(), - true, - ) - if err != nil { - return nil, err - } - - return objects, nil - } - - result, err = detector.subscriptExpressionQuery.MatchOnceAt(node) - if err != nil { - return nil, err - } - - if result != nil { - objectNode := result["object"] - propertyName := getElementProperty(result["key"]) - if propertyName == "" { - return nil, nil - } - - objects, err := detectorscommon.ProjectObject( - node, - detectorContext, - objectNode, - getObjectName(objectNode), - propertyName, - false, - ) - if err != nil { - return nil, err - } - - return objects, nil - } - - return nil, nil -} - -func getObjectName(objectNode *tree.Node) string { - // $user->name() - // $user->name - if objectNode.Type() == "variable_name" { - return objectNode.Content() - } - - return "" -} - -func getElementProperty(node *tree.Node) string { - switch node.Type() { - case "encapsed_string": - return stringutil.StripQuotes(node.Content()) - default: - return node.Content() - } -} diff --git a/internal/languages/__template__/detectors/string/string.go b/internal/languages/__template__/detectors/string/string.go deleted file mode 100644 index 1993b87dd..000000000 --- a/internal/languages/__template__/detectors/string/string.go +++ /dev/null @@ -1,53 +0,0 @@ -package string - -import ( - "github.com/bearer/bearer/internal/scanner/ast/query" - "github.com/bearer/bearer/internal/scanner/ast/tree" - "github.com/bearer/bearer/internal/scanner/ruleset" - "github.com/bearer/bearer/internal/util/stringutil" - - "github.com/bearer/bearer/internal/scanner/detectors/common" - "github.com/bearer/bearer/internal/scanner/detectors/types" -) - -type stringDetector struct { - types.DetectorBase -} - -func New(querySet *query.Set) types.Detector { - return &stringDetector{} -} - -func (detector *stringDetector) Rule() *ruleset.Rule { - return ruleset.BuiltinStringRule -} - -func (detector *stringDetector) DetectAt( - node *tree.Node, - detectorContext types.Context, -) ([]interface{}, error) { - switch node.Type() { - case "string": - value := node.Content() - if node.Parent() != nil && node.Parent().Type() != "encapsed_string" { - value = stringutil.StripQuotes(value) - } - - return []interface{}{common.String{ - Value: value, - IsLiteral: true, - }}, nil - case "encapsed_string": - return common.ConcatenateChildStrings(node, detectorContext) - case "binary_expression": - if node.Children()[1].Content() == "." { - return common.ConcatenateChildStrings(node, detectorContext) - } - case "augmented_assignment_expression": - if node.Children()[1].Content() == ".=" { - return common.ConcatenateAssignEquals(node, detectorContext) - } - } - - return nil, nil -} diff --git a/internal/languages/__template__/detectors/testdata/class.language b/internal/languages/__template__/detectors/testdata/class.language deleted file mode 100644 index 57f2add9a..000000000 --- a/internal/languages/__template__/detectors/testdata/class.language +++ /dev/null @@ -1,8 +0,0 @@ -class User -{ - public $name = ''; - - public function LowercaseName() { - echo strtolower($this->name); - } -} \ No newline at end of file diff --git a/internal/languages/__template__/detectors/testdata/no_class.language b/internal/languages/__template__/detectors/testdata/no_class.language deleted file mode 100644 index 224354883..000000000 --- a/internal/languages/__template__/detectors/testdata/no_class.language +++ /dev/null @@ -1 +0,0 @@ -$user.name(); \ No newline at end of file diff --git a/internal/languages/__template__/detectors/testdata/string.language b/internal/languages/__template__/detectors/testdata/string.language deleted file mode 100644 index 4f3713049..000000000 --- a/internal/languages/__template__/detectors/testdata/string.language +++ /dev/null @@ -1,15 +0,0 @@ -class Greet { - const Greeting = "Hello World"; - - public static function main($args) - { - $s = self::Greeting . "!"; - $s .= "!!"; - - $s2 = "hey "; - $s2 .= $args[0]; - $s2 .= " there"; - - $s3 = "foo '{$s2}' bar"; - } -} diff --git a/internal/languages/__template__/language.go b/internal/languages/__template__/language.go deleted file mode 100644 index df5dad6cf..000000000 --- a/internal/languages/__template__/language.go +++ /dev/null @@ -1,58 +0,0 @@ -package language - -import ( - sitter "github.com/smacker/go-tree-sitter" - - "github.com/bearer/bearer/internal/classification/schema" - "github.com/bearer/bearer/internal/report/detectors" - "github.com/bearer/bearer/internal/scanner/ast/query" - "github.com/bearer/bearer/internal/scanner/ast/tree" - detectortypes "github.com/bearer/bearer/internal/scanner/detectors/types" - - "github.com/bearer/bearer/internal/languages/language/analyzer" - "github.com/bearer/bearer/internal/languages/language/detectors/object" - stringdetector "github.com/bearer/bearer/internal/languages/language/detectors/string" - "github.com/bearer/bearer/internal/languages/language/pattern" - "github.com/bearer/bearer/internal/scanner/detectors/datatype" - "github.com/bearer/bearer/internal/scanner/detectors/insecureurl" - "github.com/bearer/bearer/internal/scanner/detectors/stringliteral" - "github.com/bearer/bearer/internal/scanner/language" -) - -type implementation struct { - pattern pattern.Pattern -} - -func Get() language.Language { - return &implementation{} -} - -func (*implementation) ID() string { - return "" -} - -func (*implementation) EnryLanguages() []string { - return []string{""} -} - -func (*implementation) NewBuiltInDetectors(schemaClassifier *schema.Classifier, querySet *query.Set) []detectortypes.Detector { - return []detectortypes.Detector{ - object.New(querySet), - datatype.New(detectors.Detector, schemaClassifier), - stringdetector.New(querySet), - stringliteral.New(querySet), - insecureurl.New(querySet), - } -} - -func (*implementation) SitterLanguage() *sitter.Language { - return language.GetLanguage() -} - -func (language *implementation) Pattern() language.Pattern { - return &language.pattern -} - -func (*implementation) NewAnalyzer(builder *tree.Builder) language.Analyzer { - return analyzer.New(builder) -} diff --git a/internal/languages/__template__/language_test.go b/internal/languages/__template__/language_test.go deleted file mode 100644 index f3f4305ec..000000000 --- a/internal/languages/__template__/language_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package language_test - -import ( - _ "embed" - "testing" - - "github.com/bearer/bearer/internal/languages/testhelper" -) - -//go:embed testdata/logger.yml -var loggerRule []byte - -//go:embed testdata/scope_rule.yml -var scopeRule []byte - -func TestFlow(t *testing.T) { - testhelper.GetRunner(t, loggerRule, "").RunTest(t, "./testdata/testcases/flow", ".snapshots/flow/") -} - -func TestScope(t *testing.T) { - testhelper.GetRunner(t, scopeRule, "").RunTest(t, "./testdata/scope", ".snapshots/") -} diff --git a/internal/languages/__template__/pattern/pattern.go b/internal/languages/__template__/pattern/pattern.go deleted file mode 100644 index e5b8ef731..000000000 --- a/internal/languages/__template__/pattern/pattern.go +++ /dev/null @@ -1,197 +0,0 @@ -package pattern - -import ( - "fmt" - "regexp" - "slices" - "strings" - - "github.com/bearer/bearer/internal/scanner/ast/tree" - "github.com/bearer/bearer/internal/scanner/language" - "github.com/bearer/bearer/internal/util/regex" -) - -var ( - // $ or $ or $ - patternQueryVariableRegex = regexp.MustCompile(`\$<(?P[^>:!\.]+)(?::(?P[^>]+))?>`) - matchNodeRegex = regexp.MustCompile(`\$`) - ellipsisRegex = regexp.MustCompile(`\$<\.\.\.>`) - unanchoredPatternNodeTypes = []string{} - patternMatchNodeContainerTypes = []string{"formal_parameters", "simple_parameter", "argument"} - - allowedPatternQueryTypes = []string{"_"} -) - -type Pattern struct { - language.PatternBase -} - -// func (*Pattern) AdjustInput(input string) string { -// return input -// } - -// func (*Pattern) FixupMissing(node *tree.Node) string { -// if node.Type() != `";"` { -// return "" -// } - -// return ";" -// } - -func (*Pattern) FixupVariableDummyValue(input []byte, node *tree.Node, dummyValue string) string { - if slices.Contains([]string{"named_type"}, node.Parent().Type()) { - return "$" + dummyValue - } - - return dummyValue -} - -func (*Pattern) ExtractVariables(input string) (string, []language.PatternVariable, error) { - nameIndex := patternQueryVariableRegex.SubexpIndex("name") - typesIndex := patternQueryVariableRegex.SubexpIndex("types") - i := 0 - - var params []language.PatternVariable - - replaced, err := regex.ReplaceAllWithSubmatches(patternQueryVariableRegex, input, func(submatches []string) (string, error) { - nodeTypes := strings.Split(submatches[typesIndex], "|") - if nodeTypes[0] == "" { - nodeTypes = []string{"_"} - } - - for _, nodeType := range nodeTypes { - if !slices.Contains(allowedPatternQueryTypes, nodeType) { - return "", fmt.Errorf("invalid node type '%s' in pattern query", nodeType) - } - } - - dummyValue := produceDummyValue(i, nodeTypes[0]) - - params = append(params, language.PatternVariable{ - Name: submatches[nameIndex], - NodeTypes: nodeTypes, - DummyValue: dummyValue, - }) - - i += 1 - - return dummyValue, nil - }) - - if err != nil { - return "", nil, err - } - - return replaced, params, nil -} - -func produceDummyValue(i int, nodeType string) string { - return "BearerVar" + fmt.Sprint(i) -} - -func (*Pattern) FindMatchNode(input []byte) [][]int { - return matchNodeRegex.FindAllIndex(input, -1) -} - -func (*Pattern) FindUnanchoredPoints(input []byte) [][]int { - return ellipsisRegex.FindAllIndex(input, -1) -} - -func (*Pattern) IsLeaf(node *tree.Node) bool { - // Encapsed string literal - switch node.Type() { - case "encapsed_string": - namedChildren := node.NamedChildren() - if len(namedChildren) == 1 && namedChildren[0].Type() == "string" { - return true - } - } - return false -} - -func (*Pattern) LeafContentTypes() []string { - return []string{ - "encapsed_string", - "string", - "name", - "integer", - "float", - "boolean", - } -} - -func (*Pattern) IsAnchored(node *tree.Node) (bool, bool) { - if slices.Contains(unanchoredPatternNodeTypes, node.Type()) { - return false, false - } - - parent := node.Parent() - if parent == nil { - return true, true - } - - if parent.Type() == "method_declaration" { - // visibility - if node == parent.ChildByFieldName("name") { - return false, true - } - - // type - if node == parent.ChildByFieldName("parameters") { - return true, false - } - - return false, false - } - - // Associative array elements are unanchored - // eg. array("foo" => 42) - if parent.Type() == "array_creation_expression" && - node.Type() == "array_element_initializer" && - len(node.NamedChildren()) == 2 { - return false, false - } - - // Class body declaration_list - // function/block compound_statement - unAnchored := []string{ - "declaration_list", - "compound_statement", - } - - isUnanchored := !slices.Contains(unAnchored, parent.Type()) - return isUnanchored, isUnanchored -} - -func (*Pattern) IsRoot(node *tree.Node) bool { - return !slices.Contains([]string{"expression_statement", "program"}, node.Type()) && !node.IsMissing() -} - -func (patternLanguage *Pattern) NodeTypes(node *tree.Node) []string { - parent := node.Parent() - if parent == nil { - return []string{node.Type()} - } - - if (node.Type() == "string" && parent.Type() != "encapsed_string") || - (node.Type() == "encapsed_string" && patternLanguage.IsLeaf(node)) { - return []string{"encapsed_string", "string"} - } - - return []string{node.Type()} -} - -func (*Pattern) TranslateContent(fromNodeType, toNodeType, content string) string { - if fromNodeType == "string" && toNodeType == "encapsed_string" { - return fmt.Sprintf(`"%s"`, content[1:len(content)-1]) - } - if fromNodeType == "encapsed_string" && toNodeType == "string" { - return fmt.Sprintf("'%s'", content[1:len(content)-1]) - } - - return content -} - -func (*Pattern) ContainerTypes() []string { - return patternMatchNodeContainerTypes -} diff --git a/internal/languages/__template__/testdata/logger.yml b/internal/languages/__template__/testdata/logger.yml deleted file mode 100644 index 7be35f485..000000000 --- a/internal/languages/__template__/testdata/logger.yml +++ /dev/null @@ -1,10 +0,0 @@ -type: "risk" -languages: - - -patterns: - - pattern: error_log($) - filters: - - variable: DATA_TYPE - detection: datatype -metadata: - id: rule_logger_test diff --git a/internal/languages/__template__/testdata/scope/scope.language b/internal/languages/__template__/testdata/scope/scope.language deleted file mode 100644 index 73504fdc1..000000000 --- a/internal/languages/__template__/testdata/scope/scope.language +++ /dev/null @@ -1,17 +0,0 @@ -scopeCursor($_GET["oops"]); -scopeCursor(x . $_GET["ok"]); -scopeCursor(x ? $_GET["oops"] : y); -scopeCursor($_GET["ok"] ? x : y); -scopeCursor($_GET["oops"] ?: y); - -scopeNested($_GET["oops"]); -scopeNested(x . $_GET["oops"]); -scopeNested(x ? $_GET["oops"] : y); -scopeNested($_GET["oops"] ? x : y); -scopeNested($_GET["oops"] ?: y); - -scopeResult($_GET["oops"]); -scopeResult(x . $_GET["oops"]); -scopeResult(x ? $_GET["oops"] : y); -scopeResult($_GET["ok"] ? x : y); -scopeResult($_GET["oops"] ?: y); \ No newline at end of file diff --git a/internal/languages/__template__/testdata/scope_rule.yml b/internal/languages/__template__/testdata/scope_rule.yml deleted file mode 100644 index 0cf504ea8..000000000 --- a/internal/languages/__template__/testdata/scope_rule.yml +++ /dev/null @@ -1,30 +0,0 @@ -languages: - - -patterns: - - pattern: scopeCursor($) - filters: - - variable: USER_INPUT - detection: scope_test_user_input - scope: cursor - - pattern: scopeNested($) - filters: - - variable: USER_INPUT - detection: scope_test_user_input - scope: nested - - pattern: scopeResult($) - filters: - - variable: USER_INPUT - detection: scope_test_user_input - scope: result -auxiliary: - - id: scope_test_user_input - patterns: - - $_GET[$<_>] - - $_POST[$<_>] -severity: high -metadata: - description: Test detection filter scopes - remediation_message: Test detection filter scopes - cwe_id: - - 42 - id: scope_test diff --git a/internal/languages/__template__/testdata/testcases/flow/different-line.language b/internal/languages/__template__/testdata/testcases/flow/different-line.language deleted file mode 100644 index ef092479a..000000000 --- a/internal/languages/__template__/testdata/testcases/flow/different-line.language +++ /dev/null @@ -1,3 +0,0 @@ -$user = new User(); -$name = $user->name; -error_log($name); \ No newline at end of file diff --git a/internal/languages/__template__/testdata/testcases/flow/same-line.php b/internal/languages/__template__/testdata/testcases/flow/same-line.php deleted file mode 100644 index d10733f42..000000000 --- a/internal/languages/__template__/testdata/testcases/flow/same-line.php +++ /dev/null @@ -1,2 +0,0 @@ -error_log($user->name); -error_log($user->name()); \ No newline at end of file diff --git a/internal/report/detections/detections.go b/internal/report/detections/detections.go index aed594aef..ef0a02ba1 100644 --- a/internal/report/detections/detections.go +++ b/internal/report/detections/detections.go @@ -27,6 +27,7 @@ var TypeSecretleak DetectionType = "secret_leak" var TypeCustom DetectionType = "custom" var TypeCustomClassified DetectionType = "custom_classified" var TypeCustomRisk DetectionType = "custom_risk" +var TypeExpectedDetection DetectionType = "expected_detection" type ReportDetection interface { AddDetection(detectionType DetectionType, detectorType detectors.Type, source source.Source, value interface{}) diff --git a/internal/report/output/dataflow/dataflow.go b/internal/report/output/dataflow/dataflow.go index 6b18a1172..6050d8d37 100644 --- a/internal/report/output/dataflow/dataflow.go +++ b/internal/report/output/dataflow/dataflow.go @@ -30,6 +30,7 @@ var allowedDetections []detections.DetectionType = []detections.DetectionType{ detections.TypeError, detections.TypeFileList, detections.TypeFileFailed, + detections.TypeExpectedDetection, } func contains(detections []detections.DetectionType, detection detections.DetectionType) bool { @@ -48,6 +49,7 @@ func AddReportData(reportData *types.ReportData, config settings.Config, isInter return nil } + expectedHolder := risks.New(config, isInternal) dataTypesHolder := datatypes.New(config, isInternal) risksHolder := risks.New(config, isInternal) componentsHolder := components.New(isInternal) @@ -142,6 +144,8 @@ func AddReportData(reportData *types.ReportData, config settings.Config, isInter if err = dataTypesHolder.AddSchema(castDetection, detectionExtras); err != nil { return err } + case detections.TypeExpectedDetection: + expectedHolder.AddRiskPresence(castDetection) case detections.TypeCustomRisk: ruleName := string(castDetection.DetectorType) customDetector, ok := config.Rules[ruleName] @@ -224,11 +228,12 @@ func AddReportData(reportData *types.ReportData, config settings.Config, isInter reportData.Files = files reportData.Dataflow = &types.DataFlow{ - Datatypes: dataTypesHolder.ToDataFlow(), - Risks: risksHolder.ToDataFlow(), - Components: componentsHolder.ToDataFlow(), - Dependencies: componentsHolder.ToDataFlowForDependencies(), - Errors: errorsHolder.ToDataFlow(), + Datatypes: dataTypesHolder.ToDataFlow(), + ExpectedDetections: expectedHolder.ToDataFlow(), + Risks: risksHolder.ToDataFlow(), + Components: componentsHolder.ToDataFlow(), + Dependencies: componentsHolder.ToDataFlowForDependencies(), + Errors: errorsHolder.ToDataFlow(), } return nil diff --git a/internal/report/output/dataflow/risks/risks.go b/internal/report/output/dataflow/risks/risks.go index a523d2071..79155cec4 100644 --- a/internal/report/output/dataflow/risks/risks.go +++ b/internal/report/output/dataflow/risks/risks.go @@ -242,12 +242,10 @@ func (holder *Holder) ToDataFlow() []types.RiskDetector { constructedDetector := types.RiskDetector{ DetectorID: detector.id, } - locations := []types.RiskLocation{} + locations := []types.RiskLocation{} for _, file := range maputil.ToSortedSlice(detector.files) { - for _, line := range maputil.ToSortedSlice(file.startLineNumber) { - for _, source := range maputil.ToSortedSlice(line.source) { location := types.RiskLocation{ Filename: file.name, @@ -304,9 +302,7 @@ func (holder *Holder) ToDataFlow() []types.RiskDetector { locations = append(locations, location) } - } - } constructedDetector.Locations = locations diff --git a/internal/report/output/security/formatter.go b/internal/report/output/security/formatter.go index 7a965011b..91f1561cf 100644 --- a/internal/report/output/security/formatter.go +++ b/internal/report/output/security/formatter.go @@ -25,10 +25,11 @@ type Formatter struct { EndTime time.Time } -type RawFindingsOutput struct { - Source string `json:"source" yaml:"source"` - Version string `json:"version" yaml:"version"` - Findings RawFindings `json:"findings" yaml:"findings"` +type JsonV2Output struct { + Source string `json:"source" yaml:"source"` + Version string `json:"version" yaml:"version"` + Findings RawFindings `json:"findings" yaml:"findings"` + Expected ExpectedDetections `json:"expected_findings,omitempty" yaml:"expected_findings,omitempty"` } func NewFormatter(reportData *outputtypes.ReportData, config settings.Config, goclocResult *gocloc.Result, startTime time.Time, endTime time.Time) *Formatter { @@ -66,10 +67,11 @@ func (f Formatter) Format(format string) (output string, err error) { case flag.FormatJSON: return outputhandler.ReportJSON(f.ReportData.FindingsBySeverity) case flag.FormatJSONV2: - return outputhandler.ReportJSON(RawFindingsOutput{ + return outputhandler.ReportJSON(JsonV2Output{ Source: "Bearer", Version: build.Version, Findings: f.ReportData.RawFindings, + Expected: f.ReportData.ExpectedDetections, }) case flag.FormatYAML: return outputhandler.ReportYAML(f.ReportData.FindingsBySeverity) diff --git a/internal/report/output/security/security.go b/internal/report/output/security/security.go index 3e7c66ea0..3fb109968 100644 --- a/internal/report/output/security/security.go +++ b/internal/report/output/security/security.go @@ -42,6 +42,7 @@ var severityColorFns = map[string]func(x ...interface{}) string{ globaltypes.LevelWarning: color.New(color.FgCyan).SprintFunc(), } +type ExpectedDetections = []types.ExpectedDetection type RawFindings = []types.RawFinding type Findings = map[string][]types.Finding type IgnoredFindings = map[string][]types.IgnoredFinding @@ -106,6 +107,22 @@ func AddReportData( } } + for _, expectedDetectionPerRule := range dataflow.ExpectedDetections { + for _, location := range expectedDetectionPerRule.Locations { + reportData.ExpectedDetections = append(reportData.ExpectedDetections, types.ExpectedDetection{ + RuleID: expectedDetectionPerRule.DetectorID, + Location: types.Location{ + Start: location.Source.StartLineNumber, + End: location.Source.EndLineNumber, + Column: types.Column{ + Start: location.Source.StartColumnNumber, + End: location.Source.EndColumnNumber, + }, + }, + }) + } + } + if !config.Scan.Quiet { fingerprintOutput( append(fingerprints, builtInFingerprints...), diff --git a/internal/report/output/security/types/types.go b/internal/report/output/security/types/types.go index fbcc87212..ed73196cf 100644 --- a/internal/report/output/security/types/types.go +++ b/internal/report/output/security/types/types.go @@ -11,6 +11,11 @@ import ( ignoretypes "github.com/bearer/bearer/internal/util/ignore/types" ) +type ExpectedDetection struct { + RuleID string `json:"rule_id"` + Location Location `json:"location"` +} + type RawFinding struct { *Finding Severity string `json:"severity" yaml:"severity"` diff --git a/internal/report/output/types/types.go b/internal/report/output/types/types.go index 94f457a50..18c6d4651 100644 --- a/internal/report/output/types/types.go +++ b/internal/report/output/types/types.go @@ -20,14 +20,16 @@ type ReportData struct { PrivacyReport *privacytypes.Report Stats *statstypes.Stats SaasReport *saastypes.BearerReport + ExpectedDetections []securitytypes.ExpectedDetection } type DataFlow struct { - Datatypes []dataflowtypes.Datatype `json:"data_types,omitempty" yaml:"data_types,omitempty"` - Risks []dataflowtypes.RiskDetector `json:"risks,omitempty" yaml:"risks,omitempty"` - Components []dataflowtypes.Component `json:"components,omitempty" yaml:"components,omitempty"` - Dependencies []dataflowtypes.Dependency `json:"dependencies,omitempty" yaml:"dependencies,omitempty"` - Errors []dataflowtypes.Error `json:"errors,omitempty" yaml:"errors,omitempty"` + Datatypes []dataflowtypes.Datatype `json:"data_types,omitempty" yaml:"data_types,omitempty"` + ExpectedDetections []dataflowtypes.RiskDetector `json:"expected_detections,omitempty" yaml:"expected_detections,omitempty"` + Risks []dataflowtypes.RiskDetector `json:"risks,omitempty" yaml:"risks,omitempty"` + Components []dataflowtypes.Component `json:"components,omitempty" yaml:"components,omitempty"` + Dependencies []dataflowtypes.Dependency `json:"dependencies,omitempty" yaml:"dependencies,omitempty"` + Errors []dataflowtypes.Error `json:"errors,omitempty" yaml:"errors,omitempty"` } type GenericFormatter interface { diff --git a/internal/scanner/ast/.snapshots/TestExpectedRules b/internal/scanner/ast/.snapshots/TestExpectedRules new file mode 100644 index 000000000..eb5c1c998 --- /dev/null +++ b/internal/scanner/ast/.snapshots/TestExpectedRules @@ -0,0 +1,67 @@ +([]ast_test.ruleInfo) (len=1) { + (ast_test.ruleInfo) { + ID: (string) (len=5) "rule1", + Index: (int) 5 + } +} +type: program +id: 0 +range: 2:3 - 6:2 +dataflow_sources: + - 1 + - 2 +children: + - type: comment + id: 1 + range: 2:3 - 2:26 + content: '# bearer:expected rule1' + - type: method + id: 2 + range: 3:3 - 5:6 + expectedrules: + - rule1 + children: + - type: '"def"' + id: 3 + range: 3:3 - 3:6 + - type: identifier + id: 4 + range: 3:7 - 3:8 + content: m + - type: method_parameters + id: 5 + range: 3:8 - 3:11 + dataflow_sources: + - 6 + - 7 + - 8 + children: + - type: '"("' + id: 6 + range: 3:8 - 3:9 + - type: identifier + id: 7 + range: 3:9 - 3:10 + content: a + - type: '")"' + id: 8 + range: 3:10 - 3:11 + - type: call + id: 9 + range: 4:4 - 4:9 + children: + - type: identifier + id: 10 + range: 4:4 - 4:5 + content: b + - type: '"."' + id: 11 + range: 4:5 - 4:6 + - type: identifier + id: 12 + range: 4:6 - 4:9 + content: bar + - type: '"end"' + id: 13 + range: 5:3 - 5:6 + diff --git a/internal/scanner/ast/ast.go b/internal/scanner/ast/ast.go index bde69d0f3..69e6016b1 100644 --- a/internal/scanner/ast/ast.go +++ b/internal/scanner/ast/ast.go @@ -86,6 +86,7 @@ func analyzeNode( childCount := int(node.ChildCount()) var disabledRules []*ruleset.Rule + var expectedRules []*ruleset.Rule for i := 0; i < childCount; i++ { child := node.Child(i) if !child.IsNamed() { @@ -93,6 +94,7 @@ func analyzeNode( } disabledRules = addDisabledRules(ruleSet, builder, disabledRules, child) + expectedRules = addExpectedRules(ruleSet, builder, expectedRules, child) if err := analyzeNode(ctx, ruleSet, builder, analyzer, child); err != nil { return err } @@ -104,6 +106,38 @@ func analyzeNode( return analyzer.Analyze(node, visitChildren) } +func addExpectedRules( + ruleSet *ruleset.Set, + builder *tree.Builder, + expectedRules []*ruleset.Rule, + node *sitter.Node, +) []*ruleset.Rule { + if node.Type() == "comment" { + nextExpectedRules := expectedRules + + nodeContent := builder.ContentFor(node) + if strings.Contains(nodeContent, "bearer:expected") { + rawRuleIDs := strings.Split(nodeContent, "bearer:expected")[1] + + for _, ruleID := range strings.Split(rawRuleIDs, ",") { + rule, err := ruleSet.RuleByID(strings.TrimSpace(ruleID)) + if err != nil { + log.Debug().Msgf("ignoring unknown expected rule '%s': %s", ruleID, err) + continue + } + + nextExpectedRules = append(nextExpectedRules, rule) + } + } + + return nextExpectedRules + } + + builder.AddExpectedRules(node, expectedRules) + + return nil +} + func addDisabledRules( ruleSet *ruleset.Set, builder *tree.Builder, diff --git a/internal/scanner/ast/ast_test.go b/internal/scanner/ast/ast_test.go index e5c7e2d85..fe4be2f71 100644 --- a/internal/scanner/ast/ast_test.go +++ b/internal/scanner/ast/ast_test.go @@ -75,3 +75,55 @@ func TestDisabledRules(t *testing.T) { tree.RootNode().Dump(), ) } + +func TestExpectedRules(t *testing.T) { + content := ` + # bearer:expected rule1 + def m(a) + b.bar + end + ` + + language := ruby.Get() + languageIDs := []string{language.ID()} + + ruleSet, err := ruleset.New( + language.ID(), + map[string]*settings.Rule{ + "rule1": {Id: "rule1", Languages: languageIDs}, + }, + ) + if err != nil { + t.Fatalf("failed to create rule set: %s", err) + } + + var ruleDump []ruleInfo + for _, rule := range ruleSet.Rules() { + if rule.Type() != ruleset.RuleTypeBuiltin { + ruleDump = append(ruleDump, ruleInfo{ID: rule.ID(), Index: rule.Index()}) + } + } + + querySet := query.NewSet(language.ID(), language.SitterLanguage()) + if err := querySet.Compile(); err != nil { + t.Fatalf("failed to compile query set: %s", err) + } + + tree, err := ast.ParseAndAnalyze( + context.Background(), + language, + ruleSet, + querySet, + []byte(content), + ) + + if err != nil { + t.Fatalf("failed to parse and analyze input: %s", err) + } + + cupaloy.SnapshotT( + t, + ruleDump, + tree.RootNode().Dump(), + ) +} diff --git a/internal/scanner/ast/tree/builder.go b/internal/scanner/ast/tree/builder.go index c75d6cf5e..7e81e6f6f 100644 --- a/internal/scanner/ast/tree/builder.go +++ b/internal/scanner/ast/tree/builder.go @@ -96,6 +96,14 @@ func (builder *Builder) Alias(toNode *sitter.Node, fromNodes ...*sitter.Node) { ) } +func (builder *Builder) AddExpectedRules(sitterNode *sitter.Node, rules []*ruleset.Rule) { + if len(rules) == 0 { + return + } + + builder.addExpectedRulesForNode(builder.sitterToNodeID[sitterNode], rules) +} + func (builder *Builder) AddDisabledRules(sitterNode *sitter.Node, rules []*ruleset.Rule) { if len(rules) == 0 { return @@ -104,6 +112,14 @@ func (builder *Builder) AddDisabledRules(sitterNode *sitter.Node, rules []*rules builder.addDisabledRulesForNode(builder.sitterToNodeID[sitterNode], rules) } +func (builder *Builder) addExpectedRulesForNode(nodeID int, rules []*ruleset.Rule) { + node := &builder.nodes[nodeID] + + for _, rule := range rules { + node.expectedRules = append(node.expectedRules, rule.ID()) + } +} + func (builder *Builder) addDisabledRulesForNode(nodeID int, rules []*ruleset.Rule) { node := &builder.nodes[nodeID] if node.disabledRuleIndices == nil { diff --git a/internal/scanner/ast/tree/tree.go b/internal/scanner/ast/tree/tree.go index eea616cb6..2612a03bb 100644 --- a/internal/scanner/ast/tree/tree.go +++ b/internal/scanner/ast/tree/tree.go @@ -30,6 +30,7 @@ type Node struct { children, dataflowSources, aliasOf []*Node + expectedRules []string disabledRuleIndices *bitset.BitSet // FIXME: remove the need for this sitterNode *sitter.Node @@ -60,6 +61,10 @@ func (tree *Tree) NodeFromSitter(sitterNode *sitter.Node) *Node { return tree.sitterToNode[sitterNode] } +func (tree *Tree) Nodes() []Node { + return tree.nodes +} + func (node *Node) Tree() *Tree { return node.tree } @@ -133,6 +138,10 @@ func (node *Node) AliasOf() []*Node { return node.aliasOf } +func (node *Node) ExpectedRules() []string { + return node.expectedRules +} + func (node *Node) RuleDisabled(index int) bool { if node.disabledRuleIndices == nil { return false @@ -158,6 +167,7 @@ type nodeDump struct { AliasOf []int `yaml:"alias_of,omitempty"` Queries []int `yaml:",omitempty"` DisabledRules []int `yaml:",omitempty"` + ExpectedRules []string `yaml:",omitempty"` Children []nodeDump `yaml:",omitempty"` } @@ -189,6 +199,11 @@ func (node *Node) dumpValue() nodeDump { } } + var expectedRules []string + if len(node.expectedRules) > 0 { + expectedRules = append(expectedRules, node.expectedRules...) + } + contentRange := fmt.Sprintf( "%d:%d - %d:%d", node.ContentStart.Line, @@ -212,6 +227,7 @@ func (node *Node) dumpValue() nodeDump { Children: childDump, Queries: queries, DisabledRules: disabledRules, + ExpectedRules: expectedRules, } } diff --git a/internal/scanner/detectors/types/types.go b/internal/scanner/detectors/types/types.go index 1c34bf293..65109762d 100644 --- a/internal/scanner/detectors/types/types.go +++ b/internal/scanner/detectors/types/types.go @@ -24,6 +24,9 @@ type Context interface { type Detector interface { Rule() *ruleset.Rule DetectAt(node *tree.Node, detectorContext Context) ([]interface{}, error) + DetectExpectedAt(node *tree.Node, detectorContext Context) ([]interface{}, error) } -type DetectorBase struct{} +type DetectorBase interface { + DetectExpectedAt(node *tree.Node, detectorContext Context) ([]interface{}, error) +} diff --git a/internal/scanner/detectorset/detectorset.go b/internal/scanner/detectorset/detectorset.go index 0f4777ff1..17b44501b 100644 --- a/internal/scanner/detectorset/detectorset.go +++ b/internal/scanner/detectorset/detectorset.go @@ -21,6 +21,7 @@ const () type Result struct { Detections []*detectortypes.Detection Sanitized bool + Expected bool } type Set interface { diff --git a/internal/scanner/languagescanner/languagescanner.go b/internal/scanner/languagescanner/languagescanner.go index 3f41cce99..0b5fbf1a2 100644 --- a/internal/scanner/languagescanner/languagescanner.go +++ b/internal/scanner/languagescanner/languagescanner.go @@ -77,19 +77,19 @@ func (scanner *Scanner) Scan( ctx context.Context, fileStats *stats.FileStats, fileInfo *file.FileInfo, -) ([]*detectortypes.Detection, error) { +) ([]*detectortypes.Detection, []*detectortypes.Detection, error) { if !slices.Contains(scanner.language.EnryLanguages(), fileInfo.Language) { - return nil, nil + return nil, nil, nil } contentBytes, err := os.ReadFile(fileInfo.AbsolutePath) if err != nil { - return nil, fmt.Errorf("failed to read file: %w", err) + return nil, nil, fmt.Errorf("failed to read file: %w", err) } tree, err := ast.ParseAndAnalyze(ctx, scanner.language, scanner.ruleSet, scanner.querySet, contentBytes) if err != nil { - return nil, err + return nil, nil, err } if log.Trace().Enabled() { @@ -108,15 +108,41 @@ func (scanner *Scanner) Scan( cache, ) - return scanner.evaluateRules(ruleScanner, cache, tree) + detections, err := scanner.evaluateRules(ruleScanner, cache, tree) + expectedDetections, _ := scanner.ExpectedDetections(tree) + + return detections, expectedDetections, err +} + +func (scanner *Scanner) ExpectedDetections(tree *tree.Tree) ([]*detectortypes.Detection, error) { + var detections []*detectortypes.Detection + nodes := tree.Nodes() + for i := range tree.Nodes() { + node := &nodes[i] + if len(node.ExpectedRules()) > 0 { + for _, expectedRule := range node.ExpectedRules() { + rule, _ := scanner.ruleSet.RuleByID(expectedRule) + detections = append(detections, []*detectortypes.Detection{ + { + RuleID: rule.ID(), + MatchNode: node, + }, + }...) + } + } + } + + return detections, nil } func (scanner *Scanner) evaluateRules( ruleScanner *rulescanner.Scanner, cache *cache.Cache, tree *tree.Tree, -) ([]*detectortypes.Detection, error) { - +) ( + []*detectortypes.Detection, + error, +) { var detections []*detectortypes.Detection for _, rule := range scanner.ruleSet.Rules() { if rule.Type() != ruleset.RuleTypeTopLevel { diff --git a/internal/scanner/rulescanner/rulescanner.go b/internal/scanner/rulescanner/rulescanner.go index f0fb88b10..1df49dd54 100644 --- a/internal/scanner/rulescanner/rulescanner.go +++ b/internal/scanner/rulescanner/rulescanner.go @@ -47,7 +47,10 @@ func (scanner *Scanner) Scan( rootNode *tree.Node, rule *ruleset.Rule, traversalStrategy traversalstrategy.Strategy, -) ([]*detectortypes.Detection, error) { +) ( + []*detectortypes.Detection, + error, +) { if scanner.stats != nil { startTime := time.Now() defer scanner.stats.Rule(rule.ID(), startTime) diff --git a/internal/scanner/ruleset/ruleset.go b/internal/scanner/ruleset/ruleset.go index 1713bd8b0..abf36ce2b 100644 --- a/internal/scanner/ruleset/ruleset.go +++ b/internal/scanner/ruleset/ruleset.go @@ -128,6 +128,10 @@ func getRuleType(triggerRuleIDs set.Set[string], settingsRule *settings.Rule) Ru } } +func (set *Set) RuleByIndex(idx uint64) (*Rule, error) { + return set.Rules()[idx], nil +} + func (set *Set) RuleByID(id string) (*Rule, error) { rule, exists := set.rulesByID[id] if !exists { diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index 5f44551e0..b22235ed9 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -68,11 +68,33 @@ func (scanner *Scanner) Scan( } for _, languageScanner := range scanner.languageScanners { - detections, err := languageScanner.Scan(ctx, fileStats, file) + detections, expectedDetections, err := languageScanner.Scan(ctx, fileStats, file) if err != nil { return fmt.Errorf("%s scan failed: %w", languageScanner.LanguageID(), err) } + for _, detection := range expectedDetections { + detectorType := detectors.Type(detection.RuleID) + report.AddDetection(reportdetections.TypeExpectedDetection, + detectorType, + source.New( + file, + file.Path, + detection.MatchNode.ContentStart.Line, + detection.MatchNode.ContentStart.Column, + detection.MatchNode.ContentEnd.Line, + detection.MatchNode.ContentEnd.Column, + fmt.Sprintf("bearer:expected %s", detection.RuleID), + ), + reportschema.Source{ + StartLineNumber: detection.MatchNode.ContentStart.Line, + EndLineNumber: detection.MatchNode.ContentEnd.Line, + StartColumnNumber: detection.MatchNode.ContentStart.Column, + EndColumnNumber: detection.MatchNode.ContentEnd.Column, + Content: detection.MatchNode.Content(), + }) + } + for _, detection := range detections { detectorType := detectors.Type(detection.RuleID) data := detection.Data.(customruletypes.Data)