Skip to content

Collect go vulnerabilities from github api #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 15, 2022
2 changes: 1 addition & 1 deletion vulnerabilities/importer_yielder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
"data_source": "GitHubAPIDataSource",
"data_source_cfg": {
"endpoint": "https://api.github.com/graphql",
"ecosystems": ["MAVEN", "NUGET", "COMPOSER", "PIP", "RUBYGEMS"],
"ecosystems": ["MAVEN", "NUGET", "COMPOSER", "PIP", "RUBYGEMS", "GO"],
},
},
{
Expand Down
6 changes: 5 additions & 1 deletion vulnerabilities/importers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from vulnerabilities.package_managers import NugetVersionAPI
from vulnerabilities.package_managers import ComposerVersionAPI
from vulnerabilities.package_managers import PypiVersionAPI
from vulnerabilities.package_managers import GoproxyVersionAPI
from vulnerabilities.package_managers import RubyVersionAPI
from vulnerabilities.severity_systems import scoring_systems
from vulnerabilities.helpers import nearest_patched_package
Expand Down Expand Up @@ -206,6 +207,7 @@ def set_version_api(self, ecosystem: str) -> None:
"COMPOSER": ComposerVersionAPI,
"PIP": PypiVersionAPI,
"RUBYGEMS": RubyVersionAPI,
"GO": GoproxyVersionAPI,
}
versioner = versioners.get(ecosystem)
if versioner:
Expand All @@ -229,7 +231,7 @@ def process_name(ecosystem: str, pkg_name: str) -> Optional[Tuple[Optional[str],
return None
return vendor, name

if ecosystem == "NUGET" or ecosystem == "PIP" or ecosystem == "RUBYGEMS":
if ecosystem == "NUGET" or ecosystem == "PIP" or ecosystem == "RUBYGEMS" or ecosystem == "GO":
return None, pkg_name

@staticmethod
Expand Down Expand Up @@ -265,6 +267,8 @@ def process_response(self) -> List[Advisory]:
unaffected_purls = []
if self.process_name(ecosystem, name):
ns, pkg_name = self.process_name(ecosystem, name)
if hasattr(self.version_api, "module_name_by_package_name"):
pkg_name = self.version_api.module_name_by_package_name.get(name, pkg_name)
aff_range = adv["node"]["vulnerableVersionRange"]
aff_vers, unaff_vers = self.categorize_versions(
self.version_api.package_type,
Expand Down
153 changes: 129 additions & 24 deletions vulnerabilities/package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@
from datetime import datetime
from json import JSONDecodeError
from subprocess import check_output
from typing import List
from typing import Mapping
from typing import Set
from typing import Set, List, MutableMapping, Optional
from django.utils.dateparse import parse_datetime

from aiohttp import ClientSession
import aiohttp
from aiohttp.client_exceptions import ClientResponseError
from aiohttp.client_exceptions import ServerDisconnectedError
from aiohttp.web_exceptions import HTTPGone
from bs4 import BeautifulSoup
from dateutil import parser as dateparser


@dataclasses.dataclass(frozen=True)
class Version:
value: str
release_date: datetime = None
release_date: Optional[datetime] = None


@dataclasses.dataclass
Expand All @@ -55,10 +55,10 @@ class GraphQLError(Exception):


class VersionAPI:
def __init__(self, cache: Mapping[str, Set[Version]] = None):
def __init__(self, cache: MutableMapping[str, Set[Version]] = None):
self.cache = cache or {}

def get(self, package_name, until=None) -> Set[str]:
def get(self, package_name, until=None) -> VersionResponse:
new_versions = set()
valid_versions = set()
for version in self.cache.get(package_name, set()):
Expand All @@ -67,7 +67,9 @@ def get(self, package_name, until=None) -> Set[str]:
continue
valid_versions.add(version.value)

return VersionResponse(valid_versions=valid_versions, newer_versions=new_versions)
return VersionResponse(
valid_versions=valid_versions, newer_versions=new_versions
)

async def load_api(self, pkg_set):
async with client_session() as session:
Expand Down Expand Up @@ -104,7 +106,7 @@ async def fetch(self, pkg, session):
response = await session.request(method="GET", url=url)
resp_json = await response.json()
if resp_json["entries"] == []:
self.cache[pkg] = {}
self.cache[pkg] = set()
break
for release in resp_json["entries"]:
all_versions.add(
Expand All @@ -118,8 +120,12 @@ async def fetch(self, pkg, session):
else:
break
self.cache[pkg] = all_versions
except (ClientResponseError, asyncio.exceptions.TimeoutError, ServerDisconnectedError):
self.cache[pkg] = {}
except (
ClientResponseError,
asyncio.exceptions.TimeoutError,
ServerDisconnectedError,
):
self.cache[pkg] = set()


class PypiVersionAPI(VersionAPI):
Expand Down Expand Up @@ -242,16 +248,20 @@ async def fetch(self, pkg, session, retry_count=5):
resp_json = await response.json()

if resp_json.get("error") or not resp_json.get("versions"):
self.cache[pkg] = {}
self.cache[pkg] = set()
return
for release in resp_json["versions"]:
all_versions.add(Version(value=release["version"].replace("0:", "")))

self.cache[pkg] = all_versions
# TODO : Handle ServerDisconnectedError by using some sort of
# retry mechanism
except (ClientResponseError, asyncio.exceptions.TimeoutError, ServerDisconnectedError):
self.cache[pkg] = {}
except (
ClientResponseError,
asyncio.exceptions.TimeoutError,
ServerDisconnectedError,
):
self.cache[pkg] = set()


class MavenVersionAPI(VersionAPI):
Expand Down Expand Up @@ -295,10 +305,10 @@ def artifact_url(artifact_comps: List[str]) -> str:
return endpoint

@staticmethod
def extract_versions(xml_response: ET.ElementTree) -> Set[str]:
def extract_versions(xml_response: ET.ElementTree) -> Set[Version]:
all_versions = set()
for child in xml_response.getroot().iter():
if child.tag == "version":
if child.tag == "version" and child.text:
all_versions.add(Version(child.text))

return all_versions
Expand All @@ -321,15 +331,17 @@ def nuget_url(pkg_name: str) -> str:
return base_url.format(pkg_name)

@staticmethod
def extract_versions(resp: dict) -> Set[str]:
def extract_versions(resp: dict) -> Set[Version]:
all_versions = set()
try:
for entry_group in resp["items"]:
for entry in entry_group["items"]:
all_versions.add(
Version(
value=entry["catalogEntry"]["version"],
release_date=dateparser.parse(entry["catalogEntry"]["published"]),
release_date=dateparser.parse(
entry["catalogEntry"]["published"]
),
)
)
# FIXME: json response for YamlDotNet.Signed triggers this exception.
Expand All @@ -353,7 +365,7 @@ async def fetch(self, pkg, session) -> None:
self.cache[pkg] = self.extract_versions(resp, pkg)

@staticmethod
def composer_url(pkg_name: str) -> str:
def composer_url(pkg_name: str) -> Optional[str]:
try:
vendor, name = pkg_name.split("/")
except ValueError:
Expand All @@ -362,7 +374,7 @@ def composer_url(pkg_name: str) -> str:
return f"https://repo.packagist.org/p/{vendor}/{name}.json"

@staticmethod
def extract_versions(resp: dict, pkg_name: str) -> Set[str]:
def extract_versions(resp: dict, pkg_name: str) -> Set[Version]:
all_versions = set()
for version in resp["packages"][pkg_name]:
if "dev" in version:
Expand All @@ -374,7 +386,9 @@ def extract_versions(resp: dict, pkg_name: str) -> Set[str]:
all_versions.add(
Version(
value=version.lstrip("v"),
release_date=dateparser.parse(resp["packages"][pkg_name][version]["time"]),
release_date=dateparser.parse(
resp["packages"][pkg_name][version]["time"]
),
)
)
return all_versions
Expand Down Expand Up @@ -412,7 +426,7 @@ class GitHubTagsAPI(VersionAPI):
}
}"""

def __init__(self, cache: Mapping[str, Set[Version]] = None):
def __init__(self, cache: MutableMapping[str, Set[Version]] = None):
self.gh_token = os.getenv("GH_TOKEN")
super().__init__(cache=cache)

Expand All @@ -427,7 +441,10 @@ async def fetch(self, owner_repo: str, session: aiohttp.ClientSession) -> None:
session.headers["Authorization"] = "token " + self.gh_token
endpoint = f"https://api.github.com/graphql"
owner, name = owner_repo.split("/")
query = {"query": self.GQL_QUERY, "variables": {"name": name, "owner": owner}}
query = {
"query": self.GQL_QUERY,
"variables": {"name": name, "owner": owner},
}

while True:
response = await session.post(endpoint, json=query)
Expand All @@ -451,7 +468,9 @@ async def fetch(self, owner_repo: str, session: aiohttp.ClientSession) -> None:
# probably this only happened for linux. Github cannot even properly display it.
# https://kernel.googlesource.com/pub/scm/linux/kernel/git/torvalds/linux/+/refs/tags/v2.6.11
release_date = None
self.cache[owner_repo].add(Version(value=name, release_date=release_date))
self.cache[owner_repo].add(
Version(value=name, release_date=release_date)
)

if not refs["pageInfo"]["hasNextPage"]:
break
Expand All @@ -464,7 +483,9 @@ async def fetch(self, owner_repo: str, session: aiohttp.ClientSession) -> None:
# this method is however not scalable for larger repo and the api is unresponsive
# for repo with > 50 tags
endpoint = f"https://github.com/{owner_repo}"
tags_xml = check_output(["svn", "ls", "--xml", f"{endpoint}/tags"], text=True)
tags_xml = check_output(
["svn", "ls", "--xml", f"{endpoint}/tags"], text=True
)
elements = ET.fromstring(tags_xml)
for entry in elements.iter("entry"):
name = entry.find("name").text
Expand All @@ -489,3 +510,87 @@ async def fetch(self, pkg, session):
pass

self.cache[pkg] = versions


class GoproxyVersionAPI(VersionAPI):

package_type = "golang"
module_name_by_package_name = {}

@staticmethod
def trim_url_path(url_path: str) -> Optional[str]:
"""github advisories for golang is using package names(e.g. https://github.com/advisories/GHSA-jp4j-47f9-2vc3), yet goproxy works with module names(see https://golang.org/ref/mod#goproxy-protocol).
this method removes the last part of a package path, and returns the remaining as the module name. For example: trim_url_path("https://github.com/xx/a/b") returns "https://github.com/xx/a"
"""
# some advisories contains this prefix in package name, e.g. https://github.com/advisories/GHSA-7h6j-2268-fhcm
if url_path.startswith("https://pkg.go.dev/"):
url_path = url_path.removeprefix("https://pkg.go.dev/")
parts = url_path.split("/")
if len(parts) >= 2:
return "/".join(parts[:-1])
else:
return None

@staticmethod
def escape_path(path: str) -> str:
"""escape uppercase in module/version name. For example: escape_path("github.com/FerretDB/FerretDB") returns "github.com/!ferret!d!b/!ferret!d!b" """
escaped_path = ""
for c in path:
if c >= "A" and c <= "Z":
# replace uppercase with !lowercase
escaped_path += "!" + chr(ord(c) + ord("a") - ord("A"))
else:
escaped_path += c
return escaped_path

@staticmethod
async def parse_version_info(
version_info: str, escaped_pkg: str, session: ClientSession
) -> Optional[Version]:
v = version_info.split()
if len(v) > 0:
value = v[0]
if len(v) > 1:
release_date = parse_datetime(v[1])
else:
escaped_ver = GoproxyVersionAPI.escape_path(value)
try:
response = await session.request(
method="GET",
url=f"https://proxy.golang.org/{escaped_pkg}/@v/{escaped_ver}.info",
)
resp_json = await response.json()
release_date = parse_datetime(resp_json.get("Time", ""))
except:
release_date = None
return Version(value=value, release_date=release_date)
return None

async def fetch(self, pkg: str, session: ClientSession):
# escape uppercase in module path
escaped_pkg = GoproxyVersionAPI.escape_path(pkg)
trimmed_pkg = pkg
resp_text = None
# resolve module name from package name, see https://go.dev/ref/mod#resolve-pkg-mod
while escaped_pkg is not None:
url = f"https://proxy.golang.org/{escaped_pkg}/@v/list"
try:
response = await session.request(method="GET", url=url)
resp_text = await response.text()
except HTTPGone:
escaped_pkg = GoproxyVersionAPI.trim_url_path(escaped_pkg)
trimmed_pkg = GoproxyVersionAPI.trim_url_path(trimmed_pkg) or ""
continue
break
if resp_text is None or escaped_pkg is None or trimmed_pkg is None:
print(f"error while fetching versions for {pkg} from goproxy")
return
self.module_name_by_package_name[pkg] = trimmed_pkg
versions = set()
for version_info in resp_text.split("\n"):
version = await GoproxyVersionAPI.parse_version_info(
version_info, escaped_pkg, session
)
if version is not None:
versions.add(version)
self.cache[pkg] = versions