Skip to content

Commit fe4eaea

Browse files
authored
refactor: use new VulnerabilityMatcher in guided remediation (google#1503)
Following up on google#1470 - Made `ResolutionClient` use the `VulnerabilityMatcher` interface (and added helper function to convert deps.dev graphs into inventories) - Deleted old `VulnerabilityClient` - Created `CachedOSVMatcher` to re-implement performance improvements from the original `VulnerabilityClient` w.r.t. repeated queries. - Re-enabled local database capability in `osv-scanner fix`
1 parent 6874aa6 commit fe4eaea

File tree

13 files changed

+296
-248
lines changed

13 files changed

+296
-248
lines changed

cmd/osv-scanner/fix/main.go

+34-5
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,28 @@ import (
44
"errors"
55
"fmt"
66
"io"
7+
"net/http"
78
"os"
89
"path/filepath"
910
"strings"
11+
"time"
1012

1113
"deps.dev/util/resolve"
14+
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
15+
"github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher"
1216
"github.com/google/osv-scanner/internal/depsdev"
17+
"github.com/google/osv-scanner/internal/imodels/ecosystem"
18+
"github.com/google/osv-scanner/internal/osvdev"
1319
"github.com/google/osv-scanner/internal/remediation"
1420
"github.com/google/osv-scanner/internal/remediation/upgrade"
1521
"github.com/google/osv-scanner/internal/resolution"
1622
"github.com/google/osv-scanner/internal/resolution/client"
1723
"github.com/google/osv-scanner/internal/resolution/lockfile"
1824
"github.com/google/osv-scanner/internal/resolution/manifest"
25+
"github.com/google/osv-scanner/internal/resolution/util"
1926
"github.com/google/osv-scanner/internal/version"
2027
"github.com/google/osv-scanner/pkg/reporter"
28+
"github.com/ossf/osv-schema/bindings/go/osvschema"
2129
"github.com/urfave/cli/v2"
2230
"golang.org/x/term"
2331
)
@@ -364,18 +372,39 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro
364372
}
365373
}
366374

375+
userAgent := "osv-scanner_fix/" + version.OSVVersion
367376
if ctx.Bool("experimental-offline-vulnerabilities") {
368-
var err error
369-
opts.Client.VulnerabilityClient, err = client.NewOSVOfflineClient(
377+
matcher, err := localmatcher.NewLocalMatcher(
370378
r,
371-
system,
379+
ctx.String("experimental-local-db-path"),
380+
userAgent,
372381
ctx.Bool("experimental-download-offline-databases"),
373-
ctx.String("experimental-local-db-path"))
382+
)
374383
if err != nil {
375384
return nil, err
376385
}
386+
387+
eco, ok := util.OSVEcosystem[system]
388+
if !ok {
389+
// Something's very wrong if we hit this
390+
panic("unhandled resolve.Ecosystem: " + system.String())
391+
}
392+
if err := matcher.LoadEcosystem(ctx.Context, ecosystem.Parsed{Ecosystem: osvschema.Ecosystem(eco)}); err != nil {
393+
return nil, err
394+
}
395+
396+
opts.Client.VulnerabilityMatcher = matcher
377397
} else {
378-
opts.Client.VulnerabilityClient = client.NewOSVClient()
398+
config := osvdev.DefaultConfig()
399+
config.UserAgent = userAgent
400+
opts.Client.VulnerabilityMatcher = &osvmatcher.CachedOSVMatcher{
401+
Client: osvdev.OSVClient{
402+
HTTPClient: http.DefaultClient,
403+
Config: config,
404+
BaseHostURL: osvdev.DefaultBaseURL,
405+
},
406+
InitialQueryTimeout: 5 * time.Minute,
407+
}
379408
}
380409

381410
if !ctx.Bool("non-interactive") {

internal/clients/clientimpl/localmatcher/localmatcher.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (matcher *LocalMatcher) MatchVulnerabilities(ctx context.Context, invs []*e
7777
continue
7878
}
7979

80-
results = append(results, db.VulnerabilitiesAffectingPackage(pkg))
80+
results = append(results, VulnerabilitiesAffectingPackage(db.Vulnerabilities(false), pkg))
8181
}
8282

8383
return results, nil

internal/clients/clientimpl/localmatcher/zip.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ func (db *ZipDB) Vulnerabilities(includeWithdrawn bool) []models.Vulnerability {
235235
return vulnerabilities
236236
}
237237

