Skip to content

Commit

Permalink
use consolidated source fetching methods
Browse files Browse the repository at this point in the history
  • Loading branch information
cccs-rs committed Sep 22, 2021
1 parent 1714162 commit 813fca5
Showing 1 changed file with 7 additions and 186 deletions.
193 changes: 7 additions & 186 deletions suricata_/update_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from assemblyline.common.isotime import iso_to_epoch, epoch_to_iso
from assemblyline.odm.models.service import Service, UpdateSource
from assemblyline_v4_service.updater.updater import ServiceUpdater, temporary_api_key
from assemblyline_v4_service.updater.helper import git_clone_repo, url_download, SkipSource

from suricata_.suricata_importer import SuricataImporter

Expand All @@ -28,190 +29,10 @@
UPDATE_CONFIGURATION_PATH = os.environ.get('UPDATE_CONFIGURATION_PATH', "/tmp/suricata_updater_config.yaml")
UPDATE_OUTPUT_PATH = os.environ.get('UPDATE_OUTPUT_PATH', "/tmp/suricata_updater_output")
UPDATE_DIR = os.path.join(tempfile.gettempdir(), 'suricata_updates')
LOGGER = logging.getLogger('assemblyline.updater.suricata')

UI_SERVER = os.getenv('UI_SERVER', 'https://nginx')


class SkipSource(RuntimeError):
pass


def add_cacert(cert: str):
# Add certificate to requests
cafile = certifi.where()
with open(cafile, 'a') as ca_editor:
ca_editor.write(f"\n{cert}")


def url_download(source: Dict[str, Any], previous_update=None) -> List:
"""
:param source:
:param previous_update:
:return:
"""
name = source['name']
uri = source['uri']
pattern = source.get('pattern', None)
username = source.get('username', None)
password = source.get('password', None)
ca_cert = source.get('ca_cert', None)
ignore_ssl_errors = source.get('ssl_ignore_errors', False)
auth = (username, password) if username and password else None

proxy = source.get('proxy', None)
headers = source.get('headers', None)

LOGGER.info(f"{name} source is configured to {'ignore SSL errors' if ignore_ssl_errors else 'verify SSL'}.")
if ca_cert:
LOGGER.info("A CA certificate has been provided with this source.")
add_cacert(ca_cert)

# Create a requests session
session = requests.Session()
session.verify = not ignore_ssl_errors

# Let https requests go through proxy
if proxy:
os.environ['https_proxy'] = proxy

try:
if isinstance(previous_update, str):
previous_update = iso_to_epoch(previous_update)

# Check the response header for the last modified date
response = session.head(uri, auth=auth, headers=headers)
last_modified = response.headers.get('Last-Modified', None)
if last_modified:
# Convert the last modified time to epoch
last_modified = time.mktime(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z"))

# Compare the last modified time with the last updated time
if previous_update and last_modified <= previous_update:
# File has not been modified since last update, do nothing
raise SkipSource()

if previous_update:
previous_update = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.gmtime(previous_update))
if headers:
headers['If-Modified-Since'] = previous_update
else:
headers = {'If-Modified-Since': previous_update}

response = session.get(uri, auth=auth, headers=headers)

# Check the response code
if response.status_code == requests.codes['not_modified']:
# File has not been modified since last update, do nothing
raise SkipSource()
elif response.ok:
if not os.path.exists(UPDATE_DIR):
os.makedirs(UPDATE_DIR)

file_name = os.path.basename(urlparse(uri).path)
file_path = os.path.join(UPDATE_DIR, file_name)
with open(file_path, 'wb') as f:
f.write(response.content)

rules_files = None
if file_name.endswith('tar.gz'):
extract_dir = os.path.join(UPDATE_DIR, name)
shutil.unpack_archive(file_path, extract_dir=extract_dir)

rules_files = set()
for path_in_dir, _, files in os.walk(extract_dir):
for filename in files:
filepath = os.path.join(extract_dir, path_in_dir, filename)
if pattern:
if re.match(pattern, filename):
rules_files.add(filepath)
else:
rules_files.add(filepath)

# Clear proxy setting
if proxy:
del os.environ['https_proxy']

return [(f, get_sha256_for_file(f)) for f in rules_files or [file_path]]

