Skip to content

Commit

Permalink
Add tests and address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Tushar Goel <[email protected]>
  • Loading branch information
TG1999 committed Feb 4, 2025
1 parent b34649b commit 85e6e08
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 12 deletions.
37 changes: 34 additions & 3 deletions vulnerabilities/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass(order=True)
@dataclasses.dataclass(frozen=True)
class VulnerabilitySeverity:
# FIXME: this should be named scoring_system, like in the model
system: ScoringSystem
Expand All @@ -65,6 +65,16 @@ def to_dict(self):
**published_at_dict,
}

def __eq__(self, other):
if not isinstance(other, VulnerabilitySeverity):
return NotImplemented
return str(self.to_dict()) == str(other.to_dict())

def __lt__(self, other):
if not isinstance(other, VulnerabilitySeverity):
return NotImplemented
return str(self.to_dict()) < str(other.to_dict())

@classmethod
def from_dict(cls, severity: dict):
"""
Expand All @@ -79,7 +89,7 @@ def from_dict(cls, severity: dict):
)


@dataclasses.dataclass(order=True)
@dataclasses.dataclass(frozen=True)
class Reference:
reference_id: str = ""
reference_type: str = ""
Expand All @@ -99,6 +109,16 @@ def normalized(self):
reference_type=self.reference_type,
)

def __eq__(self, other):
if not isinstance(other, Reference):
return NotImplemented
return str(self.to_dict()) == str(other.to_dict())

def __lt__(self, other):
if not isinstance(other, Reference):
return NotImplemented
return str(self.to_dict()) < str(other.to_dict())

def to_dict(self):
return {
"reference_id": self.reference_id,
Expand Down Expand Up @@ -140,7 +160,7 @@ class NoAffectedPackages(Exception):
"""