238-
func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*models.Vulnerability {
238+
// TODO: Move this to another file.
239+
func VulnerabilitiesAffectingPackage(allVulns []models.Vulnerability, pkg imodels.PackageInfo) []*models.Vulnerability {
239240
var vulnerabilities []*models.Vulnerability
240241

241242
// TODO (V2 Models): remove this once PackageDetails has been migrated
@@ -248,7 +249,7 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod
248249
DepGroups: pkg.DepGroups(),
249250
}
250251

251-
for _, vulnerability := range db.Vulnerabilities(false) {
252+
for _, vulnerability := range allVulns {
252253
if vulns.IsAffected(vulnerability, mappedPackageDetails) && !vulns.Include(vulnerabilities, vulnerability) {
253254
vulnerabilities = append(vulnerabilities, &vulnerability)
254255
}
@@ -258,10 +259,11 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod
258259
}
259260

260261
func (db *ZipDB) Check(pkgs []imodels.PackageInfo) ([]*models.Vulnerability, error) {
262+
allVulns := db.Vulnerabilities(false)
261263
vulnerabilities := make([]*models.Vulnerability, 0, len(pkgs))
262264

263265
for _, pkg := range pkgs {
264-
vulnerabilities = append(vulnerabilities, db.VulnerabilitiesAffectingPackage(pkg)...)
266+
vulnerabilities = append(vulnerabilities, VulnerabilitiesAffectingPackage(allVulns, pkg)...)
265267
}
266268

267269
return vulnerabilities, nil
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package osvmatcher
2+
3+
import (
4+
"context"
5+
"errors"
6+
"maps"
7+
"slices"
8+
"sync"
9+
"time"
10+
11+
"github.com/google/osv-scalibr/extractor"
12+
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
13+
"github.com/google/osv-scanner/internal/imodels"
14+
"github.com/google/osv-scanner/internal/osvdev"
15+
"github.com/google/osv-scanner/pkg/models"
16+
"golang.org/x/sync/errgroup"
17+
)
18+
19+
// CachedOSVMatcher implements the VulnerabilityMatcher interface with a osv.dev client.
20+
// It sends out requests for every vulnerability of each package, which get cached.
21+
// Checking if a specific version matches an OSV record is done locally.
22+
// This should be used when we know the same packages are going to be repeatedly
23+
// queried multiple times, as in guided remediation.
24+
// TODO: This does not support commit-based queries.
25+
type CachedOSVMatcher struct {
26+
Client osvdev.OSVClient
27+
// InitialQueryTimeout allows you to set a timeout specifically for the initial paging query
28+
// If timeout runs out, whatever pages that has been successfully queried within the timeout will
29+
// still return fully hydrated.
30+
InitialQueryTimeout time.Duration
31+
32+
vulnCache sync.Map // map[osvdev.Package][]models.Vulnerability
33+
}
34+
35+
func (matcher *CachedOSVMatcher) MatchVulnerabilities(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) {
36+
// populate vulnCache with missing packages
37+
if err := matcher.doQueries(ctx, invs); err != nil {
38+
return nil, err
39+
}
40+
41+
results := make([][]*models.Vulnerability, len(invs))
42+
43+
for i, inv := range invs {
44+
if ctx.Err() != nil {
45+
return nil, ctx.Err()
46+
}
47+
48+
pkgInfo := imodels.FromInventory(inv)
49+
pkg := osvdev.Package{
50+
Name: pkgInfo.Name(),
51+
Ecosystem: pkgInfo.Ecosystem().String(),
52+
}
53+
vulns, ok := matcher.vulnCache.Load(pkg)
54+
if !ok {
55+
continue
56+
}
57+
results[i] = localmatcher.VulnerabilitiesAffectingPackage(vulns.([]models.Vulnerability), pkgInfo)
58+
}
59+
60+
return results, nil
61+
}
62+
63+
func (matcher *CachedOSVMatcher) doQueries(ctx context.Context, invs []*extractor.Inventory) error {
64+
var batchResp *osvdev.BatchedResponse
65+
deadlineExceeded := false
66+
67+
var queries []*osvdev.Query
68+
{
69+
// determine which packages aren't already cached
70+
// convert Inventory to Query for each pkgs element
71+
toQuery := make(map[*osvdev.Query]struct{})
72+
for _, inv := range invs {
73+
pkgInfo := imodels.FromInventory(inv)
74+
if pkgInfo.Name() == "" || pkgInfo.Ecosystem().IsEmpty() {
75+
continue
76+
}
77+
pkg := osvdev.Package{
78+
Name: pkgInfo.Name(),
79+
Ecosystem: pkgInfo.Ecosystem().String(),
80+
}
81+
if _, ok := matcher.vulnCache.Load(pkg); !ok {
82+
toQuery[&osvdev.Query{Package: pkg}] = struct{}{}
83+
}
84+
}
85+
queries = slices.Collect(maps.Keys(toQuery))
86+
}
87+
88+
if len(queries) == 0 {
89+
return nil
90+
}
91+
92+
var err error
93+
94+
// If there is a timeout for the initial query, set an additional context deadline here.
95+
if matcher.InitialQueryTimeout > 0 {
96+
batchQueryCtx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(matcher.InitialQueryTimeout))
97+
batchResp, err = queryForBatchWithPaging(batchQueryCtx, &matcher.Client, queries)
98+
cancelFunc()
99+
} else {
100+
batchResp, err = queryForBatchWithPaging(ctx, &matcher.Client, queries)
101+
}
102+
103+
if err != nil {
104+
// Deadline being exceeded is likely caused by a long paging time
105+
// if that's the case, we can should return what we already got, and
106+
// then let the caller know it is not all the results.
107+
if errors.Is(err, context.DeadlineExceeded) {
108+
deadlineExceeded = true
109+
} else {
110+
return err
111+
}
112+
}
113+
114+
vulnerabilities := make([][]models.Vulnerability, len(batchResp.Results))
115+
g, ctx := errgroup.WithContext(ctx)
116+
g.SetLimit(maxConcurrentRequests)
117+
118+
for batchIdx, resp := range batchResp.Results {
119+
vulnerabilities[batchIdx] = make([]models.Vulnerability, len(resp.Vulns))
120+
for resultIdx, vuln := range resp.Vulns {
121+
g.Go(func() error {
122+
// exit early if another hydration request has already failed
123+
// results are thrown away later, so avoid needless work
124+
if ctx.Err() != nil {
125+
return nil //nolint:nilerr // this value doesn't matter to errgroup.Wait()
126+
}
127+
vuln, err := matcher.Client.GetVulnByID(ctx, vuln.ID)
128+
if err != nil {
129+
return err
130+
}
131+
vulnerabilities[batchIdx][resultIdx] = *vuln
132+
133+
return nil
134+
})
135+
}
136+
}
137+
138+
if err := g.Wait(); err != nil {
139+
return err
140+
}
141+
142+
if deadlineExceeded {
143+
return context.DeadlineExceeded
144+
}
145+
146+
for i, vulns := range vulnerabilities {
147+
matcher.vulnCache.Store(queries[i].Package, vulns)
148+
}
149+
150+
return nil
151+
}

internal/remediation/in_place.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import (
99
"deps.dev/util/resolve"
1010
"deps.dev/util/resolve/dep"
1111
"deps.dev/util/semver"
12+
"github.com/google/osv-scanner/internal/clients/clientinterfaces"
1213
"github.com/google/osv-scanner/internal/remediation/upgrade"
1314
"github.com/google/osv-scanner/internal/resolution"
1415
"github.com/google/osv-scanner/internal/resolution/client"
1516
lf "github.com/google/osv-scanner/internal/resolution/lockfile"
1617
"github.com/google/osv-scanner/internal/resolution/util"
1718
"github.com/google/osv-scanner/internal/utility/vulns"
19+
"github.com/google/osv-scanner/pkg/models"
1820
"golang.org/x/exp/maps"
1921
)
2022

@@ -105,7 +107,7 @@ func (r InPlaceResult) VulnCount() VulnCount {
105107
// ComputeInPlacePatches finds all possible targeting version changes that would fix vulnerabilities in a resolved graph.
106108
// TODO: Check for introduced vulnerabilities
107109
func ComputeInPlacePatches(ctx context.Context, cl client.ResolutionClient, graph *resolve.Graph, opts Options) (InPlaceResult, error) {
108-
res, err := inPlaceVulnsNodes(cl, graph)
110+
res, err := inPlaceVulnsNodes(ctx, cl, graph)
109111
if err != nil {
110112
return InPlaceResult{}, err
111113
}
@@ -235,12 +237,16 @@ type inPlaceVulnsNodesResult struct {
235237
vkNodes map[resolve.VersionKey][]resolve.NodeID
236238
}
237239

238-
func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) {
239-
nodeVulns, err := cl.FindVulns(graph)
240+
func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatcher, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) {
241+
nodeVulns, err := m.MatchVulnerabilities(ctx, client.GraphToInventory(graph))
240242
if err != nil {
241243
return inPlaceVulnsNodesResult{}, err
242244
}
243245

246+
// GraphToInventory/MatchVulnerabilities excludes the root node of the graph.
247+
// Prepend an element to nodeVulns so that the indices line up with graph.Nodes[i] <=> nodeVulns[i]
248+
nodeVulns = append([][]*models.Vulnerability{nil}, nodeVulns...)
249+
244250
result := inPlaceVulnsNodesResult{
245251
nodeDependencies: make(map[resolve.NodeID][]resolve.VersionKey),
246252
vkVulns: make(map[resolve.VersionKey][]resolution.Vulnerability),
@@ -272,7 +278,7 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP
272278
result.vkNodes[vk] = append(result.vkNodes[vk], nID)
273279
for _, vuln := range nodeVulns[nID] {
274280
resVuln := resolution.Vulnerability{
275-
OSV: vuln,
281+
OSV: *vuln,
276282
ProblemChains: slices.Clone(chains),
277283
DevOnly: !slices.ContainsFunc(chains, func(dc resolution.DependencyChain) bool { return !resolution.ChainIsDev(dc, nil) }),
278284
}

internal/resolution/client/client.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,16 @@ import (
88
"deps.dev/util/resolve"
99
"deps.dev/util/resolve/dep"
1010
"deps.dev/util/semver"
11+
"github.com/google/osv-scanner/internal/clients/clientinterfaces"
1112
"github.com/google/osv-scanner/internal/depsdev"
12-
"github.com/google/osv-scanner/pkg/models"
1313
"github.com/google/osv-scanner/pkg/osv"
1414
"google.golang.org/grpc"
1515
"google.golang.org/grpc/credentials"
1616
)
1717

1818
type ResolutionClient struct {
1919
DependencyClient
20-
VulnerabilityClient
21-
}
22-
23-
type VulnerabilityClient interface {
24-
// FindVulns finds the vulnerabilities affecting each of Nodes in the graph.
25-
// The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i].
26-
FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error)
20+
clientinterfaces.VulnerabilityMatcher
2721
}
2822

2923
type DependencyClient interface {

internal/resolution/client/helper.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package client
2+
3+
import (
4+
"deps.dev/util/resolve"
5+
"github.com/google/osv-scalibr/extractor"
6+
"github.com/google/osv-scalibr/plugin"
7+
"github.com/google/osv-scalibr/purl"
8+
)
9+
10+
// GraphToInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher.
11+
func GraphToInventory(g *resolve.Graph) []*extractor.Inventory {
12+
// g.Nodes[0] is the root node of the graph that should be excluded.
13+
inv := make([]*extractor.Inventory, len(g.Nodes)-1)
14+
for i, n := range g.Nodes[1:] {
15+
inv[i] = &extractor.Inventory{
16+
Name: n.Version.Name,
17+
Version: n.Version.Version,
18+
Extractor: mockExtractor{n.Version.System},
19+
}
20+
}
21+
22+
return inv
23+
}
24+
25+
// mockExtractor is for GraphToInventory to get the ecosystem.
26+
type mockExtractor struct {
27+
ecosystem resolve.System
28+
}
29+
30+
func (e mockExtractor) Ecosystem(*extractor.Inventory) string {
31+
switch e.ecosystem {
32+
case resolve.NPM:
33+
return "npm"
34+
case resolve.Maven:
35+
return "Maven"
36+
case resolve.UnknownSystem:
37+
return ""
38+
default:
39+
return ""
40+
}
41+
}
42+
43+
func (e mockExtractor) Name() string { return "" }
44+
func (e mockExtractor) Requirements() *plugin.Capabilities { return nil }
45+
func (e mockExtractor) ToPURL(*extractor.Inventory) *purl.PackageURL { return nil }
46+
func (e mockExtractor) Version() int { return 0 }

0 commit comments

Comments
 (0)