Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cvedb metric refactoring #4955

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions cve_bin_tool/cvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ class CVEDB:
VALUES (?, ?)
""",
}
METRICS = [
(UNKNOWN_METRIC_ID, "UNKNOWN"),
(EPSS_METRIC_ID, "EPSS"),
(CVSS_2_METRIC_ID, "CVSS-2"),
(CVSS_3_METRIC_ID, "CVSS-3"),
]

def __init__(
self,
Expand Down Expand Up @@ -310,6 +316,15 @@ def refresh_cache_and_update_db(self) -> None:
self.create_exploit_db()
self.update_exploits()

# Check if metrics need to be updated
cursor = self.db_open_and_get_cursor()
if not self.latest_schema("metrics", self.TABLE_SCHEMAS["metrics"], cursor):
self.LOGGER.info("Updating metrics data.")
self.populate_metrics()
if self.connection is not None:
self.connection.commit()
self.db_close()

def get_cvelist_if_stale(self) -> None:
"""Update if the local db is more than one day old.
This avoids the full slow update with every execution.
Expand All @@ -333,6 +348,7 @@ def get_cvelist_if_stale(self) -> None:
or not self.latest_schema(
"cve_exploited", self.TABLE_SCHEMAS["cve_exploited"]
)
or not self.latest_schema("metrics", self.TABLE_SCHEMAS["metrics"])
):
self.refresh_cache_and_update_db()
self.time_of_last_update = datetime.datetime.today()
Expand Down Expand Up @@ -360,7 +376,6 @@ def latest_schema(

# getting schema from command
lines = table_schema.split("(")[1].split(",")

table_schema = [x.split("\n")[1].strip().split(" ")[0] for x in lines]
table_schema.pop()

Expand All @@ -370,7 +385,16 @@ def latest_schema(
if table_schema == current_schema:
schema_latest = True

# check for cve_
# Check for metrics table data integrity
if table_name == "metrics":
for metric_id, metric_name in self.METRICS:
result = cursor.execute(
"SELECT * FROM metrics WHERE metrics_id=? AND metrics_name=?",
(metric_id, metric_name),
)
if not result.fetchone():
schema_latest = False
break # Early exit if any metric is missing

return schema_latest

Expand Down Expand Up @@ -626,18 +650,12 @@ def populate_affected(self, affected_data, cursor, data_source):
def populate_metrics(self):
"""Adding data to metric table."""
cursor = self.db_open_and_get_cursor()
# Insert a row without specifying cve_metrics_id
insert_metrics = self.INSERT_QUERIES["insert_metrics"]
data = [
(UNKNOWN_METRIC_ID, "UNKNOWN"),
(EPSS_METRIC_ID, "EPSS"),
(CVSS_2_METRIC_ID, "CVSS-2"),
(CVSS_3_METRIC_ID, "CVSS-3"),
]
# Execute the insert query for each row
for row in data:
# Use the METRICS constant to populate the table
for row in self.METRICS:
cursor.execute(insert_metrics, row)
self.connection.commit()
if self.connection is not None:
self.connection.commit()
self.db_close()

def metric_finder(self, cursor, cve):
Expand Down Expand Up @@ -855,23 +873,26 @@ def create_exploit_db(self):
create_exploit_table = self.TABLE_SCHEMAS["cve_exploited"]
cursor = self.db_open_and_get_cursor()
cursor.execute(create_exploit_table)
self.connection.commit()
if self.connection is not None:
self.connection.commit()
self.db_close()

def populate_exploit_db(self, exploits):
"""Add exploits to the exploits database table."""
insert_exploit = self.INSERT_QUERIES["insert_exploit"]
cursor = self.db_open_and_get_cursor()
cursor.executemany(insert_exploit, exploits)
self.connection.commit()
if self.connection is not None:
self.connection.commit()
self.db_close()

def store_epss_data(self, epss_data):
"""Insert Exploit Prediction Scoring System (EPSS) data into database."""
insert_cve_metrics = self.INSERT_QUERIES["insert_cve_metrics"]
cursor = self.db_open_and_get_cursor()
cursor.executemany(insert_cve_metrics, epss_data)
self.connection.commit()
if self.connection is not None:
self.connection.commit()
self.db_close()

def dict_factory(self, cursor, row):
Expand Down Expand Up @@ -1143,7 +1164,8 @@ def json_to_db_wrapper(self, path, pubkey, ignore_signature, log_signature_error
shutil.rmtree(temp_gnupg_home)
return ERROR_CODES[SigningError]
self.json_to_db(cursor, dir, json.loads(data))
self.connection.commit()
if self.connection is not None:
self.connection.commit()

if is_signed and not ignore_signature and temp_gnupg_home.exists():
shutil.rmtree(temp_gnupg_home)
Expand Down
Loading