@dataclasses.dataclass(order=True, frozen=True)
@dataclasses.dataclass(frozen=True)
class AffectedPackage:
"""
Relate a Package URL with a range of affected versions and a fixed version.
Expand Down Expand Up @@ -170,6 +190,16 @@ def get_fixed_purl(self):
raise ValueError(f"Affected Package {self.package!r} does not have a fixed version")
return update_purl_version(purl=self.package, version=str(self.fixed_version))

def __eq__(self, other):
if not isinstance(other, AffectedPackage):
return NotImplemented
return str(self.to_dict()) == str(other.to_dict())

def __lt__(self, other):
if not isinstance(other, AffectedPackage):
return NotImplemented
return str(self.to_dict()) < str(other.to_dict())

@classmethod
def merge(
cls, affected_packages: Iterable
Expand Down Expand Up @@ -274,6 +304,7 @@ class AdvisoryData:
date_published: Optional[datetime.datetime] = None
weaknesses: List[int] = dataclasses.field(default_factory=list)
url: Optional[str] = None
created_by: Optional[str] = None

def __post_init__(self):
if self.date_published and not self.date_published.tzinfo:
Expand Down
1 change: 1 addition & 0 deletions vulnerabilities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,7 @@ def to_advisory_data(self) -> "AdvisoryData":
date_published=self.date_published,
weaknesses=self.weaknesses,
url=self.url,
created_by=self.created_by
)


Expand Down
224 changes: 224 additions & 0 deletions vulnerabilities/tests/test_compute_content_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import datetime
import pytz
from unittest import TestCase

from packageurl import PackageURL
from univers.version_range import VersionRange

from vulnerabilities.importer import AdvisoryData, AffectedPackage, Reference, VulnerabilitySeverity
from vulnerabilities.severity_systems import SCORING_SYSTEMS
from vulnerabilities.utils import compute_content_id


class TestComputeContentId(TestCase):
def setUp(self):
self.maxDiff = None
self.base_advisory = AdvisoryData(
summary="Test summary",
affected_packages=[
AffectedPackage(
package=PackageURL(
type="npm",
name="package1",
qualifiers={},
),
affected_version_range=VersionRange.from_string("vers:npm/>=1.0.0|<2.0.0"),
)
],
references=[
Reference(
url="https://example.com/vuln1",
reference_id="GHSA-1234-5678-9012",
severities=[
VulnerabilitySeverity(
system=SCORING_SYSTEMS["cvssv3.1"],
value="7.5",
)
],
)
],
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

def test_same_content_different_order_same_id(self):
"""
Test that advisories with same content but different ordering have same content ID
"""
advisory1 = self.base_advisory

# Same content but different order of references and affected packages
advisory2 = AdvisoryData(
summary="Test summary",
affected_packages=list(reversed(self.base_advisory.affected_packages)),
references=list(reversed(self.base_advisory.references)),
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

self.assertEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)

def test_different_metadata_same_content_same_id(self):
"""
Test that advisories with same content but different metadata have same content ID
when include_metadata=False
"""
advisory1 = self.base_advisory

advisory2 = AdvisoryData(
summary="Test summary",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
created_by="different_importer",
url="https://different.url",
)

self.assertEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)

def test_different_metadata_different_id_when_included(self):
"""
Test that advisories with same content but different metadata have different content IDs
when include_metadata=True
"""
advisory1 = self.base_advisory

advisory2 = AdvisoryData(
summary="Test summary",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
created_by="different_importer",
url="https://different.url",
)

self.assertNotEqual(
compute_content_id(advisory1, include_metadata=True),
compute_content_id(advisory2, include_metadata=True),
)

def test_different_summary_different_id(self):
"""
Test that advisories with different summaries have different content IDs
"""
advisory1 = self.base_advisory

advisory2 = AdvisoryData(
summary="Different summary",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

self.assertNotEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)

def test_different_affected_packages_different_id(self):
"""
Test that advisories with different affected packages have different content IDs
"""
advisory1 = self.base_advisory

advisory2 = AdvisoryData(
summary="Test summary",
affected_packages=[
AffectedPackage(
package=PackageURL(
type="npm",
name="different-package",
qualifiers={},
),
affected_version_range=VersionRange.from_string("vers:npm/>=1.0.0|<2.0.0"),
)
],
references=self.base_advisory.references,
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

self.assertNotEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)

def test_different_references_different_id(self):
"""
Test that advisories with different references have different content IDs
"""
advisory1 = self.base_advisory

advisory2 = AdvisoryData(
summary="Test summary",
affected_packages=self.base_advisory.affected_packages,
references=[
Reference(
url="https://example.com/different-vuln",
reference_id="GHSA-9999-9999-9999",
severities=[
VulnerabilitySeverity(
system=SCORING_SYSTEMS["cvssv3.1"],
value="8.5",
)
],
)
],
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

self.assertNotEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)

def test_different_weaknesses_different_id(self):
"""
Test that advisories with different weaknesses have different content IDs
"""
advisory1 = AdvisoryData(
summary="Test summary",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
weaknesses=[1, 2, 3],
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

advisory2 = AdvisoryData(
summary="Test summary",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
weaknesses=[4, 5, 6],
date_published=datetime.datetime(2024, 1, 1, tzinfo=pytz.UTC),
)

self.assertNotEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)

def test_empty_fields_same_id(self):
"""
Test that advisories with empty optional fields still generate same content ID
"""
advisory1 = AdvisoryData(
summary="",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
date_published=None,
)

advisory2 = AdvisoryData(
summary="",
affected_packages=self.base_advisory.affected_packages,
references=self.base_advisory.references,
date_published=None,
)

self.assertEqual(
compute_content_id(advisory1),
compute_content_id(advisory2),
)
21 changes: 12 additions & 9 deletions vulnerabilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ def normalize_purl(purl: Union[PackageURL, str]):
return PackageURL.from_string(purl)



def compute_content_id(advisory_data, include_metadata=False):
"""
Computes a unique content_id for an advisory by normalizing its data and hashing it.
Expand All @@ -548,24 +547,28 @@ def compute_content_id(advisory_data, include_metadata=False):
:param include_metadata: Boolean indicating whether to include `created_by` and `url`
:return: SHA-256 hash digest as content_id
"""

def normalize_text(text):
"""Normalize text by removing spaces and converting to lowercase."""
return text.replace(" ", "").lower() if text else ""

def normalize_affected_packages(packages):
"""Normalize a list of AffectedPackage objects"""
if not packages:
return []
return sorted([pkg.to_dict() for pkg in packages])

def normalize_list(lst):
"""Sort a list to ensure consistent ordering."""
return sorted(lst) if lst else []

def normalize_dict(obj):
"""Ensure dictionary keys are ordered."""
return json.loads(json.dumps(obj, sort_keys=True)) if obj else {}


def normalize_references(references):
"""Normalize a list of references"""
return sorted([ref.to_dict() for ref in references])
# Normalize fields
normalized_data = {
"summary": normalize_text(advisory_data.summary),
"affected_packages": normalize_list(advisory_data.affected_packages),
"references": normalize_list(advisory_data.references),
"affected_packages": normalize_affected_packages(advisory_data.affected_packages),
"references": normalize_references(advisory_data.references),
"weaknesses": normalize_list(advisory_data.weaknesses),
}

Expand Down

0 comments on commit 85e6e08

Please sign in to comment.