except requests.Timeout:
# TODO: should we retry?
pass
except Exception as e:
# Catch all other types of exceptions such as ConnectionError, ProxyError, etc.
LOGGER.info(str(e))
exit()
# TODO: Should we exit even if one file fails to download? Or should we continue downloading other files?
finally:
# Close the requests session
session.close()


def git_clone_repo(source: Dict[str, Any], previous_update=None) -> List:
name = source['name']
url = source['uri']
pattern = source.get('pattern', None)
key = source.get('private_key', None)

ignore_ssl_errors = source.get("ssl_ignore_errors", False)
ca_cert = source.get("ca_cert")
proxy = source.get('proxy', None)

git_config = None
git_env = {}

if ignore_ssl_errors:
git_env['GIT_SSL_NO_VERIFY'] = 1

# Let https requests go through proxy
if proxy:
os.environ['https_proxy'] = proxy

if ca_cert:
LOGGER.info("A CA certificate has been provided with this source.")
add_cacert(ca_cert)
git_env['GIT_SSL_CAINFO'] = certifi.where()

if key:
LOGGER.info(f"key found for {url}")
# Save the key to a file
git_ssh_identity_file = os.path.join(tempfile.gettempdir(), 'id_rsa')
with open(git_ssh_identity_file, 'w') as key_fh:
key_fh.write(key)
os.chmod(git_ssh_identity_file, 0o0400)

git_ssh_cmd = f"ssh -oStrictHostKeyChecking=no -i {git_ssh_identity_file}"
git_env['GIT_SSH_COMMAND'] = git_ssh_cmd

clone_dir = os.path.join(UPDATE_DIR, name)
if os.path.exists(clone_dir):
shutil.rmtree(clone_dir)

repo = Repo.clone_from(url, clone_dir, env=git_env, git_config=git_config)

# Check repo last commit
if previous_update:
if isinstance(previous_update, str):
previous_update = iso_to_epoch(previous_update)
for c in repo.iter_commits():
if c.committed_date < previous_update:
raise SkipSource()
break

if pattern:
files = [(os.path.join(clone_dir, f), get_sha256_for_file(f))
for f in os.listdir(clone_dir) if re.match(pattern, f)]
else:
files = [(f, get_sha256_for_file(f)) for f in glob.glob(os.path.join(clone_dir, '*.rules*'))]

# Clear proxy setting
if proxy:
del os.environ['https_proxy']

return files


class SuricataUpdateServer(ServiceUpdater):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -242,13 +63,13 @@ def do_source_update(self, service: Service) -> None:
classification.UNRESTRICTED)
try:
if uri.endswith('.git'):
files = git_clone_repo(source, previous_update=old_update_time)
files = git_clone_repo(source, old_update_time, "*.rules", self.log, UPDATE_DIR)
for file, sha256 in files:
files_sha256.setdefault(source_name, {})
if previous_hashes.get(source_name, {}).get(file, None) != sha256:
files_sha256[source_name][file] = sha256
else:
files = url_download(source, previous_update=old_update_time)
files = url_download(source, old_update_time, self.log, UPDATE_DIR)
for file, sha256 in files:
files_sha256.setdefault(source_name, {})
if previous_hashes.get(source_name, {}).get(file, None) != sha256:
Expand All @@ -259,20 +80,20 @@ def do_source_update(self, service: Service) -> None:
continue

if files_sha256:
LOGGER.info("Found new Suricata rule files to process!")
self.log.info("Found new Suricata rule files to process!")

suricata_importer = SuricataImporter(al_client, logger=LOGGER)
suricata_importer = SuricataImporter(al_client, logger=self.log)

for source, source_val in files_sha256.items():
total_imported = 0
default_classification = source_default_classification[source]
for file in source_val.keys():
total_imported += suricata_importer.import_file(file, source,
default_classification=default_classification)
LOGGER.info(f"{total_imported} signatures were imported for source {source}")
self.log.info(f"{total_imported} signatures were imported for source {source}")

else:
LOGGER.info('No new Suricata rule files to process')
self.log.info('No new Suricata rule files to process')

self.set_source_update_time(run_time)
self.set_source_extra(files_sha256)
Expand Down

0 comments on commit 813fca5

Please sign in to comment.