diff --git a/cve_bin_tool/cve_scanner.py b/cve_bin_tool/cve_scanner.py index 6eab229d02..825a8ab89d 100644 --- a/cve_bin_tool/cve_scanner.py +++ b/cve_bin_tool/cve_scanner.py @@ -31,9 +31,9 @@ class CVEScanner: all_cve_version_info: Dict[str, VersionInfo] RANGE_UNSET: str = "" - dbname: str = str(Path(DISK_LOCATION_DEFAULT) / DBNAME) CONSOLE: Console = Console(file=sys.stderr, theme=cve_theme) ALPHA_TO_NUM: Dict[str, int] = dict(zip(ascii_lowercase, range(26))) + CACHEDIR = DISK_LOCATION_DEFAULT def __init__( self, @@ -46,6 +46,7 @@ def __init__( check_exploits: bool = False, exploits_list: List[str] = [], disabled_sources: List[str] = [], + cachedir: str = None, ): self.logger = logger or LOGGER.getChild(self.__class__.__name__) self.error_mode = error_mode @@ -61,6 +62,7 @@ def __init__( self.exploits_list = exploits_list self.disabled_sources = disabled_sources self.all_product_data = dict() + self.dbname = str(Path(cachedir) / DBNAME) if cachedir is not None else DISK_LOCATION_DEFAULT def get_cves(self, product_info: ProductInfo, triage_data: TriageData): """Get CVEs against a specific version of a product. diff --git a/cve_bin_tool/data_sources/curl_source.py b/cve_bin_tool/data_sources/curl_source.py index 819dd2f9aa..16451f47e7 100644 --- a/cve_bin_tool/data_sources/curl_source.py +++ b/cve_bin_tool/data_sources/curl_source.py @@ -31,11 +31,16 @@ class Curl_Source(Data_Source): LOGGER = LOGGER.getChild("CVEDB") DATA_SOURCE_LINK = "https://curl.se/docs/vuln.json" - def __init__(self, error_mode=ErrorMode.TruncTrace): + def __init__( + self, + error_mode=ErrorMode.TruncTrace, + cachedir: str = None, + backup_cachedir: str = None + ): """Initialize a Curl_Source instance. Args: error_mode (ErrorMode): The error mode to be used.""" self.cve_list = None - self.cachedir = self.CACHEDIR - self.backup_cachedir = self.BACKUPCACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR + self.backup_cachedir = Path(backup_cachedir) if backup_cachedir is not None else self.BACKUPCACHEDIR self.error_mode = error_mode self.session = None self.affected_data = None diff --git a/cve_bin_tool/data_sources/epss_source.py b/cve_bin_tool/data_sources/epss_source.py index 6d7f05b47c..0dc0bbfbf3 100644 --- a/cve_bin_tool/data_sources/epss_source.py +++ b/cve_bin_tool/data_sources/epss_source.py @@ -29,11 +29,16 @@ class Epss_Source: LOGGER = logging.getLogger().getChild("CVEDB") DATA_SOURCE_LINK = "https://epss.cyentia.com/epss_scores-current.csv.gz" - def __init__(self, error_mode=ErrorMode.TruncTrace): + def __init__( + self, + error_mode=ErrorMode.TruncTrace, + cachedir: str = None, + backup_cachedir: str = None + ): self.epss_data = None self.error_mode = error_mode - self.cachedir = self.CACHEDIR - self.backup_cachedir = self.BACKUPCACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR + self.backup_cachedir = Path(backup_cachedir) if backup_cachedir is not None else self.BACKUPCACHEDIR self.epss_path = str(Path(self.cachedir) / "epss") self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv") self.source_name = self.SOURCE diff --git a/cve_bin_tool/data_sources/gad_source.py b/cve_bin_tool/data_sources/gad_source.py index 191ece779e..2d5f0342a0 100644 --- a/cve_bin_tool/data_sources/gad_source.py +++ b/cve_bin_tool/data_sources/gad_source.py @@ -34,9 +34,9 @@ class GAD_Source(Data_Source): GAD_API_URL = "https://gitlab.com/api/v4/projects/12006272/repository/tree" def __init__( - self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False + self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None ): - self.cachedir = self.CACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR self.slugs = None self.gad_path = str(Path(self.cachedir) / "gad") self.source_name = self.SOURCE diff --git a/cve_bin_tool/data_sources/nvd_source.py b/cve_bin_tool/data_sources/nvd_source.py index 1ee0a4d34b..2949d796e5 100644 --- a/cve_bin_tool/data_sources/nvd_source.py +++ b/cve_bin_tool/data_sources/nvd_source.py @@ -66,13 +66,15 @@ def __init__( nvd_type: str = "json-mirror", incremental_update: bool = False, nvd_api_key: str = "", + cachedir: str = None, + backup_cachedir: str = None, ): if feed is None: self.feed = self.FEED_NVD if nvd_type == "json-nvd" else self.FEED_MIRROR else: self.feed = feed - self.cachedir = self.CACHEDIR - self.backup_cachedir = self.BACKUPCACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR + self.backup_cachedir = Path(backup_cachedir) if backup_cachedir is not None else self.BACKUPCACHEDIR self.error_mode = error_mode self.source_name = self.SOURCE diff --git a/cve_bin_tool/data_sources/osv_source.py b/cve_bin_tool/data_sources/osv_source.py index 6cfed5ef75..b2460f1025 100644 --- a/cve_bin_tool/data_sources/osv_source.py +++ b/cve_bin_tool/data_sources/osv_source.py @@ -30,9 +30,9 @@ class OSV_Source(Data_Source): OSV_GS_URL = "gs://osv-vulnerabilities/" def __init__( - self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False + self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None ): - self.cachedir = self.CACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR self.ecosystems = None self.osv_path = str(Path(self.cachedir) / "osv") self.source_name = self.SOURCE diff --git a/cve_bin_tool/data_sources/purl2cpe_source.py b/cve_bin_tool/data_sources/purl2cpe_source.py index 6fa17be6b5..74d8966be9 100644 --- a/cve_bin_tool/data_sources/purl2cpe_source.py +++ b/cve_bin_tool/data_sources/purl2cpe_source.py @@ -21,9 +21,9 @@ class PURL2CPE_Source(Data_Source): PURL2CPE_URL = "https://github.com/scanoss/purl2cpe/raw/main/purl2cpe.db.zip" def __init__( - self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False + self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str | None = None ): - self.cachedir = self.CACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR self.purl2cpe_path = str(Path(self.cachedir) / "purl2cpe") self.source_name = self.SOURCE self.error_mode = error_mode diff --git a/cve_bin_tool/data_sources/redhat_source.py b/cve_bin_tool/data_sources/redhat_source.py index 118e8a77a9..be38281050 100644 --- a/cve_bin_tool/data_sources/redhat_source.py +++ b/cve_bin_tool/data_sources/redhat_source.py @@ -24,9 +24,9 @@ class REDHAT_Source(Data_Source): CVE_ENDPOINT = "/cve.json" def __init__( - self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False + self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None, ): - self.cachedir = self.CACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR self.redhat_path = str(Path(self.cachedir) / "redhat") self.source_name = self.SOURCE diff --git a/cve_bin_tool/data_sources/rsd_source.py b/cve_bin_tool/data_sources/rsd_source.py index 413ea1486a..ee5dcf4676 100644 --- a/cve_bin_tool/data_sources/rsd_source.py +++ b/cve_bin_tool/data_sources/rsd_source.py @@ -32,9 +32,9 @@ class RSD_Source(Data_Source): RSD_API_URL = "https://gitlab.com/api/v4/projects/39314828/repository/tree" def __init__( - self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False + self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False, cachedir: str = None ): - self.cachedir = self.CACHEDIR + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR self.rsd_path = str(Path(self.cachedir) / "rsd") self.source_name = self.SOURCE diff --git a/cve_bin_tool/helper_script.py b/cve_bin_tool/helper_script.py index 54433c0eb4..c8fdfd7667 100644 --- a/cve_bin_tool/helper_script.py +++ b/cve_bin_tool/helper_script.py @@ -37,6 +37,7 @@ def __init__( product_name: str | None = None, version_number: str | None = None, string_length: int = 40, + cachedir: str = None, ): self.extractor: TempDirExtractorContext = Extractor() self.product_name = product_name @@ -45,7 +46,7 @@ def __init__( # for setting the database self.connection = None - self.dbpath = str(Path(DISK_LOCATION_DEFAULT) / DBNAME) + self.dbname = str(Path(cachedir) / DBNAME) if cachedir is not None else DISK_LOCATION_DEFAULT # for extraction self.walker = DirWalk().walk diff --git a/cve_bin_tool/input_engine.py b/cve_bin_tool/input_engine.py index 3e23f0fd87..768dad7bf2 100644 --- a/cve_bin_tool/input_engine.py +++ b/cve_bin_tool/input_engine.py @@ -16,6 +16,7 @@ from typing import Any, DefaultDict, Dict, Iterable, Set, Union from cve_bin_tool.cvedb import CVEDB +from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT from cve_bin_tool.error_handler import ( ErrorHandler, ErrorMode, @@ -66,6 +67,7 @@ def __init__( logger: Logger = None, error_mode=ErrorMode.TruncTrace, filetype="autodetect", + cachedir: str = None, ): """ Initializes the InputEngine instance. @@ -82,8 +84,9 @@ def __init__( self.error_mode = error_mode self.filetype = filetype self.parsed_data = defaultdict(dict) + self.cachedir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT # Connect to the database - self.cvedb = CVEDB(version_check=False) + self.cvedb = CVEDB(version_check=False, cachedir=self.cachedir) def parse_input(self) -> DefaultDict[ProductInfo, TriageData]: """ diff --git a/cve_bin_tool/merge.py b/cve_bin_tool/merge.py index 576bdead33..6141f938af 100644 --- a/cve_bin_tool/merge.py +++ b/cve_bin_tool/merge.py @@ -38,7 +38,7 @@ def __init__( merge_files: list[str], logger: Logger | None = None, error_mode=ErrorMode.TruncTrace, - cache_dir=DISK_LOCATION_DEFAULT, + cachedir: str = None, score=0, filter_tag=[], ): @@ -52,7 +52,7 @@ def __init__( self.total_files = 0 self.products_with_cve = 0 self.products_without_cve = 0 - self.cache_dir = cache_dir + self.cache_dir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT self.merged_files = ["tag"] self.score = score self.filter_tag = filter_tag diff --git a/cve_bin_tool/output_engine/__init__.py b/cve_bin_tool/output_engine/__init__.py index d523a3d3df..e94bd1f117 100644 --- a/cve_bin_tool/output_engine/__init__.py +++ b/cve_bin_tool/output_engine/__init__.py @@ -13,7 +13,7 @@ from pathlib import Path from cve_bin_tool.cve_scanner import CVEData -from cve_bin_tool.cvedb import CVEDB +from cve_bin_tool.cvedb import CVEDB, DISK_LOCATION_DEFAULT from cve_bin_tool.error_handler import ErrorHandler, ErrorMode from cve_bin_tool.log import LOGGER from cve_bin_tool.output_engine.console import output_console @@ -130,9 +130,10 @@ def output_pdf( exploits: bool = False, metrics: bool = False, all_product_data=None, + cachedir: str = DISK_LOCATION_DEFAULT, ): """Output a PDF of CVEs""" - cvedb_data = CVEDB() + cvedb_data = CVEDB(cachedir=cachedir) db_date = time.strftime( "%d %B %Y at %H:%M:%S", time.localtime(cvedb_data.get_db_update_date()) ) @@ -595,6 +596,7 @@ def output_pdf( affected_versions: int = 0, exploits: bool = False, all_product_data=None, + cachedir: str = DISK_LOCATION_DEFAULT ): """Output a PDF of CVEs Required module: Reportlab not found""" @@ -673,6 +675,7 @@ def __init__( vex_product_info: dict[str, str] = {}, offline: bool = False, organized_arguements: dict = None, + cachedir: str = DISK_LOCATION_DEFAULT ): """Constructor for OutputEngine class.""" self.logger = logger or LOGGER.getChild(self.__class__.__name__) @@ -705,6 +708,7 @@ def __init__( self.vex_type = vex_type self.vex_product_info = vex_product_info self.vex_filename = vex_filename + self.cachedir = cachedir def output_cves(self, outfile, output_type="console"): """Output a list of CVEs @@ -752,6 +756,7 @@ def output_cves(self, outfile, output_type="console"): self.affected_versions, self.exploits, self.metrics, + self.cachedir, ) elif output_type == "html": output_html( diff --git a/cve_bin_tool/package_list_parser.py b/cve_bin_tool/package_list_parser.py index 6e00a2a040..1529e0bf3e 100644 --- a/cve_bin_tool/package_list_parser.py +++ b/cve_bin_tool/package_list_parser.py @@ -12,6 +12,7 @@ import distro from cve_bin_tool.cvedb import CVEDB +from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT from cve_bin_tool.error_handler import ( EmptyTxtError, ErrorHandler, @@ -35,6 +36,7 @@ def __init__( input_file: str, logger: Logger = LOGGER.getChild("PackageListParser"), error_mode=ErrorMode.TruncTrace, + cachedir: str = None, ) -> None: """ Initialize the PackageListParser object. @@ -56,6 +58,7 @@ def __init__( self.parsed_data_with_vendor: Dict[Any, Any] = defaultdict(dict) self.package_names_with_vendor: List[Any] = [] self.package_names_without_vendor: List[Any] = [] + self.cache_dir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT, def parse_list(self): """ @@ -147,7 +150,7 @@ def parse_list(self): if package_name in txt_package_names: self.package_names_without_vendor.append(installed_package) - cve_db = CVEDB() + cve_db = CVEDB(cachedir=self.cache_dir) vendor_package_pairs = cve_db.get_vendor_product_pairs( self.package_names_without_vendor ) diff --git a/cve_bin_tool/parsers/__init__.py b/cve_bin_tool/parsers/__init__.py index 8394ccebd0..6fbfe5a897 100644 --- a/cve_bin_tool/parsers/__init__.py +++ b/cve_bin_tool/parsers/__init__.py @@ -4,6 +4,7 @@ from __future__ import annotations import sqlite3 +from pathlib import Path from packageurl import PackageURL @@ -41,7 +42,7 @@ class Parser: filename (str): The filename of the data to be processed. """ - def __init__(self, cve_db, logger): + def __init__(self, cve_db, logger, cachedir : str = None): """Initializes a Parser object.""" self.cve_db = cve_db self.logger = logger @@ -49,6 +50,7 @@ def __init__(self, cve_db, logger): self.purl_pkg_type = "default" self.connection: sqlite3.Connection | None = None self.dbpath = DISK_LOCATION_DEFAULT / DBNAME + self.dbname = str(Path(cachedir) / DBNAME) if cachedir is not None else DISK_LOCATION_DEFAULT def run_checker(self, filename): """ diff --git a/cve_bin_tool/sbom_manager/parse.py b/cve_bin_tool/sbom_manager/parse.py index 2827bae083..33fbf89340 100644 --- a/cve_bin_tool/sbom_manager/parse.py +++ b/cve_bin_tool/sbom_manager/parse.py @@ -14,6 +14,7 @@ from packageurl import PackageURL from cve_bin_tool.cvedb import CVEDB +from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT from cve_bin_tool.input_engine import TriageData from cve_bin_tool.log import LOGGER from cve_bin_tool.util import ( @@ -45,12 +46,15 @@ class SBOMParse: sbom_data: defaultdict[ProductInfo, TriageData] + CACHEDIR = DISK_LOCATION_DEFAULT + def __init__( self, filename: str, sbom_type: str = "spdx", logger: Logger | None = None, validate: bool = True, + cachedir: str = None, ): self.filename = filename self.sbom_data = defaultdict(dict) @@ -60,9 +64,10 @@ def __init__( self.logger = logger or LOGGER.getChild(self.__class__.__name__) self.validate = validate self.serialNumber = "" + self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR # Connect to the database - self.cvedb = CVEDB(version_check=False) + self.cvedb = CVEDB(version_check=False, cachedir=self.cachedir) def parse_sbom(self) -> dict[ProductInfo, TriageData]: """ diff --git a/cve_bin_tool/version_scanner.py b/cve_bin_tool/version_scanner.py index 2fd1982aef..9fe9108861 100644 --- a/cve_bin_tool/version_scanner.py +++ b/cve_bin_tool/version_scanner.py @@ -10,6 +10,7 @@ from cve_bin_tool.checkers import BUILTIN_CHECKERS, Checker from cve_bin_tool.cvedb import CVEDB +from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT from cve_bin_tool.egg_updater import IS_DEVELOP, update_egg from cve_bin_tool.error_handler import ErrorMode from cve_bin_tool.extractor import Extractor, TempDirExtractorContext @@ -44,6 +45,7 @@ def __init__( score: int = 0, validate: bool = True, sources=None, + cachedir: str = None, ): self.logger = logger or LOGGER.getChild(self.__class__.__name__) # Update egg if installed in development mode @@ -66,7 +68,8 @@ def __init__( self.should_extract = should_extract self.file_stack: list[str] = [] self.error_mode = error_mode - self.cve_db = CVEDB(sources=sources) + self.cache_dir = Path(cachedir) if cachedir is not None else DISK_LOCATION_DEFAULT + self.cve_db = CVEDB(sources=sources,cachedir=self.cache_dir) self.validate = validate # self.logger.info("Checkers loaded: %s" % (", ".join(self.checkers.keys()))) self.language_checkers = self.available_language_checkers() diff --git a/cve_bin_tool/version_signature.py b/cve_bin_tool/version_signature.py index ee6c5b1995..448b0009a6 100644 --- a/cve_bin_tool/version_signature.py +++ b/cve_bin_tool/version_signature.py @@ -18,7 +18,7 @@ class InvalidVersionSignatureTable(ValueError): class VersionSignatureDb: """Methods for version signature data stored in sqlite""" - def __init__(self, table_name, mapping_function, duration) -> None: + def __init__(self, table_name, mapping_function, duration,cachedir: str = None,) -> None: """Set location on disk data cache will reside. Also sets the table name and refresh duration """ @@ -28,7 +28,7 @@ def __init__(self, table_name, mapping_function, duration) -> None: self.table_name = table_name self.update_table_name = f"latest_update_{table_name}" self.mapping_function = mapping_function - self.disk_location = DISK_LOCATION_DEFAULT + self.disk_location = cachedir if cachedir is not None else DISK_LOCATION_DEFAULT self.duration = duration self.conn: sqlite3.Connection | None = None self.cursor: sqlite3.Cursor | None = None