Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC Set cachedir and backupcachedir as parameter for parallel instances of the tool with own database #4773

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cve_bin_tool/cve_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class CVEScanner:
all_cve_version_info: Dict[str, VersionInfo]

RANGE_UNSET: str = ""
dbname: str = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
CONSOLE: Console = Console(file=sys.stderr, theme=cve_theme)
ALPHA_TO_NUM: Dict[str, int] = dict(zip(ascii_lowercase, range(26)))
CACHEDIR = DISK_LOCATION_DEFAULT

def __init__(
self,
Expand All @@ -46,6 +46,7 @@ def __init__(
check_exploits: bool = False,
exploits_list: List[str] = [],
disabled_sources: List[str] = [],
cachedir: str = None,
):
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
self.error_mode = error_mode
Expand All @@ -61,6 +62,7 @@ def __init__(
self.exploits_list = exploits_list
self.disabled_sources = disabled_sources
self.all_product_data = dict()
self.dbname = str(Path(cachedir) / DBNAME) if cachedir is not None else DISK_LOCATION_DEFAULT

def get_cves(self, product_info: ProductInfo, triage_data: TriageData):
"""Get CVEs against a specific version of a product.
Expand Down
11 changes: 8 additions & 3 deletions cve_bin_tool/data_sources/curl_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ class Curl_Source(Data_Source):
LOGGER = LOGGER.getChild("CVEDB")
DATA_SOURCE_LINK = "https://curl.se/docs/vuln.json"

def __init__(self, error_mode=ErrorMode.TruncTrace):
def __init__(
self,
error_mode=ErrorMode.TruncTrace,
cachedir: str = None,
backup_cachedir: str = None
):
"""Initialize a Curl_Source instance. Args: error_mode (ErrorMode): The error mode to be used."""
self.cve_list = None
self.cachedir = self.CACHEDIR
self.backup_cachedir = self.BACKUPCACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.backup_cachedir = Path(backup_cachedir) if backup_cachedir is not None else self.BACKUPCACHEDIR
self.error_mode = error_mode
self.session = None
self.affected_data = None
Expand Down
11 changes: 8 additions & 3 deletions cve_bin_tool/data_sources/epss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@ class Epss_Source:
LOGGER = logging.getLogger().getChild("CVEDB")
DATA_SOURCE_LINK = "https://epss.cyentia.com/epss_scores-current.csv.gz"

def __init__(self, error_mode=ErrorMode.TruncTrace):
def __init__(
self,
error_mode=ErrorMode.TruncTrace,
cachedir: str = None,
backup_cachedir: str = None
):
self.epss_data = None
self.error_mode = error_mode
self.cachedir = self.CACHEDIR
self.backup_cachedir = self.BACKUPCACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.backup_cachedir = Path(backup_cachedir) if backup_cachedir is not None else self.BACKUPCACHEDIR
self.epss_path = str(Path(self.cachedir) / "epss")
self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv")
self.source_name = self.SOURCE
Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/data_sources/gad_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class GAD_Source(Data_Source):
GAD_API_URL = "https://gitlab.com/api/v4/projects/12006272/repository/tree"

def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None
):
self.cachedir = self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.slugs = None
self.gad_path = str(Path(self.cachedir) / "gad")
self.source_name = self.SOURCE
Expand Down
6 changes: 4 additions & 2 deletions cve_bin_tool/data_sources/nvd_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ def __init__(
nvd_type: str = "json-mirror",
incremental_update: bool = False,
nvd_api_key: str = "",
cachedir: str = None,
backup_cachedir: str = None,
):
if feed is None:
self.feed = self.FEED_NVD if nvd_type == "json-nvd" else self.FEED_MIRROR
else:
self.feed = feed
self.cachedir = self.CACHEDIR
self.backup_cachedir = self.BACKUPCACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.backup_cachedir = Path(backup_cachedir) if backup_cachedir is not None else self.BACKUPCACHEDIR
self.error_mode = error_mode
self.source_name = self.SOURCE

Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/data_sources/osv_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class OSV_Source(Data_Source):
OSV_GS_URL = "gs://osv-vulnerabilities/"

def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None
):
self.cachedir = self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.ecosystems = None
self.osv_path = str(Path(self.cachedir) / "osv")
self.source_name = self.SOURCE
Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/data_sources/purl2cpe_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class PURL2CPE_Source(Data_Source):
PURL2CPE_URL = "https://github.com/scanoss/purl2cpe/raw/main/purl2cpe.db.zip"

def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str | None = None
):
self.cachedir = self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.purl2cpe_path = str(Path(self.cachedir) / "purl2cpe")
self.source_name = self.SOURCE
self.error_mode = error_mode
Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/data_sources/redhat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class REDHAT_Source(Data_Source):
CVE_ENDPOINT = "/cve.json"

def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None,
):
self.cachedir = self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.redhat_path = str(Path(self.cachedir) / "redhat")
self.source_name = self.SOURCE

Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/data_sources/rsd_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class RSD_Source(Data_Source):
RSD_API_URL = "https://gitlab.com/api/v4/projects/39314828/repository/tree"

def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None
):
self.cachedir = self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.rsd_path = str(Path(self.cachedir) / "rsd")
self.source_name = self.SOURCE

Expand Down
3 changes: 2 additions & 1 deletion cve_bin_tool/helper_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
product_name: str | None = None,
version_number: str | None = None,
string_length: int = 40,
cachedir: str = None,
):
self.extractor: TempDirExtractorContext = Extractor()
self.product_name = product_name
Expand All @@ -45,7 +46,7 @@ def __init__(

# for setting the database
self.connection = None
self.dbpath = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
self.dbname = str(Path(cachedir) / DBNAME) if cachedir is not None else DISK_LOCATION_DEFAULT

# for extraction
self.walker = DirWalk().walk
Expand Down
5 changes: 4 additions & 1 deletion cve_bin_tool/input_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, DefaultDict, Dict, Iterable, Set, Union

from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT
from cve_bin_tool.error_handler import (
ErrorHandler,
ErrorMode,
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
logger: Logger = None,
error_mode=ErrorMode.TruncTrace,
filetype="autodetect",
cachedir: str = None,
):
"""
Initializes the InputEngine instance.
Expand All @@ -82,8 +84,9 @@ def __init__(
self.error_mode = error_mode
self.filetype = filetype
self.parsed_data = defaultdict(dict)
self.cachedir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT
# Connect to the database
self.cvedb = CVEDB(version_check=False)
self.cvedb = CVEDB(version_check=False, cachedir=self.cachedir)

def parse_input(self) -> DefaultDict[ProductInfo, TriageData]:
"""
Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
merge_files: list[str],
logger: Logger | None = None,
error_mode=ErrorMode.TruncTrace,
cache_dir=DISK_LOCATION_DEFAULT,
cachedir: str = None,
score=0,
filter_tag=[],
):
Expand All @@ -52,7 +52,7 @@ def __init__(
self.total_files = 0
self.products_with_cve = 0
self.products_without_cve = 0
self.cache_dir = cache_dir
self.cache_dir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT
self.merged_files = ["tag"]
self.score = score
self.filter_tag = filter_tag
Expand Down
9 changes: 7 additions & 2 deletions cve_bin_tool/output_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pathlib import Path

from cve_bin_tool.cve_scanner import CVEData
from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.cvedb import CVEDB, DISK_LOCATION_DEFAULT
from cve_bin_tool.error_handler import ErrorHandler, ErrorMode
from cve_bin_tool.log import LOGGER
from cve_bin_tool.output_engine.console import output_console
Expand Down Expand Up @@ -130,9 +130,10 @@ def output_pdf(
exploits: bool = False,
metrics: bool = False,
all_product_data=None,
cachedir: str = DISK_LOCATION_DEFAULT,
):
"""Output a PDF of CVEs"""
cvedb_data = CVEDB()
cvedb_data = CVEDB(cachedir=cachedir)
db_date = time.strftime(
"%d %B %Y at %H:%M:%S", time.localtime(cvedb_data.get_db_update_date())
)
Expand Down Expand Up @@ -595,6 +596,7 @@ def output_pdf(
affected_versions: int = 0,
exploits: bool = False,
all_product_data=None,
cachedir: str = DISK_LOCATION_DEFAULT
):
"""Output a PDF of CVEs
Required module: Reportlab not found"""
Expand Down Expand Up @@ -673,6 +675,7 @@ def __init__(
vex_product_info: dict[str, str] = {},
offline: bool = False,
organized_arguements: dict = None,
cachedir: str = DISK_LOCATION_DEFAULT
):
"""Constructor for OutputEngine class."""
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
Expand Down Expand Up @@ -705,6 +708,7 @@ def __init__(
self.vex_type = vex_type
self.vex_product_info = vex_product_info
self.vex_filename = vex_filename
self.cachedir = cachedir

def output_cves(self, outfile, output_type="console"):
"""Output a list of CVEs
Expand Down Expand Up @@ -752,6 +756,7 @@ def output_cves(self, outfile, output_type="console"):
self.affected_versions,
self.exploits,
self.metrics,
self.cachedir,
)
elif output_type == "html":
output_html(
Expand Down
5 changes: 4 additions & 1 deletion cve_bin_tool/package_list_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import distro

from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT
from cve_bin_tool.error_handler import (
EmptyTxtError,
ErrorHandler,
Expand All @@ -35,6 +36,7 @@ def __init__(
input_file: str,
logger: Logger = LOGGER.getChild("PackageListParser"),
error_mode=ErrorMode.TruncTrace,
cachedir: str = None,
) -> None:
"""
Initialize the PackageListParser object.
Expand All @@ -56,6 +58,7 @@ def __init__(
self.parsed_data_with_vendor: Dict[Any, Any] = defaultdict(dict)
self.package_names_with_vendor: List[Any] = []
self.package_names_without_vendor: List[Any] = []
self.cache_dir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT,

def parse_list(self):
"""
Expand Down Expand Up @@ -147,7 +150,7 @@ def parse_list(self):
if package_name in txt_package_names:
self.package_names_without_vendor.append(installed_package)

cve_db = CVEDB()
cve_db = CVEDB(cachedir=self.cache_dir)
vendor_package_pairs = cve_db.get_vendor_product_pairs(
self.package_names_without_vendor
)
Expand Down
4 changes: 3 additions & 1 deletion cve_bin_tool/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import sqlite3
from pathlib import Path

from packageurl import PackageURL

Expand Down Expand Up @@ -41,14 +42,15 @@ class Parser:
filename (str): The filename of the data to be processed.
"""

def __init__(self, cve_db, logger):
def __init__(self, cve_db, logger, cachedir : str = None):
"""Initializes a Parser object."""
self.cve_db = cve_db
self.logger = logger
self.filename = ""
self.purl_pkg_type = "default"
self.connection: sqlite3.Connection | None = None
self.dbpath = DISK_LOCATION_DEFAULT / DBNAME
self.dbname = str(Path(cachedir) / DBNAME) if cachedir is not None else DISK_LOCATION_DEFAULT

def run_checker(self, filename):
"""
Expand Down
7 changes: 6 additions & 1 deletion cve_bin_tool/sbom_manager/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from packageurl import PackageURL

from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT
from cve_bin_tool.input_engine import TriageData
from cve_bin_tool.log import LOGGER
from cve_bin_tool.util import (
Expand Down Expand Up @@ -45,12 +46,15 @@ class SBOMParse:

sbom_data: defaultdict[ProductInfo, TriageData]

CACHEDIR = DISK_LOCATION_DEFAULT

def __init__(
self,
filename: str,
sbom_type: str = "spdx",
logger: Logger | None = None,
validate: bool = True,
cachedir: str = None,
):
self.filename = filename
self.sbom_data = defaultdict(dict)
Expand All @@ -60,9 +64,10 @@ def __init__(
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
self.validate = validate
self.serialNumber = ""
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR

# Connect to the database
self.cvedb = CVEDB(version_check=False)
self.cvedb = CVEDB(version_check=False, cachedir=self.cachedir)

def parse_sbom(self) -> dict[ProductInfo, TriageData]:
"""
Expand Down
5 changes: 4 additions & 1 deletion cve_bin_tool/version_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from cve_bin_tool.checkers import BUILTIN_CHECKERS, Checker
from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT
from cve_bin_tool.egg_updater import IS_DEVELOP, update_egg
from cve_bin_tool.error_handler import ErrorMode
from cve_bin_tool.extractor import Extractor, TempDirExtractorContext
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
score: int = 0,
validate: bool = True,
sources=None,
cachedir: str = None,
):
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
# Update egg if installed in development mode
Expand All @@ -66,7 +68,8 @@ def __init__(
self.should_extract = should_extract
self.file_stack: list[str] = []
self.error_mode = error_mode
self.cve_db = CVEDB(sources=sources)
self.cache_dir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT
self.cve_db = CVEDB(sources=sources,cachedir=self.cache_dir)
self.validate = validate
# self.logger.info("Checkers loaded: %s" % (", ".join(self.checkers.keys())))
self.language_checkers = self.available_language_checkers()
Expand Down
4 changes: 2 additions & 2 deletions cve_bin_tool/version_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class InvalidVersionSignatureTable(ValueError):
class VersionSignatureDb:
"""Methods for version signature data stored in sqlite"""

def __init__(self, table_name, mapping_function, duration) -> None:
def __init__(self, table_name, mapping_function, duration,cachedir: str = None,) -> None:
"""Set location on disk data cache will reside.
Also sets the table name and refresh duration
"""
Expand All @@ -28,7 +28,7 @@ def __init__(self, table_name, mapping_function, duration) -> None:
self.table_name = table_name
self.update_table_name = f"latest_update_{table_name}"
self.mapping_function = mapping_function
self.disk_location = DISK_LOCATION_DEFAULT
self.disk_location = cachedir if cachedir is not None else DISK_LOCATION_DEFAULT
self.duration = duration
self.conn: sqlite3.Connection | None = None
self.cursor: sqlite3.Cursor | None = None
Expand Down
Loading