Skip to content

Commit

Permalink
Merge pull request #40 from CybercentreCanada/feature/whitelist
Browse files Browse the repository at this point in the history
Feature/whitelist
  • Loading branch information
cccs-rs authored Jun 24, 2021
2 parents 196459f + 2ccac8f commit b56c8e6
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 11 deletions.
35 changes: 35 additions & 0 deletions assemblyline_service_server/api/v1/safelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

from assemblyline_service_server.api.base import make_subapi_blueprint, make_api_response, api_login
from assemblyline_service_server.config import STORAGE

SUB_API = 'safelist'
safelist_api = make_subapi_blueprint(SUB_API, api_version=1)
safelist_api._doc = "Query safelisted hashes"


@safelist_api.route("/<qhash>/", methods=["GET"])
@api_login()
def exists(qhash, **_):
"""
Check if a file exists in the safelist.
Variables:
qhash => Hash to check
Arguments:
None
Data Block:
None
API call example:
GET /api/v1/safelist/123456...654321/
Result example:
<Safelisting object>
"""
safelist = STORAGE.safelist.get_if_exists(qhash, as_obj=False)
if safelist:
return make_api_response(safelist)

return make_api_response(None, "The hash was not found in the safelist.", 404)
29 changes: 18 additions & 11 deletions assemblyline_service_server/api/v1/task.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import time
from typing import cast, Dict, Any

from assemblyline.common.dict_utils import flatten, unflatten
from assemblyline.common.heuristics import service_heuristic_to_result_heuristic, InvalidHeuristicException

from assemblyline.common.isotime import now_as_iso
from typing import cast, Dict, Any
from flask import request

from assemblyline.common import forge
from assemblyline.common.constants import SERVICE_STATE_HASH, ServiceStatus
from assemblyline.common.dict_utils import flatten, unflatten
from assemblyline.common.forge import CachedObject
from assemblyline.common.heuristics import service_heuristic_to_result_heuristic, InvalidHeuristicException
from assemblyline.common.isotime import now_as_iso
from assemblyline.odm import construct_safe
from assemblyline.odm.messages.service_heartbeat import Metrics
from assemblyline.odm.messages.task import Task as ServiceTask
Expand All @@ -27,7 +26,9 @@
status_table = ExpiringHash(SERVICE_STATE_HASH, ttl=60*30)
dispatch_client = DispatchClient(STORAGE)
heuristics = cast(Dict[str, Heuristic], CachedObject(get_heuristics, refresh=300))
tag_whitelister = forge.get_tag_whitelister(log=LOGGER)
tag_safelister = CachedObject(forge.get_tag_safelister,
kwargs=dict(log=LOGGER, config=config, datastore=STORAGE),
refresh=300)

SUB_API = 'task'
task_api = make_subapi_blueprint(SUB_API, api_version=1)
Expand All @@ -52,11 +53,14 @@ def get_task(client_info):
service_version = client_info['service_version']
service_tool_version = client_info['service_tool_version']
client_id = client_info['client_id']
timeout = int(float(request.headers.get('timeout', 30)))
service_data = dispatch_client.service_data[service_name]
remaining_time = timeout = int(float(request.headers.get('timeout', 30)))

try:
service_data = dispatch_client.service_data[service_name]
except KeyError:
return make_api_response({}, "The service you're asking task for does not exist, try later", 404)

start_time = time.time()
remaining_time = timeout
stats = {
"execute": 0,
"cache_miss": 0,
Expand Down Expand Up @@ -210,6 +214,7 @@ def handle_task_result(exec_time: int, task: ServiceTask, result: Dict[str, Any]
# Add scores to the heuristics, if any section set a heuristic
total_score = 0
for section in result['result']['sections']:
section['tags'] = flatten(section['tags'])
if section.get('heuristic'):
heur_id = f"{client_info['service_name'].upper()}.{str(section['heuristic']['heur_id'])}"
section['heuristic']['heur_id'] = heur_id
Expand All @@ -236,8 +241,10 @@ def handle_task_result(exec_time: int, task: ServiceTask, result: Dict[str, Any]

# Process the tag values
for section in result['result']['sections']:
# Perform tag whitelisting
section['tags'] = unflatten(tag_whitelister.get_validated_tag_map(flatten(section['tags'])))
# Perform tag safelisting
tags, safelisted_tags = tag_safelister.get_validated_tag_map(section['tags'])
section['tags'] = unflatten(tags)
section['safelisted_tags'] = safelisted_tags

section['tags'], dropped = construct_safe(Tagging, section.get('tags', {}))

Expand Down
2 changes: 2 additions & 0 deletions assemblyline_service_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from assemblyline_service_server.api.v1.file import file_api
from assemblyline_service_server.api.v1.service import service_api
from assemblyline_service_server.api.v1.task import task_api
from assemblyline_service_server.api.v1.safelist import safelist_api
from assemblyline_service_server.healthz import healthz

config = forge.get_config()
Expand All @@ -25,6 +26,7 @@
app.register_blueprint(file_api)
app.register_blueprint(service_api)
app.register_blueprint(task_api)
app.register_blueprint(safelist_api)

# Setup logging
app.logger.setLevel(LOGGER.getEffectiveLevel())
Expand Down
54 changes: 54 additions & 0 deletions test/test_srv_safelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

import pytest

from unittest.mock import MagicMock, patch

from assemblyline.odm import randomizer
from assemblyline.odm.models.safelist import Safelist
from assemblyline_service_server import app
from assemblyline_service_server.config import AUTH_KEY

headers = {
'Container-Id': randomizer.get_random_hash(12),
'X-APIKey': AUTH_KEY,
'Service-Name': 'Safelist',
'Service-Version': randomizer.get_random_service_version(),
'Service-Tool-Version': randomizer.get_random_hash(64),
'Timeout': 1,
'X-Forwarded-For': '127.0.0.1',
}


@pytest.fixture(scope='function')
def storage():
ds = MagicMock()
with patch('assemblyline_service_server.api.v1.safelist.STORAGE', ds):
yield ds


@pytest.fixture()
def client():
client = app.app.test_client()
yield client


# noinspection PyUnusedLocal
def test_safelist_exist(client, storage):
valid_hash = randomizer.get_random_hash(64)
valid_resp = randomizer.random_model_obj(Safelist, as_json=True)
valid_resp['hashes']['sha256'] = valid_hash
storage.safelist.get_if_exists.return_value = valid_resp

resp = client.get(f'/api/v1/safelist/{valid_hash}/', headers=headers)
assert resp.status_code == 200
assert resp.json['api_response'] == valid_resp


# noinspection PyUnusedLocal
def test_safelist_missing(client, storage):
invalid_hash = randomizer.get_random_hash(64)
storage.safelist.get_if_exists.return_value = None

resp = client.get(f'/api/v1/safelist/{invalid_hash}/', headers=headers)
assert resp.status_code == 404
assert resp.json['api_response'] is None

0 comments on commit b56c8e6

Please sign in to comment.