Skip to content

Commit

Permalink
fix(guided remediation): reduce memory footprint by computing depende…
Browse files Browse the repository at this point in the history
…ncy subgraphs instead of chains (#1538)

Guided remediation had been using `DependencyChains` to track paths to a
vulnerable package (for computing things like depth and which direct
dependencies to relax). It was computing *every* possible path in the
graph to a dependency, which grows roughly exponentially with depth /
connectivity. This was using an unreasonable amount of memory on some
particularly large/complex projects.

I've changed the logic to instead compute one `DependencySubgraph` - the
set of nodes and edges that would contain every path to a dependency.
This should significantly reduce memory usage (and cpu usage from
allocs) when running on larger projects.

This change has touched quite a few places in the code, and the logic is
a bit complex. I've tried my best to check that everything still behaves
as expected.
  • Loading branch information
michaelkedar authored Jan 30, 2025
1 parent c4543e9 commit d8d794b
Show file tree
Hide file tree
Showing 16 changed files with 2,008 additions and 651 deletions.
4 changes: 2 additions & 2 deletions cmd/osv-scanner/fix/noninteractive.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ func makeResultVuln(vuln resolution.Vulnerability) vulnOutput {
}

affected := make(map[packageOutput]struct{})
for _, c := range append(vuln.ProblemChains, vuln.NonProblemChains...) {
vk, _ := c.End()
for _, sg := range vuln.Subgraphs {
vk := sg.Nodes[sg.Dependency].Version
affected[packageOutput{Name: vk.Name, Version: vk.Version}] = struct{}{}
}
v.Packages = maps.Keys(affected)
Expand Down
21 changes: 10 additions & 11 deletions internal/remediation/in_place.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ func ComputeInPlacePatches(ctx context.Context, cl client.ResolutionClient, grap
for vk, vulns := range res.vkVulns {
reqVers := make(map[string]struct{})
for _, vuln := range vulns {
for _, c := range vuln.ProblemChains {
_, req := c.End()
reqVers[req] = struct{}{}
for _, sg := range vuln.Subgraphs {
for _, e := range sg.Nodes[sg.Dependency].Parents {
reqVers[e.Requirement] = struct{}{}
}
}
}
set, err := buildConstraintSet(vk.Semver(), maps.Keys(reqVers))
Expand Down Expand Up @@ -268,24 +269,22 @@ func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatc
nodeIDs = append(nodeIDs, resolve.NodeID(nID))
}
}
nodeChains := resolution.ComputeChains(graph, nodeIDs)
// Computing ALL chains might be overkill...
// We only actually care about the shortest chain, the unique dependents of the vulnerable node, and maybe the unique direct dependencies.
nodeSubgraphs := resolution.ComputeSubgraphs(graph, nodeIDs)

for i, nID := range nodeIDs {
chains := nodeChains[i]
vk := graph.Nodes[nID].Version
result.vkNodes[vk] = append(result.vkNodes[vk], nID)
for _, vuln := range nodeVulns[nID] {
resVuln := resolution.Vulnerability{
OSV: *vuln,
ProblemChains: slices.Clone(chains),
DevOnly: !slices.ContainsFunc(chains, func(dc resolution.DependencyChain) bool { return !resolution.ChainIsDev(dc, nil) }),
OSV: *vuln,
Subgraphs: []*resolution.DependencySubgraph{nodeSubgraphs[i]},
DevOnly: nodeSubgraphs[i].IsDevOnly(nil),
}
idx := slices.IndexFunc(result.vkVulns[vk], func(rv resolution.Vulnerability) bool { return rv.OSV.ID == resVuln.OSV.ID })
if idx >= 0 {
result.vkVulns[vk][idx].ProblemChains = append(result.vkVulns[vk][idx].ProblemChains, resVuln.ProblemChains...)
result.vkVulns[vk][idx].DevOnly = result.vkVulns[vk][idx].DevOnly && resVuln.DevOnly

result.vkVulns[vk][idx].Subgraphs = append(result.vkVulns[vk][idx].Subgraphs, resVuln.Subgraphs...)
} else {
result.vkVulns[vk] = append(result.vkVulns[vk], resVuln)
}
Expand Down
7 changes: 2 additions & 5 deletions internal/remediation/in_place_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,8 @@ func checkInPlaceResults(t *testing.T, res remediation.InPlaceResult) {
toMinimalVuln := func(v resolution.Vulnerability) minimalVuln {
t.Helper()
nodes := make(map[resolve.NodeID]struct{})
for _, c := range v.ProblemChains {
nodes[c.Edges[0].To] = struct{}{}
}
for _, c := range v.NonProblemChains {
nodes[c.Edges[0].To] = struct{}{}
for _, sg := range v.Subgraphs {
nodes[sg.Dependency] = struct{}{}
}
sortedNodes := maps.Keys(nodes)
slices.Sort(sortedNodes)
Expand Down
39 changes: 14 additions & 25 deletions internal/remediation/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,31 +131,20 @@ func overridePatchVulns(ctx context.Context, cl client.ResolutionClient, result
// Keep track of VersionKeys we've seen for this vuln to avoid duplicates.
// Usually, there will only be one VersionKey per vuln, but some vulns affect multiple packages.
seenVKs := make(map[resolve.VersionKey]struct{})
// Use the DependencyChains to find all the affected nodes.
for _, c := range v.ProblemChains {
// Currently, there is no way to know if a specific classifier or type exists for a given version with deps.dev.
// Blindly updating versions can lead to compilation failures if the artifact+version+classifier+type doesn't exist.
// We can't reliably attempt remediation in these cases, so don't try.
// TODO: query Maven registry for existence of classifiers in getVersionsGreater
typ := c.Edges[0].Type
if typ.HasAttr(dep.MavenClassifier) || typ.HasAttr(dep.MavenArtifactType) {
return nil, nil, fmt.Errorf("%w: cannot fix vulns in artifacts with classifier or type", errOverrideImpossible)
}
vk, _ := c.End()
if _, seen := seenVKs[vk]; !seen {
vkVulns[vk] = append(vkVulns[vk], &result.Vulns[i])
seenVKs[vk] = struct{}{}
}
}
for _, c := range v.NonProblemChains {
typ := c.Edges[0].Type
if typ.HasAttr(dep.MavenClassifier) || typ.HasAttr(dep.MavenArtifactType) {
return nil, nil, fmt.Errorf("%w: cannot fix vulns in artifacts with classifier or type", errOverrideImpossible)
}
vk, _ := c.End()
if _, seen := seenVKs[vk]; !seen {
vkVulns[vk] = append(vkVulns[vk], &result.Vulns[i])
seenVKs[vk] = struct{}{}
// Use the Subgraphs to find all the affected nodes.
for _, sg := range v.Subgraphs {
for _, e := range sg.Nodes[sg.Dependency].Parents {
// Currently, there is no way to know if a specific classifier or type exists for a given version with deps.dev.
// Blindly updating versions can lead to compilation failures if the artifact+version+classifier+type doesn't exist.
// We can't reliably attempt remediation in these cases, so don't try.
if e.Type.HasAttr(dep.MavenClassifier) || e.Type.HasAttr(dep.MavenArtifactType) {
return nil, nil, fmt.Errorf("%w: cannot fix vulns in artifacts with classifier or type", errOverrideImpossible)
}
vk := sg.Nodes[sg.Dependency].Version
if _, seen := seenVKs[vk]; !seen {
vkVulns[vk] = append(vkVulns[vk], &result.Vulns[i])
seenVKs[vk] = struct{}{}
}
}
}
}
Expand Down
19 changes: 11 additions & 8 deletions internal/remediation/relax.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func tryRelaxRemediate(
}

newRes := orig
toRelax := reqsToRelax(newRes, vulnIDs, opts)
toRelax := reqsToRelax(ctx, cl, newRes, vulnIDs, opts)
for len(toRelax) > 0 {
// Try relaxing all necessary requirements
manif := newRes.Manifest.Clone()
Expand All @@ -109,24 +109,27 @@ func tryRelaxRemediate(
if err != nil {
return nil, err
}
toRelax = reqsToRelax(newRes, vulnIDs, opts)
toRelax = reqsToRelax(ctx, cl, newRes, vulnIDs, opts)
}

return newRes, nil
}

func reqsToRelax(res *resolution.Result, vulnIDs []string, opts Options) []int {
func reqsToRelax(ctx context.Context, cl resolve.Client, res *resolution.Result, vulnIDs []string, opts Options) []int {
toRelax := make(map[resolve.VersionKey]string)
for _, v := range res.Vulns {
// Don't do a full opts.MatchVuln() since we know we don't need to check every condition
if !slices.Contains(vulnIDs, v.OSV.ID) || (!opts.DevDeps && v.DevOnly) {
continue
}
// Only relax dependencies if their chain length is less than MaxDepth
for _, ch := range v.ProblemChains {
if opts.MaxDepth <= 0 || len(ch.Edges) <= opts.MaxDepth {
vk, req := ch.Direct()
toRelax[vk] = req
// Only relax dependencies if their distance is less than MaxDepth
for _, sg := range v.Subgraphs {
constr := sg.ConstrainingSubgraph(ctx, cl, &v.OSV)
for _, edge := range constr.Nodes[0].Children {
gNode := constr.Nodes[edge.To]
if opts.MaxDepth <= 0 || gNode.Distance+1 <= opts.MaxDepth {
toRelax[gNode.Version] = edge.Requirement
}
}
}
}
Expand Down
14 changes: 2 additions & 12 deletions internal/remediation/remediation.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,8 @@ func (opts Options) matchDepth(v resolution.Vulnerability) bool {
return true
}

if len(v.ProblemChains)+len(v.NonProblemChains) == 0 {
panic("vulnerability with no dependency chains")
}

for _, ch := range v.ProblemChains {
if len(ch.Edges) <= opts.MaxDepth {
return true
}
}

for _, ch := range v.NonProblemChains {
if len(ch.Edges) <= opts.MaxDepth {
for _, sg := range v.Subgraphs {
if sg.Nodes[0].Distance <= opts.MaxDepth {
return true
}
}
Expand Down
55 changes: 48 additions & 7 deletions internal/remediation/remediation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,30 @@ func TestMatchVuln(t *testing.T) {
Aliases: []string{"CVE-111", "OSV-2"},
},
DevOnly: false,
ProblemChains: []resolution.DependencyChain{{
Edges: []resolve.Edge{{From: 2, To: 3}, {From: 1, To: 2}, {From: 0, To: 1}},
Subgraphs: []*resolution.DependencySubgraph{{
Dependency: 3,
Nodes: map[resolve.NodeID]resolution.GraphNode{
3: {
Distance: 0,
Parents: []resolve.Edge{{From: 2, To: 3}},
Children: []resolve.Edge{},
},
2: {
Distance: 1,
Parents: []resolve.Edge{{From: 1, To: 2}},
Children: []resolve.Edge{{From: 2, To: 3}},
},
1: {
Distance: 2,
Parents: []resolve.Edge{{From: 0, To: 1}},
Children: []resolve.Edge{{From: 1, To: 2}},
},
0: {
Distance: 3,
Parents: []resolve.Edge{},
Children: []resolve.Edge{{From: 0, To: 1}},
},
},
}},
}
// ID: VULN-002, Dev: true, Severity: N/A, Depth: 2
Expand All @@ -34,11 +56,30 @@ func TestMatchVuln(t *testing.T) {
// No severity
},
DevOnly: true,
ProblemChains: []resolution.DependencyChain{{
Edges: []resolve.Edge{{From: 2, To: 3}, {From: 1, To: 2}, {From: 0, To: 1}},
}},
NonProblemChains: []resolution.DependencyChain{{
Edges: []resolve.Edge{{From: 1, To: 3}, {From: 0, To: 1}},
Subgraphs: []*resolution.DependencySubgraph{{
Dependency: 3,
Nodes: map[resolve.NodeID]resolution.GraphNode{
3: {
Distance: 0,
Parents: []resolve.Edge{{From: 2, To: 3}, {From: 1, To: 3}},
Children: []resolve.Edge{},
},
2: {
Distance: 1,
Parents: []resolve.Edge{{From: 1, To: 2}},
Children: []resolve.Edge{{From: 2, To: 3}},
},
1: {
Distance: 1,
Parents: []resolve.Edge{{From: 0, To: 1}},
Children: []resolve.Edge{{From: 1, To: 2}, {From: 1, To: 3}},
},
0: {
Distance: 2,
Parents: []resolve.Edge{},
Children: []resolve.Edge{{From: 0, To: 1}},
},
},
}},
}
)
Expand Down
7 changes: 2 additions & 5 deletions internal/remediation/testhelpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ func checkRemediationResults(t *testing.T, res []resolution.Difference) {
toMinimalVuln := func(v resolution.Vulnerability) minimalVuln {
t.Helper()
nodes := make(map[resolve.NodeID]struct{})
for _, c := range v.ProblemChains {
nodes[c.Edges[0].To] = struct{}{}
}
for _, c := range v.NonProblemChains {
nodes[c.Edges[0].To] = struct{}{}
for _, sg := range v.Subgraphs {
nodes[sg.Dependency] = struct{}{}
}
sortedNodes := maps.Keys(nodes)
slices.Sort(sortedNodes)
Expand Down
Loading

0 comments on commit d8d794b

Please sign in to comment.