diff --git a/cve_bin_tool/output_engine/util.py b/cve_bin_tool/output_engine/util.py index 3ded2d86fc..d74dbd5cb4 100644 --- a/cve_bin_tool/output_engine/util.py +++ b/cve_bin_tool/output_engine/util.py @@ -282,8 +282,8 @@ def intermediate_output( def add_extension_if_not(filename: str, output_type: str) -> str: """ - summary: Checks if the filename ends with the extension and if not - adds one. And if the filename ends with a different extension it replaces the extension. + Handles both replacement of invalid extensions (for known types) + and appending for completely unknown extensions. Args: filename (str): filename from OutputEngine @@ -292,18 +292,40 @@ def add_extension_if_not(filename: str, output_type: str) -> str: Returns: str: Filename with extension according to output_type """ - import re + # Map all output types to their valid extensions + extensions = { + "json": ["json"], + "cyclonedx": ["json", "xml"], + "csv": ["csv"], + "html": ["html"], + "pdf": ["pdf"], + "txt": ["txt"], + } + + # Create set of ALL valid extensions for recognition + all_valid_extensions = {ext for exts in extensions.values() for ext in exts} + + # Get valid extensions for current output type + valid_ext = extensions.get(output_type, []) - extensions = ["json", "csv", "html", "pdf", "txt"] - for extension in extensions: - if not filename.endswith(f".{extension}"): - continue - if extension == output_type: + # Split filename + if "." in filename: + name, ext = filename.rsplit(".", 1) + # Check if extension is either: + # 1. Valid for current type -> keep + # 2. Valid for another type -> replace + # 3. Invalid everywhere -> append + if ext in valid_ext: return filename - filename = re.sub(f".{extension}$", f".{output_type}", filename) - return filename - filename = f"{filename}.{output_type}" - return filename + elif ext in all_valid_extensions: + # Replace with first valid extension for current type + return f"{name}.{valid_ext[0]}" + else: + # Append first valid extension for current type + return f"{filename}.{valid_ext[0]}" + else: + # No extension - append first valid one + return f"{filename}.{valid_ext[0]}" if valid_ext else filename def group_cve_by_remark( @@ -317,7 +339,7 @@ def group_cve_by_remark( { "cve_number": "CVE-XXX-XXX", "severity": "High", - "decription: "Lorem Ipsm", + "description: "Lorem Ipsm", }, {...} ], diff --git a/cve_bin_tool/sbom_manager/generate.py b/cve_bin_tool/sbom_manager/generate.py index a3dad6c9fc..efd7b9a382 100644 --- a/cve_bin_tool/sbom_manager/generate.py +++ b/cve_bin_tool/sbom_manager/generate.py @@ -11,6 +11,7 @@ from lib4sbom.sbom import SBOM from cve_bin_tool.log import LOGGER +from cve_bin_tool.output_engine.util import add_extension_if_not from cve_bin_tool.util import strip_path from cve_bin_tool.version import VERSION @@ -46,6 +47,10 @@ def __init__( def generate_sbom(self) -> None: """Create SBOM package and generate SBOM file.""" + # Force .json extension for CycloneDX only if not already specified + if self.sbom_type == "cyclonedx": + self.filename = add_extension_if_not(self.filename, "json") + # Create SBOM sbom_relationships = [] my_package = SBOMPackage() diff --git a/cve_bin_tool/sbom_manager/parse.py b/cve_bin_tool/sbom_manager/parse.py index 49ad5551cb..9ce8d58372 100644 --- a/cve_bin_tool/sbom_manager/parse.py +++ b/cve_bin_tool/sbom_manager/parse.py @@ -3,6 +3,7 @@ from __future__ import annotations +import json import re import sys from collections import defaultdict @@ -16,13 +17,7 @@ from cve_bin_tool.cvedb import CVEDB from cve_bin_tool.input_engine import TriageData from cve_bin_tool.log import LOGGER -from cve_bin_tool.util import ( - ProductInfo, - Remarks, - decode_cpe22, - decode_cpe23, - validate_serialNumber, -) +from cve_bin_tool.util import ProductInfo, Remarks, decode_cpe22, decode_cpe23 from cve_bin_tool.validator import validate_cyclonedx, validate_spdx, validate_swid @@ -77,12 +72,22 @@ def parse_sbom(self) -> dict[ProductInfo, TriageData]: modules = [] try: if Path(self.filename).exists(): + # Validate CycloneDX JSON or XML extension + if self.type == "cyclonedx" and not ( + self.filename.lower().endswith(".json") + or self.filename.lower().endswith(".xml") + ): + self.logger.error( + "CycloneDX SBOMs require .json or .xml extension." + ) + return {} + if self.type == "swid": modules = self.parse_swid(self.filename) else: modules = self.parse_cyclonedx_spdx() except (KeyError, FileNotFoundError, ET.ParseError) as e: - LOGGER.debug(e, exc_info=True) + self.logger.debug(e, exc_info=True) LOGGER.debug( f"The number of modules identified in SBOM - {len(modules)}\n{modules}" @@ -147,7 +152,7 @@ def common_prefix_split(self, product, version) -> list[ProductInfo]: if not found_common_prefix: # if vendor not found after removing common prefix try splitting it LOGGER.debug( - f"No Vendor found for {product}, trying splitted product. " + f"No Vendor found for {product}, trying split product. " "Some results may be inaccurate due to vendor identification limitations." ) splitted_product = product.split("-") @@ -217,31 +222,45 @@ def parse_cyclonedx_spdx(self) -> [(str, str, str)]: Returns: - List[(str, str, str)]: A list of tuples, each containing vendor, product, and version information for a module. - """ + # Validate CycloneDX JSON or XML extension + if self.type == "cyclonedx" and not ( + self.filename.lower().endswith(".json") + or self.filename.lower().endswith(".xml") + ): + self.logger.error( + f"CycloneDX SBOMs require .json or .xml extension. Invalid file: {self.filename}" + ) + return [] + + # Validate JSON content for CycloneDX JSON files + if self.type == "cyclonedx" and self.filename.lower().endswith(".json"): + try: + with open(self.filename, encoding="utf-8") as f: + json.load(f) # Basic JSON validation + except json.JSONDecodeError as e: + self.logger.error(f"Invalid JSON in CycloneDX SBOM: {str(e)}") + return [] # Set up SBOM parser sbom_parser = SBOMParser(sbom_type=self.type) # Load SBOM sbom_parser.parse_file(self.filename) doc = sbom_parser.get_document() - uuid = doc.get("uuid", "") if self.type == "cyclonedx": - parts = uuid.split(":") - if len(parts) == 3 and parts[0] == "urn" and parts[1] == "uuid": - serialNumber = parts[2] - if validate_serialNumber(serialNumber): + # Extract serialNumber (optional in CycloneDX spec) + serialNumber = doc.get("serialNumber", "").lower() + if serialNumber: # Only validate if present + if re.match( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", + serialNumber, + ): self.serialNumber = serialNumber else: - LOGGER.error( + LOGGER.warning( # Downgrade to warning f"The SBOM file '{self.filename}' has an invalid serial number." ) - return [] - else: - LOGGER.error( - f"The SBOM file '{self.filename}' has an invalid serial number." - ) - return [] + # Do NOT return early; continue parsing components modules = [] if self.validate and self.filename.endswith(".xml"): @@ -281,7 +300,7 @@ def parse_cyclonedx_spdx(self) -> [(str, str, str)]: # Found at least package and version, save the results modules.append([vendor, package_name, version]) - LOGGER.debug(f"Parsed SBOM {self.filename} {modules}") + LOGGER.debug(f"SBOM Data {self.sbom_data}") return modules def parse_swid(self, sbom_file: str) -> list[list[str]]: @@ -372,7 +391,7 @@ def decode_purl(self, purl) -> (str | None, str | None, str | None): - purl (str): Package URL (purl) string. Returns: - - Tuple[str | None, str | None, str | None]: A tuple containing the vendor (which is always None for purl), + - Tuple[str | None, str | None, str | None]]: A tuple containing the vendor (which is always None for purl), product, and version information extracted from the purl string, or None if the purl is invalid or incomplete. """ diff --git a/cve_bin_tool/sbom_manager/sbom_detection.py b/cve_bin_tool/sbom_manager/sbom_detection.py index 05ef29f63b..908c68a4c9 100644 --- a/cve_bin_tool/sbom_manager/sbom_detection.py +++ b/cve_bin_tool/sbom_manager/sbom_detection.py @@ -2,51 +2,163 @@ # SPDX-License-Identifier: GPL-3.0-or-later import json +from typing import Optional +from urllib.parse import urlparse import defusedxml.ElementTree as ET -from cve_bin_tool.validator import validate_cyclonedx, validate_swid +from cve_bin_tool.log import LOGGER +from cve_bin_tool.validator import validate_cyclonedx -def sbom_detection(file_path: str) -> str: +def sbom_detection(file_path: str) -> Optional[str]: """ - Identifies SBOM type of file based on its format and schema. + Identifies SBOM type with content validation and extension checks. + Returns 'spdx', 'cyclonedx', 'swid', or None if the SBOM type cannot be determined. Args: - file_path (str): The path to the file. + file_path (str): Path to the SBOM file. Returns: - str: The detected SBOM type (spdx, cyclonedx, swid) or None. + Optional[str]: The detected SBOM type or None if detection fails. """ try: - with open(file_path) as file: - if ".spdx" in file_path: - return "spdx" + # Check for CycloneDX JSON format + if file_path.lower().endswith(".json"): + try: + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + # Check for CycloneDX-specific structure + if isinstance(data, dict): + if ( + data.get("bomFormat") == "CycloneDX" + and "components" in data + ): + return "cyclonedx" + # Fallback: Check for required fields + if "specVersion" in data.get( + "bom", {} + ) and "version" in data.get("bom", {}): + LOGGER.warning( + f"Possible CycloneDX SBOM with non-standard structure: {file_path}" + ) + return "cyclonedx" + except (json.JSONDecodeError, UnicodeDecodeError) as e: + LOGGER.debug(f"JSON parsing error for {file_path}: {str(e)}") + pass # Not JSON, continue with other checks - elif file_path.endswith(".json"): - data = json.load(file) + # Check XML-based formats with namespace validation + if file_path.endswith(".xml"): + try: + tree = ET.parse(file_path) + root = tree.getroot() + namespace_uri = ( + root.tag.split("}", 1)[0].strip("{") if "}" in root.tag else "" + ) + + # Check CycloneDX namespace + parsed_uri = urlparse(namespace_uri) + domain = parsed_uri.netloc.lower() if ( - "bomFormat" in data - and "specVersion" in data - and data["bomFormat"] == "CycloneDX" - ): + domain == "cyclonedx.org" or domain.endswith(".cyclonedx.org") + ) and validate_cyclonedx(file_path): return "cyclonedx" + # Check SWID by root tag and namespace + elif ( + root.tag.endswith("SoftwareIdentity") + and "iso/19770" in namespace_uri + ): + return "swid" + except ET.ParseError as e: + LOGGER.debug(f"XML parsing error for {file_path}: {str(e)}") + return None + + # SPDX detection (case-insensitive and path check) + if any( + ext in file_path.lower() + for ext in [".spdx", ".spdx.json", ".spdx.xml", ".spdx.yml", ".spdx.yaml"] + ): + return "spdx" + + except Exception as e: + LOGGER.error(f"SBOM detection failed for {file_path}: {str(e)}") + return None + + +def detect_sbom_type_from_content(file_path: str) -> Optional[str]: + """ + Detects SBOM type by analyzing file content without relying on file extensions. + This is a fallback method if the primary detection fails. + + Args: + file_path (str): Path to the SBOM file. + + Returns: + Optional[str]: The detected SBOM type or None if detection fails. + """ + try: + with open(file_path, "rb") as f: + content = f.read(1024) # Read first 1KB for analysis - else: - return None + # Check for JSON content + if content.startswith(b"{"): + try: + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, dict) and data.get("bomFormat") == "CycloneDX": + return "cyclonedx" + elif "SPDXID" in data or "spdxVersion" in data: + return "spdx" + except json.JSONDecodeError: + pass - elif file_path.endswith(".xml"): + # Check for XML content + if content.startswith(b" Optional[str]: + """ + Detects SBOM type using both file extension and content-based methods. + This is the main function to be used for SBOM detection. + + Args: + file_path (str): Path to the SBOM file. + + Returns: + Optional[str]: The detected SBOM type or None if detection fails. + """ + # First, try detection based on file extension and content + sbom_type = sbom_detection(file_path) + if sbom_type: + return sbom_type - except (json.JSONDecodeError, ET.ParseError): - return None + # If primary detection fails, try content-based detection as a fallback + return detect_sbom_type_from_content(file_path) diff --git a/test/test_sbom.py b/test/test_sbom.py index 527b345f6b..ba75c6b393 100644 --- a/test/test_sbom.py +++ b/test/test_sbom.py @@ -7,6 +7,7 @@ import pytest from cve_bin_tool.input_engine import TriageData +from cve_bin_tool.output_engine.util import add_extension_if_not from cve_bin_tool.sbom_manager.parse import SBOMParse from cve_bin_tool.sbom_manager.sbom_detection import sbom_detection from cve_bin_tool.util import ProductInfo, Remarks @@ -96,16 +97,41 @@ def test_nonexistent_file(self, filepath: str): assert sbom_engine.parse_sbom() == {} @pytest.mark.parametrize( - "filename, sbom_type", + "filename, sbom_type, expected_log", ( - (str(SBOM_PATH / "bad.csv"), "spdx"), - (str(SBOM_PATH / "bad.csv"), "cyclonedx"), - (str(SBOM_PATH / "bad.csv"), "swid"), + ( + str(SBOM_PATH / "bad.csv"), + "spdx", + None, + ), # SPDX doesn't enforce extensions + ( + str(SBOM_PATH / "bad.csv"), + "cyclonedx", + "CycloneDX SBOMs require .json or .xml extension.", + ), # CycloneDX enforces extensions + ( + str(SBOM_PATH / "bad.csv"), + "swid", + None, + ), # SWID doesn't enforce extensions ), ) - def test_invalid_file(self, filename: str, sbom_type: str): - sbom_engine = SBOMParse(filename, sbom_type) - assert sbom_engine.parse_sbom() == {} + def test_invalid_file( + self, filename: str, sbom_type: str, expected_log: str | None, caplog + ): + """Test handling of invalid files with incorrect extensions.""" + with caplog.at_level("ERROR"): # Capture logs at ERROR level + sbom_engine = SBOMParse(filename, sbom_type) + result = sbom_engine.parse_sbom() + + # Ensure the result is an empty dictionary + assert result == {} + + # Check for the expected log message + if expected_log: + assert expected_log in caplog.text + else: + assert not caplog.text # No logs expected for other types @pytest.mark.parametrize( "filename, sbom_type", @@ -265,3 +291,57 @@ def test_invalid_xml(self, filename: str, sbom_type: str, validate: bool): ) def test_sbom_detection(self, filename: str, expected_sbom_type: str): assert sbom_detection(filename) == expected_sbom_type + + +class TestFileExtension: + """ + Tests for the add_extension_if_not function ensuring that filenames + have the correct extension for the given SBOM output type. + """ + + def test_add_extension_no_existing(self): + assert add_extension_if_not("testfile", "cyclonedx") == "testfile.json" + assert add_extension_if_not("testfile", "csv") == "testfile.csv" + + def test_add_extension_existing_correct(self): + assert add_extension_if_not("testfile.json", "cyclonedx") == "testfile.json" + assert add_extension_if_not("testfile.xml", "cyclonedx") == "testfile.xml" + assert add_extension_if_not("testfile.csv", "csv") == "testfile.csv" + + def test_add_extension_existing_wrong(self): + assert add_extension_if_not("testfile.txt", "cyclonedx") == "testfile.json" + assert add_extension_if_not("testfile.pdf", "csv") == "testfile.csv" + + def test_extension_not_in_known_extensions(self): + assert add_extension_if_not("testfile.xyz", "cyclonedx") == "testfile.xyz.json" + assert add_extension_if_not("data.unknown", "csv") == "data.unknown.csv" + + +class TestCycloneDXXMLSupport: + """Tests for CycloneDX XML file support.""" + + SBOM_PATH = Path(__file__).parent.resolve() / "sbom" + + @pytest.mark.parametrize( + "filename, expected_valid", + [ + (str(SBOM_PATH / "cyclonedx_test.xml"), True), # Valid CycloneDX XML + ], + ) + def test_cyclonedx_xml_support(self, filename, expected_valid): + """Test CycloneDX XML file parsing.""" + sbom_engine = SBOMParse(filename, sbom_type="cyclonedx") + result = sbom_engine.parse_sbom() + assert (len(result) > 0) == expected_valid + + @pytest.mark.parametrize( + "filename, expected_valid", + [ + (str(SBOM_PATH / "cyclonedx_test.xml"), True), # Valid CycloneDX XML + ], + ) + def test_cyclonedx_xml_serial_number(self, filename, expected_valid): + """Test CycloneDX XML serialNumber validation.""" + sbom_engine = SBOMParse(filename, sbom_type="cyclonedx") + result = sbom_engine.parse_sbom() + assert (len(result) > 0) == expected_valid