diff --git a/vulnerabilities/forms.py b/vulnerabilities/forms.py
index a00885637..103c2a46c 100644
--- a/vulnerabilities/forms.py
+++ b/vulnerabilities/forms.py
@@ -14,14 +14,36 @@
class PackageSearchForm(forms.Form):
-
search = forms.CharField(
required=True,
widget=forms.TextInput(
- attrs={"placeholder": "Package name, purl or purl fragment"},
+ attrs={
+ "placeholder": "Package name, purl or purl fragment",
+ },
+ ),
+ )
+ type = forms.CharField(required=False, max_length=50)
+ vulnerable_only = forms.ChoiceField(
+ required=False,
+ choices=(
+ ("", "All Packages"),
+ ("true", "Vulnerable Only"),
+ ("false", "Non-Vulnerable Only"),
),
)
+ def clean_search(self):
+ """Sanitize the search input which provide extra layer of protection from XSS attacks"""
+ search = self.cleaned_data["search"].strip()
+ if not search:
+ raise forms.ValidationError("Search field cannot be empty")
+ return search
+
+ def clean_type(self):
+ """Sanitize the type input which provide extra layer of protection from XSS attacks"""
+ pkg_type = self.cleaned_data["type"].strip()
+ return pkg_type
+
class VulnerabilitySearchForm(forms.Form):
diff --git a/vulnerabilities/templates/packages.html b/vulnerabilities/templates/packages.html
index 1f7687429..535213e54 100644
--- a/vulnerabilities/templates/packages.html
+++ b/vulnerabilities/templates/packages.html
@@ -18,6 +18,12 @@
{{ page_obj.paginator.count|intcomma }} results
+
+
{% if is_paginated %}
{% include 'includes/pagination.html' with page_obj=page_obj %}
{% endif %}
@@ -81,4 +87,28 @@
{% endif %}
-{% endblock %}
+
+ {% endblock %}
diff --git a/vulnerabilities/tests/test_view.py b/vulnerabilities/tests/test_view.py
index 3b32ee31c..0e0458dee 100644
--- a/vulnerabilities/tests/test_view.py
+++ b/vulnerabilities/tests/test_view.py
@@ -127,6 +127,26 @@ def test_package_view_with_valid_purl_without_version(self):
"pkg:nginx/nginx@1.9.5",
]
+ def test_package_search_vulnerable_only_filter(self):
+ vulnerable_pkg = Package.objects.create(type="npm", name="vulnerable-pkg", version="1.0.0")
+ non_vulnerable_pkg = Package.objects.create(
+ type="npm", name="non-vulnerable-pkg", version="2.0.0"
+ )
+ vuln = Vulnerability.objects.create(
+ vulnerability_id="VCID-123", summary="test vulnerability"
+ )
+ AffectedByPackageRelatedVulnerability.objects.create(
+ package=vulnerable_pkg, vulnerability=vuln
+ )
+ self.assertTrue(
+ AffectedByPackageRelatedVulnerability.objects.filter(package=vulnerable_pkg).exists()
+ )
+ self.assertFalse(
+ AffectedByPackageRelatedVulnerability.objects.filter(
+ package=non_vulnerable_pkg
+ ).exists()
+ )
+
def test_package_view_with_valid_purl_and_incomplete_version(self):
qs = PackageSearch().get_queryset(query="pkg:nginx/nginx@1")
pkgs = list(qs)
diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py
index a2df48634..6a2624fc6 100644
--- a/vulnerabilities/views.py
+++ b/vulnerabilities/views.py
@@ -58,12 +58,19 @@ def get_queryset(self, query=None):
on exact purl, partial purl or just name and namespace.
"""
query = query or self.request.GET.get("search") or ""
- return (
+ queryset = (
self.model.objects.search(query)
.with_vulnerability_counts()
.prefetch_related()
.order_by("package_url")
)
+ if hasattr(self, "request"):
+ vulnerable_only = self.request.GET.get("vulnerable_only", "").lower()
+ if vulnerable_only in ["true", "false"]:
+ queryset = queryset.with_is_vulnerable()
+ queryset = queryset.filter(is_vulnerable=vulnerable_only == "true")
+
+ return queryset
class VulnerabilitySearch(ListView):