Skip to content

Collect go vulnerabilities from github api #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 15, 2022
8 changes: 7 additions & 1 deletion vulnerabilities/importers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vulnerabilities.package_managers import MavenVersionAPI
from vulnerabilities.package_managers import NugetVersionAPI
from vulnerabilities.package_managers import PypiVersionAPI
from vulnerabilities.package_managers import GoproxyVersionAPI
from vulnerabilities.package_managers import RubyVersionAPI
from vulnerabilities.severity_systems import scoring_systems

Expand Down Expand Up @@ -196,6 +197,7 @@ def set_version_api(self, ecosystem: str) -> None:
"COMPOSER": ComposerVersionAPI,
"PIP": PypiVersionAPI,
"RUBYGEMS": RubyVersionAPI,
"GO": GoproxyVersionAPI,
}
versioner = versioners.get(ecosystem)
if versioner:
Expand All @@ -219,7 +221,7 @@ def process_name(ecosystem: str, pkg_name: str) -> Optional[Tuple[Optional[str],
return None
return vendor, name

if ecosystem == "NUGET" or ecosystem == "PIP" or ecosystem == "RUBYGEMS":
if ecosystem in ("NUGET", "PIP", "RUBYGEMS", "GO"):
return None, pkg_name

@staticmethod
Expand Down Expand Up @@ -255,6 +257,10 @@ def process_response(self) -> List[Advisory]:
unaffected_purls = []
if self.process_name(ecosystem, name):
ns, pkg_name = self.process_name(ecosystem, name)
if hasattr(self.version_api, "module_name_by_package_name"):
pkg_name = self.version_api.module_name_by_package_name.get(
name, pkg_name
)
aff_range = adv["node"]["vulnerableVersionRange"]
aff_vers, unaff_vers = self.categorize_versions(
self.version_api.package_type,
Expand Down
158 changes: 139 additions & 19 deletions vulnerabilities/package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,30 @@
# for any legal advice.
# VulnerableCode is a free software code scanning tool from nexB Inc. and others.
# Visit https://github.com/nexB/vulnerablecode/ for support and download.
import traceback
import asyncio
import dataclasses
import os
import xml.etree.ElementTree as ET
from datetime import datetime
from json import JSONDecodeError
from subprocess import check_output
from typing import List
from typing import Mapping
from typing import Set
from typing import Set, List, MutableMapping, Optional
from django.utils.dateparse import parse_datetime

import aiohttp
from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.client_exceptions import ServerDisconnectedError
from aiohttp.web_exceptions import HTTPGone
from bs4 import BeautifulSoup
from dateutil import parser as dateparser


@dataclasses.dataclass(frozen=True)
class Version:
value: str
release_date: datetime = None
release_date: Optional[datetime] = None


@dataclasses.dataclass
Expand All @@ -55,10 +56,10 @@ class GraphQLError(Exception):


class VersionAPI:
def __init__(self, cache: Mapping[str, Set[Version]] = None):
def __init__(self, cache: MutableMapping[str, Set[Version]] = None):
self.cache = cache or {}

def get(self, package_name, until=None) -> Set[str]:
def get(self, package_name, until=None) -> VersionResponse:
new_versions = set()
valid_versions = set()
for version in self.cache.get(package_name, set()):
Expand Down Expand Up @@ -104,7 +105,7 @@ async def fetch(self, pkg, session):
response = await session.request(method="GET", url=url)
resp_json = await response.json()
if resp_json["entries"] == []:
self.cache[pkg] = {}
self.cache[pkg] = set()
break
for release in resp_json["entries"]:
all_versions.add(
Expand All @@ -118,8 +119,12 @@ async def fetch(self, pkg, session):
else:
break
self.cache[pkg] = all_versions
except (ClientResponseError, asyncio.exceptions.TimeoutError, ServerDisconnectedError):
self.cache[pkg] = {}
except (
ClientResponseError,
asyncio.exceptions.TimeoutError,
ServerDisconnectedError,
):
self.cache[pkg] = set()


class PypiVersionAPI(VersionAPI):
Expand Down Expand Up @@ -242,16 +247,20 @@ async def fetch(self, pkg, session, retry_count=5):
resp_json = await response.json()

if resp_json.get("error") or not resp_json.get("versions"):
self.cache[pkg] = {}
self.cache[pkg] = set()
return
for release in resp_json["versions"]:
all_versions.add(Version(value=release["version"].replace("0:", "")))

self.cache[pkg] = all_versions
# TODO : Handle ServerDisconnectedError by using some sort of
# retry mechanism
except (ClientResponseError, asyncio.exceptions.TimeoutError, ServerDisconnectedError):
self.cache[pkg] = {}
except (
ClientResponseError,
asyncio.exceptions.TimeoutError,
ServerDisconnectedError,
):
self.cache[pkg] = set()


class MavenVersionAPI(VersionAPI):
Expand Down Expand Up @@ -295,10 +304,10 @@ def artifact_url(artifact_comps: List[str]) -> str:
return endpoint

@staticmethod
def extract_versions(xml_response: ET.ElementTree) -> Set[str]:
def extract_versions(xml_response: ET.ElementTree) -> Set[Version]:
all_versions = set()
for child in xml_response.getroot().iter():
if child.tag == "version":
if child.tag == "version" and child.text:
all_versions.add(Version(child.text))

return all_versions
Expand All @@ -321,7 +330,7 @@ def nuget_url(pkg_name: str) -> str:
return base_url.format(pkg_name)

@staticmethod
def extract_versions(resp: dict) -> Set[str]:
def extract_versions(resp: dict) -> Set[Version]:
all_versions = set()
try:
for entry_group in resp["items"]:
Expand Down Expand Up @@ -353,7 +362,7 @@ async def fetch(self, pkg, session) -> None:
self.cache[pkg] = self.extract_versions(resp, pkg)

@staticmethod
def composer_url(pkg_name: str) -> str:
def composer_url(pkg_name: str) -> Optional[str]:
try:
vendor, name = pkg_name.split("/")
except ValueError:
Expand All @@ -362,7 +371,7 @@ def composer_url(pkg_name: str) -> str:
return f"https://repo.packagist.org/p/{vendor}/{name}.json"

@staticmethod
def extract_versions(resp: dict, pkg_name: str) -> Set[str]:
def extract_versions(resp: dict, pkg_name: str) -> Set[Version]:
all_versions = set()
for version in resp["packages"][pkg_name]:
if "dev" in version:
Expand Down Expand Up @@ -412,7 +421,7 @@ class GitHubTagsAPI(VersionAPI):
}
}"""

def __init__(self, cache: Mapping[str, Set[Version]] = None):
def __init__(self, cache: MutableMapping[str, Set[Version]] = None):
self.gh_token = os.getenv("GH_TOKEN")
super().__init__(cache=cache)

Expand All @@ -427,7 +436,10 @@ async def fetch(self, owner_repo: str, session: aiohttp.ClientSession) -> None:
session.headers["Authorization"] = "token " + self.gh_token
endpoint = f"https://api.github.com/graphql"
owner, name = owner_repo.split("/")
query = {"query": self.GQL_QUERY, "variables": {"name": name, "owner": owner}}
query = {
"query": self.GQL_QUERY,
"variables": {"name": name, "owner": owner},
}

while True:
response = await session.post(endpoint, json=query)
Expand Down Expand Up @@ -489,3 +501,111 @@ async def fetch(self, pkg, session):
pass

self.cache[pkg] = versions


class GoproxyVersionAPI(VersionAPI):

package_type = "golang"
module_name_by_package_name = {}

@staticmethod
def trim_url_path(url_path: str) -> Optional[str]:
"""
Return a trimmed Go `url_path` removing trailing
package references and keeping only the module
references.

Github advisories for Go are using package names
such as "https://github.com/nats-io/nats-server/v2/server"
(e.g., https://github.com/advisories/GHSA-jp4j-47f9-2vc3 ),
yet goproxy works with module names instead such as
"https://github.com/nats-io/nats-server" (see for details
https://golang.org/ref/mod#goproxy-protocol ).
This functions trims the trailing part(s) of a package URL
and returns the remaining the module name.
For example:
>>> module = "https://github.com/xx/a"
>>> assert GoproxyVersionAPI.trim_url_path("https://github.com/xx/a/b") == module
"""
# some advisories contains this prefix in package name, e.g. https://github.com/advisories/GHSA-7h6j-2268-fhcm
if url_path.startswith("https://pkg.go.dev/"):
url_path = url_path.removeprefix("https://pkg.go.dev/")
parts = url_path.split("/")
if len(parts) >= 2:
return "/".join(parts[:-1])
else:
return None

@staticmethod
def escape_path(path: str) -> str:
"""
Return an case-encoded module path or version name.

This is done by replacing every uppercase letter with an exclamation
mark followed by the corresponding lower-case letter, in order to
avoid ambiguity when serving from case-insensitive file systems.
Refer to https://golang.org/ref/mod#goproxy-protocol.
"""
escaped_path = ""
for c in path:
if c >= "A" and c <= "Z":
# replace uppercase with !lowercase
escaped_path += "!" + chr(ord(c) + ord("a") - ord("A"))
else:
escaped_path += c
return escaped_path

@staticmethod
async def parse_version_info(
version_info: str, escaped_pkg: str, session: ClientSession
) -> Optional[Version]:
v = version_info.split()
if not v:
return None
value = v[0]
if len(v) > 1:
# get release date from the second part. see https://github.com/golang/go/blob/master/src/cmd/go/internal/modfetch/proxy.go#latest()
release_date = parse_datetime(v[1])
else:
escaped_ver = GoproxyVersionAPI.escape_path(value)
try:
response = await session.request(
method="GET",
url=f"https://proxy.golang.org/{escaped_pkg}/@v/{escaped_ver}.info",
)
resp_json = await response.json()
release_date = parse_datetime(resp_json.get("Time", ""))
except:
traceback.print_exc()
print(
f"error while fetching version info for {escaped_pkg}/{escaped_ver} from goproxy"
)
release_date = None
return Version(value=value, release_date=release_date)

async def fetch(self, pkg: str, session: ClientSession):
# escape uppercase in module path
escaped_pkg = GoproxyVersionAPI.escape_path(pkg)
trimmed_pkg = pkg
resp_text = None
# resolve module name from package name, see https://go.dev/ref/mod#resolve-pkg-mod
while escaped_pkg is not None:
url = f"https://proxy.golang.org/{escaped_pkg}/@v/list"
try:
response = await session.request(method="GET", url=url)
resp_text = await response.text()
except HTTPGone:
escaped_pkg = GoproxyVersionAPI.trim_url_path(escaped_pkg)
trimmed_pkg = GoproxyVersionAPI.trim_url_path(trimmed_pkg) or ""
continue
break
if resp_text is None or escaped_pkg is None or trimmed_pkg is None:
print(f"error while fetching versions for {pkg} from goproxy")
return
self.module_name_by_package_name[pkg] = trimmed_pkg
versions = set()
for version_info in resp_text.split("\n"):
version = await GoproxyVersionAPI.parse_version_info(version_info, escaped_pkg, session)
if version is not None:
versions.add(version)
self.cache[pkg] = versions
5 changes: 5 additions & 0 deletions vulnerabilities/tests/test_data/goproxy_api/ferretdb_versions
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
v0.0.1
v0.0.5
v0.0.3
v0.0.4
v0.0.2
1 change: 1 addition & 0 deletions vulnerabilities/tests/test_data/goproxy_api/version_info
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"Version":"v0.0.5","Time":"2022-01-04T13:54:01Z"}
49 changes: 48 additions & 1 deletion vulnerabilities/tests/test_package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
from unittest.mock import AsyncMock

from aiohttp.client import ClientSession
from dateutil.tz import tzlocal
from dateutil.tz import tzlocal, tzutc
from pytz import UTC

from vulnerabilities.package_managers import ComposerVersionAPI
from vulnerabilities.package_managers import GoproxyVersionAPI
from vulnerabilities.package_managers import GitHubTagsAPI
from vulnerabilities.package_managers import MavenVersionAPI
from vulnerabilities.package_managers import NugetVersionAPI
Expand All @@ -54,6 +55,7 @@ async def request(self, *args, **kwargs):
mock_response = AsyncMock()
mock_response.json = self.json
mock_response.read = self.read
mock_response.text = self.text
return mock_response

def get(self, *args, **kwargs):
Expand All @@ -70,6 +72,9 @@ async def json(self):
async def read(self):
return self.return_val

async def text(self):
return self.return_val


class RecordedClientSession:
def __init__(self, test_id, regen=False):
Expand Down Expand Up @@ -449,6 +454,48 @@ def test_fetch(self):
assert self.version_api.get("org.apache:kafka") == VersionResponse(valid_versions=expected)


class TestGoproxyVersionAPI(TestCase):
def test_trim_url_path(self):
url1 = "https://pkg.go.dev/github.com/containous/traefik/v2"
url2 = "github.com/FerretDB/FerretDB/cmd/ferretdb"
url3 = GoproxyVersionAPI.trim_url_path(url2)
assert "github.com/containous/traefik" == GoproxyVersionAPI.trim_url_path(url1)
assert "github.com/FerretDB/FerretDB/cmd" == url3
assert "github.com/FerretDB/FerretDB" == GoproxyVersionAPI.trim_url_path(url3)

def test_escape_path(self):
path = "github.com/FerretDB/FerretDB"
assert "github.com/!ferret!d!b/!ferret!d!b" == GoproxyVersionAPI.escape_path(path)

def test_parse_version_info(self):
with open(os.path.join(TEST_DATA, "goproxy_api", "version_info")) as f:
vinfo = json.load(f)
client_session = MockClientSession(vinfo)
assert asyncio.run(
GoproxyVersionAPI.parse_version_info(
"v0.0.5", "github.com/!ferret!d!b/!ferret!d!b", client_session
)
) == Version(
value="v0.0.5",
release_date=datetime(2022, 1, 4, 13, 54, 1, tzinfo=tzutc()),
)

def test_fetch(self):
version_api = GoproxyVersionAPI()
assert version_api.get("github.com/FerretDB/FerretDB") == VersionResponse()
with open(os.path.join(TEST_DATA, "goproxy_api", "ferretdb_versions")) as f:
vlist = f.read()
client_session = MockClientSession(vlist)
asyncio.run(version_api.fetch("github.com/FerretDB/FerretDB", client_session))
assert version_api.cache["github.com/FerretDB/FerretDB"] == {
Version(value="v0.0.1"),
Version(value="v0.0.2"),
Version(value="v0.0.3"),
Version(value="v0.0.4"),
Version(value="v0.0.5"),
}


class TestNugetVersionAPI(TestCase):
@classmethod
def setUpClass(cls):
Expand Down