diff --git a/cve_bin_tool/cvedb.py b/cve_bin_tool/cvedb.py index df8de1aefa..017ff2e776 100644 --- a/cve_bin_tool/cvedb.py +++ b/cve_bin_tool/cvedb.py @@ -218,6 +218,12 @@ class CVEDB: VALUES (?, ?) """, } + METRICS = [ + (UNKNOWN_METRIC_ID, "UNKNOWN"), + (EPSS_METRIC_ID, "EPSS"), + (CVSS_2_METRIC_ID, "CVSS-2"), + (CVSS_3_METRIC_ID, "CVSS-3"), + ] def __init__( self, @@ -310,6 +316,15 @@ def refresh_cache_and_update_db(self) -> None: self.create_exploit_db() self.update_exploits() + # Check if metrics need to be updated + cursor = self.db_open_and_get_cursor() + if not self.latest_schema("metrics", self.TABLE_SCHEMAS["metrics"], cursor): + self.LOGGER.info("Updating metrics data.") + self.populate_metrics() + if self.connection is not None: + self.connection.commit() + self.db_close() + def get_cvelist_if_stale(self) -> None: """Update if the local db is more than one day old. This avoids the full slow update with every execution. @@ -333,6 +348,7 @@ def get_cvelist_if_stale(self) -> None: or not self.latest_schema( "cve_exploited", self.TABLE_SCHEMAS["cve_exploited"] ) + or not self.latest_schema("metrics", self.TABLE_SCHEMAS["metrics"]) ): self.refresh_cache_and_update_db() self.time_of_last_update = datetime.datetime.today() @@ -360,7 +376,6 @@ def latest_schema( # getting schema from command lines = table_schema.split("(")[1].split(",") - table_schema = [x.split("\n")[1].strip().split(" ")[0] for x in lines] table_schema.pop() @@ -370,7 +385,16 @@ def latest_schema( if table_schema == current_schema: schema_latest = True - # check for cve_ + # Check for metrics table data integrity + if table_name == "metrics": + for metric_id, metric_name in self.METRICS: + result = cursor.execute( + "SELECT * FROM metrics WHERE metrics_id=? AND metrics_name=?", + (metric_id, metric_name), + ) + if not result.fetchone(): + schema_latest = False + break # Early exit if any metric is missing return schema_latest @@ -626,18 +650,12 @@ def populate_affected(self, affected_data, cursor, data_source): def populate_metrics(self): """Adding data to metric table.""" cursor = self.db_open_and_get_cursor() - # Insert a row without specifying cve_metrics_id insert_metrics = self.INSERT_QUERIES["insert_metrics"] - data = [ - (UNKNOWN_METRIC_ID, "UNKNOWN"), - (EPSS_METRIC_ID, "EPSS"), - (CVSS_2_METRIC_ID, "CVSS-2"), - (CVSS_3_METRIC_ID, "CVSS-3"), - ] - # Execute the insert query for each row - for row in data: + # Use the METRICS constant to populate the table + for row in self.METRICS: cursor.execute(insert_metrics, row) - self.connection.commit() + if self.connection is not None: + self.connection.commit() self.db_close() def metric_finder(self, cursor, cve): @@ -855,7 +873,8 @@ def create_exploit_db(self): create_exploit_table = self.TABLE_SCHEMAS["cve_exploited"] cursor = self.db_open_and_get_cursor() cursor.execute(create_exploit_table) - self.connection.commit() + if self.connection is not None: + self.connection.commit() self.db_close() def populate_exploit_db(self, exploits): @@ -863,7 +882,8 @@ def populate_exploit_db(self, exploits): insert_exploit = self.INSERT_QUERIES["insert_exploit"] cursor = self.db_open_and_get_cursor() cursor.executemany(insert_exploit, exploits) - self.connection.commit() + if self.connection is not None: + self.connection.commit() self.db_close() def store_epss_data(self, epss_data): @@ -871,7 +891,8 @@ def store_epss_data(self, epss_data): insert_cve_metrics = self.INSERT_QUERIES["insert_cve_metrics"] cursor = self.db_open_and_get_cursor() cursor.executemany(insert_cve_metrics, epss_data) - self.connection.commit() + if self.connection is not None: + self.connection.commit() self.db_close() def dict_factory(self, cursor, row): @@ -1143,7 +1164,8 @@ def json_to_db_wrapper(self, path, pubkey, ignore_signature, log_signature_error shutil.rmtree(temp_gnupg_home) return ERROR_CODES[SigningError] self.json_to_db(cursor, dir, json.loads(data)) - self.connection.commit() + if self.connection is not None: + self.connection.commit() if is_signed and not ignore_signature and temp_gnupg_home.exists(): shutil.rmtree(temp_gnupg_home) diff --git a/test/test_cvedb.py b/test/test_cvedb.py index 80742ebbb4..c206341707 100644 --- a/test/test_cvedb.py +++ b/test/test_cvedb.py @@ -10,6 +10,7 @@ from cve_bin_tool import cvedb from cve_bin_tool.cli import main +from cve_bin_tool.cvedb import UNKNOWN_METRIC_ID from cve_bin_tool.data_sources import nvd_source @@ -91,3 +92,25 @@ def test_new_database_schema(self): assert all(column in column_names for column in required_columns[table]) self.cvedb.db_close() + + @pytest.mark.skipif(not EXTERNAL_SYSTEM(), reason="Skipping NVD calls") + def test_missing_unknown_metric_after_update(self): + self.cvedb.init_database() + + with self.cvedb.with_cursor() as cursor: + cursor.execute( + "DELETE FROM metrics WHERE metrics_id = ?", (UNKNOWN_METRIC_ID,) + ) + self.cvedb.connection.commit() + + # Trigger schema repair + self.cvedb.get_cvelist_if_stale() + + # Verify that the UNKNOWN_METRIC exists after the update + with self.cvedb.with_cursor() as cursor: + cursor.execute( + "SELECT * FROM metrics WHERE metrics_id = ?", (UNKNOWN_METRIC_ID,) + ) + result = cursor.fetchone() + assert result is not None, "UNKNOWN_METRIC was not restored after update" + assert result[1] == "UNKNOWN", "Incorrect metric name for UNKNOWN_METRIC"