Skip to content

Commit 8e09d06

Browse files
committed
internal/frontend, internal/vuln: replace getVulnEntries with vuln.Client
Instead of passing around a function, getVulnEntries, pass the actual vuln client and call it directly. Update the TestClient to implement the GetByModules function so that tests can use it. The purpose of this change is to further isolate calls to the vulndb Client to the internal/vuln package, and to make the code easier to understand by removing a function parameter. For golang/go#58928 Change-Id: I8bef528034a1caa44b99da2f185990338ec9cd5f Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/474537 Reviewed-by: Jamal Carvalho <[email protected]> Run-TryBot: Tatiana Bradley <[email protected]> TryBot-Result: kokoro <[email protected]>
1 parent 0817681 commit 8e09d06

File tree

9 files changed

+44
-73
lines changed

9 files changed

+44
-73
lines changed

internal/frontend/search.go

+8-12
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@ func determineSearchAction(r *http.Request, ds internal.DataSource, vulnClient *
129129
if len(filters) > 0 {
130130
symbol = filters[0]
131131
}
132-
var getVulnEntries vuln.VulnEntriesFunc
133-
if vulnClient != nil {
134-
getVulnEntries = vulnClient.ByModule
135-
}
136-
page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, getVulnEntries)
132+
page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, vulnClient)
137133
if err != nil {
138134
// Instead of returning a 500, return a 408, since symbol searches may
139135
// timeout for very popular symbols.
@@ -236,7 +232,7 @@ type subResult struct {
236232
// fetchSearchPage fetches data matching the search query from the database and
237233
// returns a SearchPage.
238234
func fetchSearchPage(ctx context.Context, db *postgres.DB, cq, symbol string,
239-
pageParams paginationParams, searchSymbols bool, getVulnEntries vuln.VulnEntriesFunc) (*SearchPage, error) {
235+
pageParams paginationParams, searchSymbols bool, vulnClient *vuln.Client) (*SearchPage, error) {
240236
maxResultCount := maxSearchOffset + pageParams.limit
241237

242238
// Pageless search: always start from the beginning.
@@ -258,8 +254,8 @@ func fetchSearchPage(ctx context.Context, db *postgres.DB, cq, symbol string,
258254
results = append(results, sr)
259255
}
260256

261-
if getVulnEntries != nil {
262-
addVulns(ctx, results, getVulnEntries)
257+
if vulnClient != nil {
258+
addVulns(ctx, results, vulnClient)
263259
}
264260

265261
var numResults int
@@ -400,13 +396,13 @@ EntryLoop:
400396
}, nil
401397
}
402398

403-
func searchVulnAlias(ctx context.Context, mode, cq string, vulnClient *vuln.Client) (_ *searchAction, err error) {
399+
func searchVulnAlias(ctx context.Context, mode, cq string, vc *vuln.Client) (_ *searchAction, err error) {
404400
defer derrors.Wrap(&err, "searchVulnAlias(%q, %q)", mode, cq)
405401

406402
if mode != searchModeVuln || !isVulnAlias(cq) {
407403
return nil, nil
408404
}
409-
aliasEntries, err := vulnClient.ByAlias(ctx, cq)
405+
aliasEntries, err := vc.ByAlias(ctx, cq)
410406
if err != nil {
411407
return nil, err
412408
}
@@ -607,7 +603,7 @@ func elapsedTime(date time.Time) string {
607603

608604
// addVulns adds vulnerability information to search results by consulting the
609605
// vulnerability database.
610-
func addVulns(ctx context.Context, rs []*SearchResult, getVulnEntries vuln.VulnEntriesFunc) {
606+
func addVulns(ctx context.Context, rs []*SearchResult, vc *vuln.Client) {
611607
// Get all vulns concurrently.
612608
var wg sync.WaitGroup
613609
// TODO(golang/go#48223): throttle concurrency?
@@ -616,7 +612,7 @@ func addVulns(ctx context.Context, rs []*SearchResult, getVulnEntries vuln.VulnE
616612
wg.Add(1)
617613
go func() {
618614
defer wg.Done()
619-
r.Vulns = vuln.VulnsForPackage(ctx, r.ModulePath, r.Version, r.PackagePath, getVulnEntries)
615+
r.Vulns = vuln.VulnsForPackage(ctx, r.ModulePath, r.Version, r.PackagePath, vc)
620616
}()
621617
}
622618
wg.Wait()

internal/frontend/search_test.go

+2-7
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,7 @@ func TestFetchSearchPage(t *testing.T) {
312312
}},
313313
}}
314314

315-
getVulnEntries = func(_ context.Context, modulePath string) ([]*osv.Entry, error) {
316-
if modulePath == moduleFoo.ModulePath {
317-
return vulnEntries, nil
318-
}
319-
return nil, nil
320-
}
315+
vc = vuln.NewTestClient(vulnEntries)
321316
)
322317

323318
for _, m := range []*internal.Module{moduleFoo, moduleBar} {
@@ -392,7 +387,7 @@ func TestFetchSearchPage(t *testing.T) {
392387
},
393388
} {
394389
t.Run(test.name, func(t *testing.T) {
395-
got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, getVulnEntries)
390+
got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, vc)
396391
if err != nil {
397392
t.Fatalf("fetchSearchPage(db, %q): %v", test.query, err)
398393
}

internal/frontend/tabs.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ func init() {
7878
// handler.
7979
func fetchDetailsForUnit(ctx context.Context, r *http.Request, tab string, ds internal.DataSource, um *internal.UnitMeta,
8080
requestedVersion string, bc internal.BuildContext,
81-
getVulnEntries vuln.VulnEntriesFunc) (_ any, err error) {
81+
vc *vuln.Client) (_ any, err error) {
8282
defer derrors.Wrap(&err, "fetchDetailsForUnit(r, %q, ds, um=%q,%q,%q)", tab, um.Path, um.ModulePath, um.Version)
8383
switch tab {
8484
case tabMain:
8585
_, expandReadme := r.URL.Query()["readme"]
8686
return fetchMainDetails(ctx, ds, um, requestedVersion, expandReadme, bc)
8787
case tabVersions:
88-
return fetchVersionsDetails(ctx, ds, um, getVulnEntries)
88+
return fetchVersionsDetails(ctx, ds, um, vc)
8989
case tabImports:
9090
return fetchImportsDetails(ctx, ds, um.Path, um.ModulePath, um.Version)
9191
case tabImportedBy:

internal/frontend/unit.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,7 @@ func (s *Server) serveUnitPage(ctx context.Context, w http.ResponseWriter, r *ht
135135
// It's also okay to provide just one (e.g. GOOS=windows), which will select
136136
// the first doc with that value, ignoring the other one.
137137
bc := internal.BuildContext{GOOS: r.FormValue("GOOS"), GOARCH: r.FormValue("GOARCH")}
138-
var getVulnEntries vuln.VulnEntriesFunc
139-
if s.vulnClient != nil {
140-
getVulnEntries = s.vulnClient.ByModule
141-
}
142-
d, err := fetchDetailsForUnit(ctx, r, tab, ds, um, info.requestedVersion, bc, getVulnEntries)
138+
d, err := fetchDetailsForUnit(ctx, r, tab, ds, um, info.requestedVersion, bc, s.vulnClient)
143139
if err != nil {
144140
return err
145141
}
@@ -240,9 +236,8 @@ func (s *Server) serveUnitPage(ctx context.Context, w http.ResponseWriter, r *ht
240236
}
241237

242238
// Get vulnerability information.
243-
if s.vulnClient != nil {
244-
page.Vulns = vuln.VulnsForPackage(ctx, um.ModulePath, um.Version, um.Path, s.vulnClient.ByModule)
245-
}
239+
page.Vulns = vuln.VulnsForPackage(ctx, um.ModulePath, um.Version, um.Path, s.vulnClient)
240+
246241
s.servePage(ctx, w, tabSettings.TemplateName, page)
247242
return nil
248243
}

internal/frontend/versions.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ type VersionSummary struct {
8585
Vulns []vuln.Vuln
8686
}
8787

88-
func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta, getVulnEntries vuln.VulnEntriesFunc) (*VersionsDetails, error) {
88+
func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta, vc *vuln.Client) (*VersionsDetails, error) {
8989
db, ok := ds.(*postgres.DB)
9090
if !ok {
9191
// The proxydatasource does not support the imported by page.
@@ -114,7 +114,7 @@ func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *inter
114114
}
115115
return constructUnitURL(versionPath, mi.ModulePath, linkVersion(mi.ModulePath, mi.Version, mi.Version))
116116
}
117-
return buildVersionDetails(ctx, um.ModulePath, um.Path, versions, sh, linkify, getVulnEntries), nil
117+
return buildVersionDetails(ctx, um.ModulePath, um.Path, versions, sh, linkify, vc), nil
118118
}
119119

120120
// pathInVersion constructs the full import path of the package corresponding
@@ -146,7 +146,7 @@ func buildVersionDetails(ctx context.Context, currentModulePath, packagePath str
146146
modInfos []*internal.ModuleInfo,
147147
sh *internal.SymbolHistory,
148148
linkify func(v *internal.ModuleInfo) string,
149-
getVulnEntries vuln.VulnEntriesFunc,
149+
vc *vuln.Client,
150150
) *VersionsDetails {
151151
// lists organizes versions by VersionListKey.
152152
lists := make(map[VersionListKey]*VersionList)
@@ -201,7 +201,7 @@ func buildVersionDetails(ctx context.Context, currentModulePath, packagePath str
201201
if mi.ModulePath == stdlib.ModulePath {
202202
pkg = packagePath
203203
}
204-
vs.Vulns = vuln.VulnsForPackage(ctx, mi.ModulePath, mi.Version, pkg, getVulnEntries)
204+
vs.Vulns = vuln.VulnsForPackage(ctx, mi.ModulePath, mi.Version, pkg, vc)
205205
vl := lists[key]
206206
if vl == nil {
207207
seenLists = append(seenLists, key)

internal/frontend/versions_test.go

+2-7
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,7 @@ func TestFetchPackageVersionsDetails(t *testing.T) {
107107
},
108108
}},
109109
}
110-
getVulnEntries := func(_ context.Context, m string) ([]*osv.Entry, error) {
111-
if m == modulePath1 {
112-
return []*osv.Entry{vulnEntry}, nil
113-
}
114-
return nil, nil
115-
}
110+
vc := vuln.NewTestClient([]*osv.Entry{vulnEntry})
116111

117112
for _, tc := range []struct {
118113
name string
@@ -201,7 +196,7 @@ func TestFetchPackageVersionsDetails(t *testing.T) {
201196
postgres.MustInsertModule(ctx, t, testDB, v)
202197
}
203198

204-
got, err := fetchVersionsDetails(ctx, testDB, &tc.pkg.UnitMeta, getVulnEntries)
199+
got, err := fetchVersionsDetails(ctx, testDB, &tc.pkg.UnitMeta, vc)
205200
if err != nil {
206201
t.Fatalf("fetchVersionsDetails(ctx, db, %q, %q): %v", tc.pkg.Path, tc.pkg.ModulePath, err)
207202
}

internal/vuln/test_client.go

+12-7
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,38 @@ package vuln
66

77
import (
88
"context"
9-
"errors"
109

1110
vulnc "golang.org/x/vuln/client"
1211
"golang.org/x/vuln/osv"
1312
)
1413

14+
// NewTestClient creates an in-memory client for use in tests.
1515
func NewTestClient(entries []*osv.Entry) *Client {
1616
c := &vulndbTestClient{
17-
entries: entries,
18-
aliasToIDs: map[string][]string{},
17+
entries: entries,
18+
aliasToIDs: map[string][]string{},
19+
modulesToEntries: map[string][]*osv.Entry{},
1920
}
2021
for _, e := range entries {
2122
for _, a := range e.Aliases {
2223
c.aliasToIDs[a] = append(c.aliasToIDs[a], e.ID)
2324
}
25+
for _, affected := range e.Affected {
26+
c.modulesToEntries[affected.Package.Name] = append(c.modulesToEntries[affected.Package.Name], e)
27+
}
2428
}
2529
return &Client{c: c}
2630
}
2731

2832
type vulndbTestClient struct {
2933
vulnc.Client
30-
entries []*osv.Entry
31-
aliasToIDs map[string][]string
34+
entries []*osv.Entry
35+
aliasToIDs map[string][]string
36+
modulesToEntries map[string][]*osv.Entry
3237
}
3338

34-
func (c *vulndbTestClient) GetByModule(context.Context, string) ([]*osv.Entry, error) {
35-
return nil, errors.New("unimplemented")
39+
func (c *vulndbTestClient) GetByModule(_ context.Context, module string) ([]*osv.Entry, error) {
40+
return c.modulesToEntries[module], nil
3641
}
3742

3843
func (c *vulndbTestClient) GetByID(_ context.Context, id string) (*osv.Entry, error) {

internal/vuln/vulns.go

+9-12
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,24 @@ type Vuln struct {
3434
Details string
3535
}
3636

37-
type VulnEntriesFunc func(context.Context, string) ([]*osv.Entry, error)
38-
3937
// VulnsForPackage obtains vulnerability information for the given package.
4038
// If packagePath is empty, it returns all entries for the module at version.
41-
// The getVulnEntries function should retrieve all entries for the given module path.
42-
// It is passed to facilitate testing.
4339
// If there is an error, VulnsForPackage returns a single Vuln that describes the error.
44-
func VulnsForPackage(ctx context.Context, modulePath, version, packagePath string, getVulnEntries VulnEntriesFunc) []Vuln {
45-
vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, getVulnEntries)
40+
func VulnsForPackage(ctx context.Context, modulePath, version, packagePath string, vc *Client) []Vuln {
41+
if vc == nil {
42+
return nil
43+
}
44+
45+
vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, vc)
4646
if err != nil {
4747
return []Vuln{{Details: fmt.Sprintf("could not get vulnerability data: %v", err)}}
4848
}
4949
return vs
5050
}
5151

52-
func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, getVulnEntries VulnEntriesFunc) (_ []Vuln, err error) {
53-
defer derrors.Wrap(&err, "vulns(%q, %q, %q)", modulePath, vers, packagePath)
52+
func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, vc *Client) (_ []Vuln, err error) {
53+
defer derrors.Wrap(&err, "vulnsForPackage(%q, %q, %q)", modulePath, vers, packagePath)
5454

55-
if getVulnEntries == nil {
56-
return nil, nil
57-
}
5855
// Stdlib pages requested at master will map to a pseudo version that puts
5956
// all vulns in range. We can't really tell you're at master so version.IsPseudo
6057
// is the best we can do. The result is vulns won't be reported for a pseudoversion
@@ -68,7 +65,7 @@ func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string,
6865
modulePath = vulnStdlibModulePath
6966
}
7067
// Get all the vulns for this module.
71-
entries, err := getVulnEntries(ctx, modulePath)
68+
entries, err := vc.ByModule(ctx, modulePath)
7269
if err != nil {
7370
return nil, err
7471
}

internal/vuln/vulns_test.go

+2-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package vuln
66

77
import (
88
"context"
9-
"fmt"
109
"reflect"
1110
"testing"
1211

@@ -60,18 +59,7 @@ func TestVulnsForPackage(t *testing.T) {
6059
}},
6160
}
6261

63-
get := func(_ context.Context, modulePath string) ([]*osv.Entry, error) {
64-
switch modulePath {
65-
case "good.com":
66-
return nil, nil
67-
case "bad.com", "unfixable.com":
68-
return []*osv.Entry{&e}, nil
69-
case "stdlib":
70-
return []*osv.Entry{&stdlib}, nil
71-
default:
72-
return nil, fmt.Errorf("unknown module %q", modulePath)
73-
}
74-
}
62+
vc := NewTestClient([]*osv.Entry{&e, &stdlib})
7563

7664
testCases := []struct {
7765
mod, pkg, version string
@@ -118,7 +106,7 @@ func TestVulnsForPackage(t *testing.T) {
118106
},
119107
}
120108
for _, tc := range testCases {
121-
got := VulnsForPackage(ctx, tc.mod, tc.version, tc.pkg, get)
109+
got := VulnsForPackage(ctx, tc.mod, tc.version, tc.pkg, vc)
122110
if diff := cmp.Diff(tc.want, got); diff != "" {
123111
t.Errorf("VulnsForPackage(%q, %q, %q) = %+v, mismatch (-want, +got):\n%s", tc.mod, tc.version, tc.pkg, tc.want, diff)
124112
}

0 commit comments

Comments
 (0)