diff --git a/.coveragerc b/.coveragerc index d315b87a..ec233bb2 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,7 +2,7 @@ # https://coverage.readthedocs.io/en/latest/config.html [run] -source = src/example +source = src/guacscanner omit = branch = true diff --git a/README.md b/README.md index 8b1647a3..220769cc 100644 --- a/README.md +++ b/README.md @@ -6,20 +6,17 @@ [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cisagov/guacscanner.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cisagov/guacscanner/context:python) [![Known Vulnerabilities](https://snyk.io/test/github/cisagov/guacscanner/develop/badge.svg)](https://snyk.io/test/github/cisagov/guacscanner) -This is a generic skeleton project that can be used to quickly get a -new [cisagov](https://github.com/cisagov) Python library GitHub -project started. This skeleton project contains [licensing -information](LICENSE), as well as -[pre-commit hooks](https://pre-commit.com) and -[GitHub Actions](https://github.com/features/actions) configurations -appropriate for a Python library project. - -## New Repositories from a Skeleton ## - -Please see our [Project Setup guide](https://github.com/cisagov/development-guide/tree/develop/project_setup) -for step-by-step instructions on how to start a new repository from -a skeleton. This will save you time and effort when configuring a -new repository! +This project is a Python utility that continually scans the EC2 instances +in an AWS VPC and adds/removes Guacamole connections in the underlying +PostgreSQL database accordingly. + +This utility is [Dockerized](https://docker.com) in +[cisagov/guacscanner-docker](https://github.com/cisagov/guacscanner-docker), +and the resulting Docker container is intended to run as a part of +[cisagov/guacamole-composition](https://github.com/cisagov/guacamole-composition), +although it could - probably uselessly - run in a [Docker +composition](https://docs.docker.com/compose/) alongside only the +[official PostgreSQL Docker image](https://hub.docker.com/_/postgres). ## Contributing ## diff --git a/bump_version.sh b/bump_version.sh index e1324b84..fe8d89ad 100755 --- a/bump_version.sh +++ b/bump_version.sh @@ -6,7 +6,7 @@ set -o nounset set -o errexit set -o pipefail -VERSION_FILE=src/example/_version.py +VERSION_FILE=src/guacscanner/_version.py HELP_INFORMATION="bump_version.sh (show|major|minor|patch|prerelease|build|finalize)" diff --git a/setup.py b/setup.py index 3291c4dd..6252bd6f 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ """ -This is the setup module for the example project. +This is the setup module for the guacscanner project. Based on: @@ -42,10 +42,10 @@ def get_version(version_file): setup( - name="example", + name="guacscanner", # Versions should comply with PEP440 - version=get_version("src/example/_version.py"), - description="Example Python library", + version=get_version("src/guacscanner/_version.py"), + description="Scan for EC2 instances added (removed) from a VPC and create (destroy) the corresponding Guacamole connections.", long_description=readme(), long_description_content_type="text/markdown", # Landing page for CISA's cybersecurity mission @@ -81,13 +81,21 @@ def get_version(version_file): ], python_requires=">=3.6", # What does your project relate to? - keywords="skeleton", + keywords="aws, guacamole, vpc", packages=find_packages(where="src"), package_dir={"": "src"}, - package_data={"example": ["data/*.txt"]}, py_modules=[splitext(basename(path))[0] for path in glob("src/*.py")], include_package_data=True, - install_requires=["docopt", "schema", "setuptools >= 24.2.0"], + # TODO: Loosen these requirements. See cisagov/guacscanner#9 for + # more details. + install_requires=[ + "boto3 == 1.19.6", + "docopt == 0.6.2", + "ec2-metadata == 2.5.0", + "psycopg == 3.0.1", + "schema == 0.7.4", + "setuptools >= 24.2.0", + ], extras_require={ "test": [ "coverage", @@ -98,11 +106,17 @@ def get_version(version_file): # 1.11.1 fixed this issue, but to ensure expected behavior we'll pin # to never grab the regression version. "coveralls != 1.11.0", + "moto", "pre-commit", "pytest-cov", "pytest", ] }, - # Conveniently allows one to run the CLI tool as `example` - entry_points={"console_scripts": ["example = example.example:main"]}, + # Conveniently allows one to run the CLI tool as + # `guacscanner` + entry_points={ + "console_scripts": [ + "guacscanner = guacscanner.guacscanner:main", + ], + }, ) diff --git a/src/example/__init__.py b/src/example/__init__.py deleted file mode 100644 index 98b5e041..00000000 --- a/src/example/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""The example library.""" -# We disable a Flake8 check for "Module imported but unused (F401)" here because -# although this import is not directly used, it populates the value -# package_name.__version__, which is used to get version information about this -# Python package. -from ._version import __version__ # noqa: F401 -from .example import example_div - -__all__ = ["example_div"] diff --git a/src/example/data/secret.txt b/src/example/data/secret.txt deleted file mode 100644 index c40a49b5..00000000 --- a/src/example/data/secret.txt +++ /dev/null @@ -1 +0,0 @@ -Three may keep a secret, if two of them are dead. diff --git a/src/example/example.py b/src/example/example.py deleted file mode 100644 index d3eda196..00000000 --- a/src/example/example.py +++ /dev/null @@ -1,103 +0,0 @@ -"""example is an example Python library and tool. - -Divide one integer by another and log the result. Also log some information -from an environment variable and a package resource. - -EXIT STATUS - This utility exits with one of the following values: - 0 Calculation completed successfully. - >0 An error occurred. - -Usage: - example [--log-level=LEVEL] - example (-h | --help) - -Options: - -h --help Show this message. - --log-level=LEVEL If specified, then the log level will be set to - the specified value. Valid values are "debug", "info", - "warning", "error", and "critical". [default: info] -""" - -# Standard Python Libraries -import logging -import os -import sys -from typing import Any, Dict - -# Third-Party Libraries -import docopt -import pkg_resources -from schema import And, Schema, SchemaError, Use - -from ._version import __version__ - -DEFAULT_ECHO_MESSAGE: str = "Hello World from the example default!" - - -def example_div(dividend: int, divisor: int) -> float: - """Print some logging messages.""" - logging.debug("This is a debug message") - logging.info("This is an info message") - logging.warning("This is a warning message") - logging.error("This is an error message") - logging.critical("This is a critical message") - return dividend / divisor - - -def main() -> None: - """Set up logging and call the example function.""" - args: Dict[str, str] = docopt.docopt(__doc__, version=__version__) - # Validate and convert arguments as needed - schema: Schema = Schema( - { - "--log-level": And( - str, - Use(str.lower), - lambda n: n in ("debug", "info", "warning", "error", "critical"), - error="Possible values for --log-level are " - + "debug, info, warning, error, and critical.", - ), - "": Use(int, error=" must be an integer."), - "": And( - Use(int), - lambda n: n != 0, - error=" must be an integer that is not 0.", - ), - str: object, # Don't care about other keys, if any - } - ) - - try: - validated_args: Dict[str, Any] = schema.validate(args) - except SchemaError as err: - # Exit because one or more of the arguments were invalid - print(err, file=sys.stderr) - sys.exit(1) - - # Assign validated arguments to variables - dividend: int = validated_args[""] - divisor: int = validated_args[""] - log_level: str = validated_args["--log-level"] - - # Set up logging - logging.basicConfig( - format="%(asctime)-15s %(levelname)s %(message)s", level=log_level.upper() - ) - - logging.info("%d / %d == %f", dividend, divisor, example_div(dividend, divisor)) - - # Access some data from an environment variable - message: str = os.getenv("ECHO_MESSAGE", DEFAULT_ECHO_MESSAGE) - logging.info('ECHO_MESSAGE="%s"', message) - - # Access some data from our package data (see the setup.py) - secret_message: str = ( - pkg_resources.resource_string("example", "data/secret.txt") - .decode("utf-8") - .strip() - ) - logging.info('Secret="%s"', secret_message) - - # Stop logging and clean up - logging.shutdown() diff --git a/src/guacscanner/ConnectionParameters.py b/src/guacscanner/ConnectionParameters.py new file mode 100644 index 00000000..01b81b36 --- /dev/null +++ b/src/guacscanner/ConnectionParameters.py @@ -0,0 +1,34 @@ +"""A dataclass container for Guacamole connection parameters.""" + + +# Standard Python Libraries +from dataclasses import dataclass + + +@dataclass +class ConnectionParameters: + """A dataclass container for Guacamole connection parameters.""" + + """The slots for this dataclass.""" + __slots__ = ( + "private_ssh_key", + "rdp_password", + "rdp_username", + "vnc_password", + "vnc_username", + ) + + """The private SSH key to use when transferring data via VNC.""" + private_ssh_key: str + + """The password to use when Guacamole establishes an RDP connection.""" + rdp_password: str + + """The user name to use when Guacamole establishes an RDP connection.""" + rdp_username: str + + """The password to use when Guacamole establishes a VNC connection.""" + vnc_password: str + + """The user name to use when Guacamole establishes a VNC connection.""" + vnc_username: str diff --git a/src/guacscanner/__init__.py b/src/guacscanner/__init__.py new file mode 100644 index 00000000..3f9d58ac --- /dev/null +++ b/src/guacscanner/__init__.py @@ -0,0 +1,35 @@ +"""The guacscanner library.""" +# We disable a Flake8 check for "Module imported but unused (F401)" +# here because, although this import is not directly used, it +# populates the value package_name.__version__, which is used to get +# version information about this Python package. +from ._version import __version__ # noqa: F401 +from .guacscanner import ( + ConnectionParameters, + add_instance_connection, + add_user, + check_for_ghost_instances, + entity_exists, + get_connection_name, + get_entity_id, + instance_connection_exists, + main, + process_instance, + remove_connection, + remove_instance_connections, +) + +__all__ = [ + "ConnectionParameters", + "add_instance_connection", + "add_user", + "check_for_ghost_instances", + "entity_exists", + "get_connection_name", + "get_entity_id", + "instance_connection_exists", + "main", + "process_instance", + "remove_connection", + "remove_instance_connections", +] diff --git a/src/example/__main__.py b/src/guacscanner/__main__.py similarity index 70% rename from src/example/__main__.py rename to src/guacscanner/__main__.py index 11a3238f..c438269b 100644 --- a/src/example/__main__.py +++ b/src/guacscanner/__main__.py @@ -1,5 +1,5 @@ """Code to run if this package is used as a Python module.""" -from .example import main +from .guacscanner import main main() diff --git a/src/example/_version.py b/src/guacscanner/_version.py similarity index 70% rename from src/example/_version.py rename to src/guacscanner/_version.py index 33cee844..de155d77 100644 --- a/src/example/_version.py +++ b/src/guacscanner/_version.py @@ -1,2 +1,2 @@ """This file defines the version of this module.""" -__version__ = "0.0.1" +__version__ = "1.0.0" diff --git a/src/guacscanner/guacscanner.py b/src/guacscanner/guacscanner.py new file mode 100644 index 00000000..a85cfe5a --- /dev/null +++ b/src/guacscanner/guacscanner.py @@ -0,0 +1,807 @@ +"""Query AWS for new (destroyed) instances and add (remove) Guacamole connections for them. + +Also check for instances that have been destroyed and remove their +corresponding connections. + +EXIT STATUS + 0 Update was successful. + >0 An error occurred. + +Usage: + guacscanner [--log-level=LEVEL] [--oneshot] [--sleep=SECONDS] [--postgres-password=PASSWORD|--postgres-password-file=FILENAME] [--postgres-username=USERNAME|--postgres-username-file=FILENAME] [--private-ssh-key=KEY|--private-ssh-key-file=FILENAME] [--rdp-password=PASSWORD|--rdp-password-file=FILENAME] [--rdp-username=USERNAME|--rdp-username-file=FILENAME] [--region=REGION] [--vnc-password=PASSWORD|--vnc-password-file=FILENAME] [--vnc-username=USERNAME|--vnc-username-file=FILENAME] [--vpc-id=VPC_ID] + guacscanner (-h | --help) + +Options: + -h --help Show this message. + --log-level=LEVEL If specified, then the log level will be set to + the specified value. Valid values are "debug", "info", + "warning", "error", and "critical". [default: info] + --oneshot If present then the loop that adds (removes) connections for new (terminated) instances will only be run once. + --postgres-password=PASSWORD If specified then the specified value will be used as the password when connecting to the PostgreSQL database. Otherwise, the password will be read from a local file. + --postgres-password-file=FILENAME The file from which the PostgreSQL password will be read. [default: /run/secrets/postgres-password] + --postgres-username=USERNAME If specified then the specified value will be used when connecting to the PostgreSQL database. Otherwise, the username will be read from a local file. + --postgres-username-file=FILENAME The file from which the PostgreSQL username will be read. [default: /run/secrets/postgres-username] + --private-ssh-key=KEY If specified then the specified value will be used for the private SSH key. Otherwise, the SSH key will be read from a local file. + --private-ssh-key-file=FILENAME The file from which the private SSH key will be read. [default: /run/secrets/private-ssh-key] + --rdp-password=PASSWORD If specified then the specified value will be used for the RDP password. Otherwise, the password will be read from a local file. + --rdp-password-file=FILENAME The file from which the RDP password will be read. [default: /run/secrets/rdp-password] + --rdp-username=USERNAME If specified then the specified value will be used for the RDP username. Otherwise, the username will be read from a local file. + --rdp-username-file=FILENAME The file from which the RDP username will be read. [default: /run/secrets/rdp-username] + --region=REGION The AWS region in which the VPC specified by --vpc-id exists. Unused if --vpc-id is not specified. [default: us-east-1] + --sleep=SECONDS Sleep for the specified number of seconds between executions of the Guacamole connection update loop. [default: 60] + --vnc-password=PASSWORD If specified then the specified value will be used for the VNC password. Otherwise, the password will be read from a local file. + --vnc-password-file=FILENAME The file from which the VNC password will be read. [default: /run/secrets/vnc-password] + --vnc-username=USERNAME If specified then the specified value will be used for the VNC username. Otherwise, the username will be read from a local file. + --vnc-username-file=FILENAME The file from which the VNC username will be read. [default: /run/secrets/vnc-username] + --vpc-id=VPC_ID If specified then query for EC2 instances created + or destroyed in the specified VPC ID. If not + specified then the ID of the VPC in which the host + resides will be used. +""" + + +# Standard Python Libraries +import datetime +import hashlib +import logging +import re +import secrets +import string +import sys +import time + +# Third-Party Libraries +import boto3 +import docopt +from ec2_metadata import ec2_metadata +import psycopg +from schema import And, Optional, Or, Schema, SchemaError, Use + +from .ConnectionParameters import ConnectionParameters +from ._version import __version__ + +# TODO: Add exception handling for all the database accesses and +# wherever else it is appropriate. guacscanner currently just bombs +# out if an exception is thrown, but it would probably make more sense +# to print an error message and keep looping, keepin' the train +# a-chooglin'. See cisagov/guacscanner#5 for more details. + +# TODO: Create command line options with defaults for these variables. +# See cisagov/guacscanner#2 for more details. +DEFAULT_ADD_INSTANCE_STATES = [ + "running", +] +DEFAULT_PASSWORD_LENGTH = 32 +DEFAULT_PASSWORD_SALT_LENGTH = 32 +DEFAULT_POSTGRES_DB_NAME = "guacamole_db" +DEFAULT_POSTGRES_HOSTNAME = "postgres" +DEFAULT_POSTGRES_PORT = 5432 +DEFAULT_REMOVE_INSTANCE_STATES = [ + "terminated", +] +DEFAULT_AMI_SKIP_REGEXES = [ + re.compile(r"^guacamole-.*$"), + re.compile(r"^nessus-.*$"), + re.compile(r"^samba-.*$"), +] + +# Some precompiled regexes +# +# Note the use of a named capture group here via the (?P...) +# syntax, as described here: +# https://docs.python.org/3/library/re.html#regular-expression-syntax +INSTANCE_ID_REGEX = re.compile(r"^.* \((?Pi-[0-9a-f]{17})\)$") +VPC_ID_REGEX = re.compile(r"^vpc-([0-9a-f]{8}|[0-9a-f]{17})$") + +# TODO: Determine if we can use f-strings instead of .format() for +# these queries. Also define the psycopg.sql.Identifier() variables +# separately so that they can be reused where that is possible. See +# cisagov/guacscanner#3 for more details. + +# The PostgreSQL queries used for adding and removing connections +COUNT_QUERY = psycopg.sql.SQL( + "SELECT COUNT({id_field}) FROM {table} WHERE {name_field} = %s" +).format( + id_field=psycopg.sql.Identifier("connection_id"), + table=psycopg.sql.Identifier("guacamole_connection"), + name_field=psycopg.sql.Identifier("connection_name"), +) +IDS_QUERY = psycopg.sql.SQL( + "SELECT {id_field} FROM {table} WHERE {name_field} = %s" +).format( + id_field=psycopg.sql.Identifier("connection_id"), + table=psycopg.sql.Identifier("guacamole_connection"), + name_field=psycopg.sql.Identifier("connection_name"), +) +NAMES_QUERY = psycopg.sql.SQL("SELECT {id_field}, {name_field} FROM {table}").format( + id_field=psycopg.sql.Identifier("connection_id"), + name_field=psycopg.sql.Identifier("connection_name"), + table=psycopg.sql.Identifier("guacamole_connection"), +) +INSERT_CONNECTION_QUERY = psycopg.sql.SQL( + """INSERT INTO {table} ( + {name_field}, {protocol_field}, {max_connections_field}, + {max_connections_per_user_field}, {proxy_port_field}, {proxy_hostname_field}, + {proxy_encryption_method_field}) + VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING {id_field};""" +).format( + table=psycopg.sql.Identifier("guacamole_connection"), + name_field=psycopg.sql.Identifier("connection_name"), + protocol_field=psycopg.sql.Identifier("protocol"), + max_connections_field=psycopg.sql.Identifier("max_connections"), + max_connections_per_user_field=psycopg.sql.Identifier("max_connections_per_user"), + proxy_port_field=psycopg.sql.Identifier("proxy_port"), + proxy_hostname_field=psycopg.sql.Identifier("proxy_hostname"), + proxy_encryption_method_field=psycopg.sql.Identifier("proxy_encryption_method"), + id_field=psycopg.sql.Identifier("connection_id"), +) +INSERT_CONNECTION_PARAMETER_QUERY = psycopg.sql.SQL( + """INSERT INTO {table} + ({id_field}, {parameter_name_field}, {parameter_value_field}) + VALUES (%s, %s, %s);""" +).format( + table=psycopg.sql.Identifier("guacamole_connection_parameter"), + id_field=psycopg.sql.Identifier("connection_id"), + parameter_name_field=psycopg.sql.Identifier("parameter_name"), + parameter_value_field=psycopg.sql.Identifier("parameter_value"), +) +DELETE_CONNECTIONS_QUERY = psycopg.sql.SQL( + """DELETE FROM {table} WHERE {id_field} = %s;""" +).format( + table=psycopg.sql.Identifier("guacamole_connection"), + id_field=psycopg.sql.Identifier("connection_id"), +) +DELETE_CONNECTION_PARAMETERS_QUERY = psycopg.sql.SQL( + """DELETE FROM {table} WHERE {id_field} = %s;""" +).format( + table=psycopg.sql.Identifier("guacamole_connection_parameter"), + id_field=psycopg.sql.Identifier("connection_id"), +) + +# The PostgreSQL queries used for adding and removing users +ENTITY_COUNT_QUERY = psycopg.sql.SQL( + "SELECT COUNT({id_field}) FROM {table} WHERE {name_field} = %s AND {type_field} = %s" +).format( + id_field=psycopg.sql.Identifier("entity_id"), + table=psycopg.sql.Identifier("guacamole_entity"), + name_field=psycopg.sql.Identifier("name"), + type_field=psycopg.sql.Identifier("type"), +) +ENTITY_ID_QUERY = psycopg.sql.SQL( + "SELECT {id_field} FROM {table} WHERE {name_field} = %s AND {type_field} = %s" +).format( + id_field=psycopg.sql.Identifier("entity_id"), + table=psycopg.sql.Identifier("guacamole_entity"), + name_field=psycopg.sql.Identifier("name"), + type_field=psycopg.sql.Identifier("type"), +) +INSERT_ENTITY_QUERY = psycopg.sql.SQL( + """INSERT INTO {table} ( + {name_field}, {type_field}) + VALUES (%s, %s) RETURNING {id_field};""" +).format( + table=psycopg.sql.Identifier("guacamole_entity"), + name_field=psycopg.sql.Identifier("name"), + type_field=psycopg.sql.Identifier("type"), + id_field=psycopg.sql.Identifier("entity_id"), +) +INSERT_USER_QUERY = psycopg.sql.SQL( + """INSERT INTO {table} ( + {id_field}, {hash_field}, {salt_field}, {date_field}) + VALUES (%s, %s, %s, %s);""" +).format( + table=psycopg.sql.Identifier("guacamole_user"), + id_field=psycopg.sql.Identifier("entity_id"), + hash_field=psycopg.sql.Identifier("password_hash"), + salt_field=psycopg.sql.Identifier("password_salt"), + date_field=psycopg.sql.Identifier("password_date"), +) +# The PostgreSQL queries used to add and remove connection +# permissions +INSERT_CONNECTION_PERMISSION_QUERY = psycopg.sql.SQL( + """INSERT INTO {table} ( + {entity_id_field}, {connection_id_field}, {permission_field}) + VALUES (%s, %s, %s);""" +).format( + table=psycopg.sql.Identifier("guacamole_connection_permission"), + entity_id_field=psycopg.sql.Identifier("entity_id"), + connection_id_field=psycopg.sql.Identifier("connection_id"), + permission_field=psycopg.sql.Identifier("permission"), +) +DELETE_CONNECTION_PERMISSIONS_QUERY = psycopg.sql.SQL( + """DELETE FROM {table} WHERE {connection_id_field} = %s;""" +).format( + table=psycopg.sql.Identifier("guacamole_connection_permission"), + connection_id_field=psycopg.sql.Identifier("connection_id"), +) + + +def entity_exists(db_connection, entity_name, entity_type): + """Return a boolean indicating whether an entity with the specified name and type exists.""" + with db_connection.cursor() as cursor: + logging.debug( + "Checking to see if an entity named %s of type %s exists in the database.", + entity_name, + entity_type, + ) + cursor.execute(ENTITY_COUNT_QUERY, (entity_name, entity_type)) + count = cursor.fetchone()["count"] + if count != 0: + logging.debug( + "An entity named %s of type %s exists in the database.", + entity_name, + entity_type, + ) + else: + logging.debug( + "No entity named %s of type %s exists in the database.", + entity_name, + entity_type, + ) + + return count != 0 + + +def get_entity_id(db_connection, entity_name, entity_type): + """Return the ID corresponding to the entity with the specified name and type.""" + logging.debug("Looking for entity ID for %s of type %s.", entity_name, entity_type) + with db_connection.cursor() as cursor: + logging.debug( + "Checking to see if any entity named %s of type %s exists in the database.", + entity_name, + entity_type, + ) + cursor.execute(ENTITY_ID_QUERY, (entity_name, entity_type)) + + # Note that we are assuming there is only a single match. + return cursor.fetchone()["entity_id"] + + +def add_user( + db_connection: psycopg.Connection, + username: str, + password: str = None, + salt: bytes = None, +) -> int: + """Add a user, returning its corresponding entity ID. + + If password (salt) is None (the default) then a random password + (salt) will be generated for the user. + + Note that the salt should be an array of bytes, while the password + should be an ASCII string. + + """ + logging.debug("Adding user entry.") + + if password is None: + # Generate a random password consisting of ASCII letters and + # digits + alphabet = string.ascii_letters + string.digits + password = "".join( + secrets.choice(alphabet) for i in range(DEFAULT_PASSWORD_LENGTH) + ) + if salt is None: + # Generate a random byte array + salt = secrets.token_bytes(DEFAULT_PASSWORD_SALT_LENGTH) + + # Compute the salted password hash that is to be saved to the + # database. + # + # Note that we convert the hexed salt and the salted password hash + # to uppercase, since that must be done to match the corresponding + # values in the database that are generated for the default + # guacadmin password by the database initialization script. + hexed_salt = salt.hex().upper() + hasher = hashlib.sha256() + # We must use the same password hashing algorithm as is used in + # the Guacamole source code, so we cannot avoid the LGTM warning + # here. + hasher.update(password.encode()) # lgtm[py/weak-sensitive-data-hashing] + hasher.update(hexed_salt.encode()) + salted_password_hash = hasher.hexdigest().upper() + + entity_id = None + with db_connection.cursor() as cursor: + cursor.execute( + INSERT_ENTITY_QUERY, + ( + username, + "USER", + ), + ) + entity_id = cursor.fetchone()["entity_id"] + cursor.execute( + INSERT_USER_QUERY, + ( + entity_id, + bytes.fromhex(salted_password_hash), + salt, + datetime.datetime.now(), + ), + ) + + # Commit all pending transactions to the database + db_connection.commit() + + return entity_id + + +def instance_connection_exists(db_connection, connection_name): + """Return a boolean indicating whether a connection with the specified name exists.""" + with db_connection.cursor() as cursor: + logging.debug( + "Checking to see if a connection named %s exists in the database.", + connection_name, + ) + cursor.execute(COUNT_QUERY, (connection_name,)) + count = cursor.fetchone()["count"] + if count != 0: + logging.debug( + "A connection named %s exists in the database.", connection_name + ) + else: + logging.debug( + "No connection named %s exists in the database.", connection_name + ) + + return count != 0 + + +def add_instance_connection( + db_connection, + instance, + connection_parameters: ConnectionParameters, + entity_id, +): + """Add a connection for the EC2 instance.""" + logging.debug("Adding connection entry for %s.", instance.id) + hostname = instance.private_dns_name + connection_name = get_connection_name(instance) + is_windows = False + connection_protocol = "vnc" + connection_port = 5901 + if instance.platform and instance.platform.lower() == "windows": + logging.debug("Instance %s is Windows and therefore uses RDP.", instance.id) + is_windows = True + connection_protocol = "rdp" + connection_port = 3389 + + with db_connection.cursor() as cursor: + cursor.execute( + INSERT_CONNECTION_QUERY, + ( + connection_name, + connection_protocol, + 10, + 10, + 4822, + "guacd", + "NONE", + ), + ) + connection_id = cursor.fetchone()["connection_id"] + + guac_conn_params = ( + ( + connection_id, + "cursor", + "local", + ), + ( + connection_id, + "sftp-directory", + f"/home/{connection_parameters.vnc_username}/Documents", + ), + ( + connection_id, + "sftp-username", + connection_parameters.vnc_username, + ), + ( + connection_id, + "sftp-private-key", + connection_parameters.private_ssh_key, + ), + ( + connection_id, + "sftp-server-alive-interval", + 60, + ), + ( + connection_id, + "sftp-root-directory", + f"/home/{connection_parameters.vnc_username}/", + ), + ( + connection_id, + "enable-sftp", + True, + ), + ( + connection_id, + "color-depth", + 24, + ), + ( + connection_id, + "hostname", + hostname, + ), + ( + connection_id, + "password", + connection_parameters.vnc_password, + ), + ( + connection_id, + "port", + connection_port, + ), + ) + if is_windows: + # mypy gives a warning on this line because we are + # re-assigning the variable with a tuple of a different + # length, but we know this is safe to do here. + guac_conn_params = ( # type: ignore + ( + connection_id, + "ignore-cert", + True, + ), + ( + connection_id, + "hostname", + hostname, + ), + ( + connection_id, + "password", + connection_parameters.rdp_password, + ), + ( + connection_id, + "port", + connection_port, + ), + ( + connection_id, + "username", + connection_parameters.rdp_username, + ), + ) + + logging.debug( + "Adding connection parameter entries for connection named %s.", + connection_name, + ) + cursor.executemany(INSERT_CONNECTION_PARAMETER_QUERY, guac_conn_params) + + logging.debug( + "Adding connection permission entries for connection named %s.", + connection_name, + ) + cursor.execute( + INSERT_CONNECTION_PERMISSION_QUERY, + ( + entity_id, + connection_id, + "READ", + ), + ) + + # Commit all pending transactions to the database + db_connection.commit() + + +def remove_connection(db_connection, connection_id): + """Remove all connections corresponding to the specified ID.""" + logging.debug("Removing connection entries for %s.", connection_id) + with db_connection.cursor() as cursor: + cursor.execute(DELETE_CONNECTIONS_QUERY, (connection_id,)) + + logging.debug("Removing connection parameter entries for %s.", connection_id) + cursor.execute(DELETE_CONNECTION_PARAMETERS_QUERY, (connection_id,)) + + logging.debug("Removing connection permission entries for %s.", connection_id) + cursor.execute(DELETE_CONNECTION_PERMISSIONS_QUERY, (connection_id,)) + + +def remove_instance_connections(db_connection, instance): + """Remove all connections corresponding to the EC2 instance.""" + logging.debug("Removing connections for %s.", instance.id) + connection_name = get_connection_name(instance) + with db_connection.cursor() as cursor: + logging.debug( + "Checking to see if any connections named %s exist in the database.", + connection_name, + ) + cursor.execute(IDS_QUERY, (connection_name,)) + for record in cursor: + logging.info("Removing entries for connections named %s.", connection_name) + connection_id = record["connection_id"] + remove_connection(db_connection, connection_id) + + # Commit all pending transactions to the database + db_connection.commit() + + +def get_connection_name(instance): + """Return the unique connection name for an EC2 instance.""" + name = [tag["Value"] for tag in instance.tags if tag["Key"] == "Name"][0] + return f"{name} ({instance.id})" + + +def process_instance( + db_connection, + instance, + add_instance_states, + remove_instance_states, + connection_parameters: ConnectionParameters, + entity_id, +): + """Add/remove connections for the specified EC2 instance.""" + logging.debug("Examining instance %s.", instance.id) + state = instance.state["Name"] + connection_name = get_connection_name(instance) + logging.debug("Connection name is %s.", connection_name) + if state in add_instance_states: + logging.info( + "Instance %s is in state %s and will be added if not already present.", + instance.id, + state, + ) + if not instance_connection_exists(db_connection, connection_name): + logging.info("Adding a connection for %s.", instance.id) + add_instance_connection( + db_connection, + instance, + connection_parameters, + entity_id, + ) + else: + logging.debug( + "Connection for %s already exists in the database.", instance.id + ) + elif state in remove_instance_states: + logging.info( + "Instance %s is in state %s and will be removed if present.", + instance.id, + state, + ) + remove_instance_connections(db_connection, instance) + else: + logging.debug( + "Instance %s is in state %s and WILL NOT be added or removed.", + instance.id, + state, + ) + + +def check_for_ghost_instances(db_connection, instances): + """Check to see if any connections belonging to nonexistent instances are in the database.""" + instance_ids = [instance.id for instance in instances] + with db_connection.cursor() as cursor: + cursor.execute(NAMES_QUERY) + for record in cursor: + connection_id = record["connection_id"] + connection_name = record["connection_name"] + m = INSTANCE_ID_REGEX.match(connection_name) + instance_id = None + if m: + instance_id = m.group("id") + else: + logging.error( + 'Connection name "%s" does not contain a valid instance ID', + connection_name, + ) + + if instance_id not in instance_ids: + logging.info( + "Connection for %s being removed since that instance no longer exists.", + instance_id, + ) + remove_connection(db_connection, connection_id) + + db_connection.commit() + + +def main() -> None: + """Add/remove connections to Guacamole DB as necessary.""" + # Parse command line arguments + args = docopt.docopt(__doc__, version=__version__) + # Validate and convert arguments as needed + schema = Schema( + { + "--log-level": And( + str, + Use(str.lower), + lambda n: n in ("debug", "info", "warning", "error", "critical"), + error="Possible values for --log-level are " + + "debug, info, warning, error, and critical.", + ), + "--sleep": And( + Use(float), + error="Value for --sleep must be parseable as a floating point number.", + ), + Optional("--vpc-id"): Or( + None, + And( + str, + Use(str.lower), + lambda x: VPC_ID_REGEX.match(x) is not None, + error="Possible values for --vpc-id are the characters vpc- followed by either 8 or 17 hexadecimal digits.", + ), + ), + str: object, # Don't care about other keys, if any + } + ) + try: + validated_args = schema.validate(args) + except SchemaError as err: + # Exit because one or more of the arguments were invalid + print(err, file=sys.stderr) + sys.exit(1) + + # Set up logging + log_level = validated_args["--log-level"] + logging.basicConfig( + format="%(asctime)-15s %(levelname)s %(message)s", level=log_level.upper() + ) + + add_instance_states = DEFAULT_ADD_INSTANCE_STATES + postgres_db_name = DEFAULT_POSTGRES_DB_NAME + postgres_hostname = DEFAULT_POSTGRES_HOSTNAME + postgres_port = DEFAULT_POSTGRES_PORT + remove_instance_states = DEFAULT_REMOVE_INSTANCE_STATES + + oneshot = validated_args["--oneshot"] + logging.debug("oneshot is %s.", oneshot) + + postgres_password = validated_args["--postgres-password"] + if postgres_password is None: + with open(validated_args["--postgres-password-file"], "r") as file: + postgres_password = file.read() + + postgres_username = validated_args["--postgres-username"] + if postgres_username is None: + with open(validated_args["--postgres-username-file"], "r") as file: + postgres_username = file.read() + + rdp_password = validated_args["--rdp-password"] + if rdp_password is None: + with open(validated_args["--rdp-password-file"], "r") as file: + rdp_password = file.read() + + rdp_username = validated_args["--rdp-username"] + if rdp_username is None: + with open(validated_args["--rdp-username-file"], "r") as file: + rdp_username = file.read() + + vnc_password = validated_args["--vnc-password"] + if vnc_password is None: + with open(validated_args["--vnc-password-file"], "r") as file: + vnc_password = file.read() + + vnc_username = validated_args["--vnc-username"] + if vnc_username is None: + with open(validated_args["--vnc-username-file"], "r") as file: + vnc_username = file.read() + + private_ssh_key = validated_args["--private-ssh-key"] + if private_ssh_key is None: + with open(validated_args["--private-ssh-key-file"], "r") as file: + private_ssh_key = file.read() + + db_connection_string = f"user={postgres_username} password={postgres_password} host={postgres_hostname} port={postgres_port} dbname={postgres_db_name}" + + vpc_id = validated_args["--vpc-id"] + # TODO: Verify that the region specified is indeed a valid AWS + # region. See cisagov/guacscanner#6 for more details. + region = validated_args["--region"] + + # If no VPC ID was specified on the command line then grab the VPC + # ID where this instance resides and use that. + ec2 = None + if vpc_id is None: + instance_id = ec2_metadata.instance_id + region = ec2_metadata.region + ec2 = boto3.resource("ec2", region_name=region) + instance = ec2.Instance(instance_id) + vpc_id = instance.vpc_id + else: + ec2 = boto3.resource("ec2", region_name=region) + + logging.info("Examining instances in VPC %s.", vpc_id) + + instances = ec2.Vpc(vpc_id).instances.all() + keep_looping = True + guacuser_id = None + while keep_looping: + time.sleep(validated_args["--sleep"]) + + try: + db_connection = psycopg.connect( + db_connection_string, row_factory=psycopg.rows.dict_row + ) + except psycopg.OperationalError: + logging.exception( + "Unable to connect to the PostgreSQL database backending Guacamole." + ) + continue + + # Create guacuser if it doesn't already exist + # + # TODO: Figure out a way to make this cleaner. We don't want + # to hardcode the guacuser name, and we want to allow the user + # to specify a list of users that should be created if they + # don't exist and given access to use the connections created + # by guacscanner. See cisagov/guacscanner#4 for more details. + if guacuser_id is None: + # We haven't initialized guacuser_id yet, so let's do it + # now. + if not entity_exists(db_connection, "guacuser", "USER"): + guacuser_id = add_user(db_connection, "guacuser") + else: + guacuser_id = get_entity_id(db_connection, "guacuser", "USER") + + for instance in instances: + ami = ec2.Image(instance.image_id) + # Early exit if this instance is running an AMI that we + # want to avoid adding to Guacamole. + try: + ami_matches = [ + regex.match(ami.name) for regex in DEFAULT_AMI_SKIP_REGEXES + ] + except AttributeError: + # This exception can be thrown when an instance is + # running an AMI to which the account no longer has + # access; for example, between the time when a new AMI + # of the same type is built and terraform-post-packer + # is run and the new AMI is applied to the account. + # In this situation we can't take any action because + # we can't access the AMI's name and hence can't know + # if the instance AMI is of a type whose Guacamole + # connections are being controlled by guacscanner. + # + # In any event, this continue statement should keep + # things moving when it does. + logging.exception( + "Unable to determine if instance is running an AMI that would cause it to be skipped." + ) + continue + if any(ami_matches): + continue + + process_instance( + db_connection, + instance, + add_instance_states, + remove_instance_states, + ConnectionParameters( + private_ssh_key=private_ssh_key, + rdp_password=rdp_password, + rdp_username=rdp_username, + vnc_password=vnc_password, + vnc_username=vnc_username, + ), + guacuser_id, + ) + + logging.info( + "Checking to see if any connections belonging to nonexistent instances are in the database." + ) + check_for_ghost_instances(db_connection, instances) + + if oneshot: + logging.debug( + "Stopping Guacamole connection update loop because --oneshot is present." + ) + keep_looping = False + + # pycopg.connect() can act as a context manager, but the + # connection is not closed when you leave the context; + # therefore, we still have to close the connection manually. + db_connection.close() + + logging.shutdown() diff --git a/tests/test_example.py b/tests/test_example.py deleted file mode 100644 index f8dea673..00000000 --- a/tests/test_example.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env pytest -vs -"""Tests for example.""" - -# Standard Python Libraries -import logging -import os -import sys -from unittest.mock import patch - -# Third-Party Libraries -import pytest - -# cisagov Libraries -import example - -div_params = [ - (1, 1, 1), - (2, 2, 1), - (0, 1, 0), - (8, 2, 4), -] - -log_levels = ( - "debug", - "info", - "warning", - "error", - "critical", -) - -# define sources of version strings -RELEASE_TAG = os.getenv("RELEASE_TAG") -PROJECT_VERSION = example.__version__ - - -def test_stdout_version(capsys): - """Verify that version string sent to stdout agrees with the module version.""" - with pytest.raises(SystemExit): - with patch.object(sys, "argv", ["bogus", "--version"]): - example.example.main() - captured = capsys.readouterr() - assert ( - captured.out == f"{PROJECT_VERSION}\n" - ), "standard output by '--version' should agree with module.__version__" - - -def test_running_as_module(capsys): - """Verify that the __main__.py file loads correctly.""" - with pytest.raises(SystemExit): - with patch.object(sys, "argv", ["bogus", "--version"]): - # F401 is a "Module imported but unused" warning. This import - # emulates how this project would be run as a module. The only thing - # being done by __main__ is importing the main entrypoint of the - # package and running it, so there is nothing to use from this - # import. As a result, we can safely ignore this warning. - # cisagov Libraries - import example.__main__ # noqa: F401 - captured = capsys.readouterr() - assert ( - captured.out == f"{PROJECT_VERSION}\n" - ), "standard output by '--version' should agree with module.__version__" - - -@pytest.mark.skipif( - RELEASE_TAG in [None, ""], reason="this is not a release (RELEASE_TAG not set)" -) -def test_release_version(): - """Verify that release tag version agrees with the module version.""" - assert ( - RELEASE_TAG == f"v{PROJECT_VERSION}" - ), "RELEASE_TAG does not match the project version" - - -@pytest.mark.parametrize("level", log_levels) -def test_log_levels(level): - """Validate commandline log-level arguments.""" - with patch.object(sys, "argv", ["bogus", f"--log-level={level}", "1", "1"]): - with patch.object(logging.root, "handlers", []): - assert ( - logging.root.hasHandlers() is False - ), "root logger should not have handlers yet" - return_code = None - try: - example.example.main() - except SystemExit as sys_exit: - return_code = sys_exit.code - assert return_code is None, "main() should return success" - assert ( - logging.root.hasHandlers() is True - ), "root logger should now have a handler" - assert ( - logging.getLevelName(logging.root.getEffectiveLevel()) == level.upper() - ), f"root logger level should be set to {level.upper()}" - assert return_code is None, "main() should return success" - - -def test_bad_log_level(): - """Validate bad log-level argument returns error.""" - with patch.object(sys, "argv", ["bogus", "--log-level=emergency", "1", "1"]): - return_code = None - try: - example.example.main() - except SystemExit as sys_exit: - return_code = sys_exit.code - assert return_code == 1, "main() should exit with error" - - -@pytest.mark.parametrize("dividend, divisor, quotient", div_params) -def test_division(dividend, divisor, quotient): - """Verify division results.""" - result = example.example_div(dividend, divisor) - assert result == quotient, "result should equal quotient" - - -@pytest.mark.slow -def test_slow_division(): - """Example of using a custom marker. - - This test will only be run if --runslow is passed to pytest. - Look in conftest.py to see how this is implemented. - """ - # Standard Python Libraries - import time - - result = example.example_div(256, 16) - time.sleep(4) - assert result == 16, "result should equal be 16" - - -def test_zero_division(): - """Verify that division by zero throws the correct exception.""" - with pytest.raises(ZeroDivisionError): - example.example_div(1, 0) - - -def test_zero_divisor_argument(): - """Verify that a divisor of zero is handled as expected.""" - with patch.object(sys, "argv", ["bogus", "1", "0"]): - return_code = None - try: - example.example.main() - except SystemExit as sys_exit: - return_code = sys_exit.code - assert return_code == 1, "main() should exit with error" diff --git a/tests/test_guacscanner.py b/tests/test_guacscanner.py new file mode 100644 index 00000000..da1b41a8 --- /dev/null +++ b/tests/test_guacscanner.py @@ -0,0 +1,501 @@ +#!/usr/bin/env pytest -vs +"""Tests for guacscanner.""" + +# Standard Python Libraries +import logging +import os +import sys +from unittest.mock import MagicMock, patch + +# Third-Party Libraries +import boto3 +from moto import mock_ec2 +import psycopg +import pytest + +# cisagov Libraries +import guacscanner + +log_levels = ( + "debug", + "info", + "warning", + "error", + "critical", +) + +# define sources of version strings +RELEASE_TAG = os.getenv("RELEASE_TAG") +PROJECT_VERSION = guacscanner.__version__ + +DUMMY_VPC_ID = "vpc-0123456789abcdef0" + + +def test_stdout_version(capsys): + """Verify that version string sent to stdout agrees with the module version.""" + with pytest.raises(SystemExit): + with patch.object(sys, "argv", ["bogus", "--version"]): + guacscanner.guacscanner.main() + captured = capsys.readouterr() + assert ( + captured.out == f"{PROJECT_VERSION}\n" + ), "standard output by '--version' should agree with module.__version__" + + +def test_running_as_module(capsys): + """Verify that the __main__.py file loads correctly.""" + with pytest.raises(SystemExit): + with patch.object(sys, "argv", ["bogus", "--version"]): + # F401 is a "Module imported but unused" warning. This import + # emulates how this project would be run as a module. The only thing + # being done by __main__ is importing the main entrypoint of the + # package and running it, so there is nothing to use from this + # import. As a result, we can safely ignore this warning. + # cisagov Libraries + import guacscanner.__main__ # noqa: F401 + captured = capsys.readouterr() + assert ( + captured.out == f"{PROJECT_VERSION}\n" + ), "standard output by '--version' should agree with module.__version__" + + +@pytest.mark.skipif( + RELEASE_TAG in [None, ""], reason="this is not a release (RELEASE_TAG not set)" +) +def test_release_version(): + """Verify that release tag version agrees with the module version.""" + assert ( + RELEASE_TAG == f"v{PROJECT_VERSION}" + ), "RELEASE_TAG does not match the project version" + + +@mock_ec2 +@pytest.mark.parametrize("level", log_levels) +def test_log_levels(level): + """Validate commandline log-level arguments.""" + with patch.object( + sys, + "argv", + [ + f"--log-level={level}", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={DUMMY_VPC_ID}", + ], + ): + with patch.object(logging.root, "handlers", []): + with patch.object(psycopg, "connect", return_value=MagicMock()): + assert ( + logging.root.hasHandlers() is False + ), "root logger should not have handlers yet" + return_code = None + try: + guacscanner.guacscanner.main() + except SystemExit as sys_exit: + return_code = sys_exit.code + assert return_code is None, "main() should return success" + assert ( + logging.root.hasHandlers() is True + ), "root logger should now have a handler" + assert ( + logging.getLevelName(logging.root.getEffectiveLevel()) + == level.upper() + ), f"root logger level should be set to {level.upper()}" + assert return_code is None, "main() should return success" + + +def test_bad_log_level(): + """Validate bad log-level argument returns error.""" + with patch.object(sys, "argv", ["bogus", "--log-level=emergency"]): + return_code = None + try: + guacscanner.guacscanner.main() + except SystemExit as sys_exit: + return_code = sys_exit.code + assert return_code == 1, "main() should exit with error" + + +@mock_ec2 +def test_addition_of_guacuser(): + """Verify that adding the guacuser works as expected.""" + # Create a VPC + ec2 = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.19.74.0/24") + vpc_id = vpc["Vpc"]["VpcId"] + + # Mock the PostgreSQL database connection + mock_connection = MagicMock( + name="Mock PostgreSQL connection", spec_set=psycopg.Connection + ) + mock_cursor = MagicMock(name="Mock PostgreSQL cursor", spec_set=psycopg.Cursor) + mock_cursor.__enter__.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + # Checking to see if guacuser exists and then adding it + {"count": 0}, + {"entity_id": 1}, + ] + mock_connection.__enter__.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + with patch.object( + sys, + "argv", + [ + "--log-level=debug", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={vpc_id}", + ], + ): + with patch.object( + psycopg, "connect", return_value=mock_connection + ) as mock_connect: + guacscanner.guacscanner.main() + mock_connect.assert_called_once() + mock_connection.cursor.assert_called() + mock_connection.commit.assert_called() + mock_cursor.fetchone.assert_called() + mock_cursor.execute.assert_called() + + +@mock_ec2 +def test_guacuser_already_exists(): + """Verify that the case where the guacuser already exists works as expected.""" + # Create a VPC + ec2 = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.19.74.0/24") + vpc_id = vpc["Vpc"]["VpcId"] + + # Mock the PostgreSQL database connection + mock_connection = MagicMock( + name="Mock PostgreSQL connection", spec_set=psycopg.Connection + ) + mock_cursor = MagicMock(name="Mock PostgreSQL cursor", spec_set=psycopg.Cursor) + mock_cursor.__enter__.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + # Checking to see if guacuser exists and then fetching its ID + {"count": 1}, + {"entity_id": 1}, + ] + mock_connection.__enter__.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + with patch.object( + sys, + "argv", + [ + "--log-level=debug", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={vpc_id}", + ], + ): + with patch.object( + psycopg, "connect", return_value=mock_connection + ) as mock_connect: + guacscanner.guacscanner.main() + mock_connect.assert_called_once() + mock_connection.cursor.assert_called() + mock_connection.commit.assert_called() + mock_cursor.fetchone.assert_called() + mock_cursor.execute.assert_called() + + +@mock_ec2 +def test_new_linux_instance(): + """Verify that adding a new Linux instance works as expected.""" + # Create and populate a VPC with an EC2 instance + # + # TODO: Create a test fixture to reduce duplication of this EC2 + # setup code. See cisagov/guacscanner#7 for more details. + ec2 = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.19.74.0/24") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = ec2.create_subnet(CidrBlock="10.19.74.0/24", VpcId=vpc_id) + subnet_id = subnet["Subnet"]["SubnetId"] + amis = ec2.describe_images( + Filters=[ + {"Name": "Name", "Values": ["amzn-ami-hvm-2017.09.1.20171103-x86_64-gp2"]} + ] + ) + ami = amis["Images"][0] + ami_id = ami["ImageId"] + ec2.run_instances( + ImageId=ami_id, + SubnetId=subnet_id, + MaxCount=1, + MinCount=1, + TagSpecifications=[ + {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": "Linux"}]} + ], + ) + + # Mock the PostgreSQL database connection + mock_connection = MagicMock( + name="Mock PostgreSQL connection", spec_set=psycopg.Connection + ) + mock_cursor = MagicMock(name="Mock PostgreSQL cursor", spec_set=psycopg.Cursor) + mock_cursor.__enter__.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + # Checking to see if guacuser exists and then adding it + {"count": 0}, + {"entity_id": 1}, + # Checking to see if the connection exists and then adding it + {"count": 0}, + {"connection_id": 1}, + ] + mock_connection.__enter__.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + with patch.object( + sys, + "argv", + [ + "--log-level=debug", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={vpc_id}", + ], + ): + with patch.object( + psycopg, "connect", return_value=mock_connection + ) as mock_connect: + guacscanner.guacscanner.main() + mock_connect.assert_called_once() + mock_connection.cursor.assert_called() + mock_connection.commit.assert_called() + mock_cursor.fetchone.assert_called() + mock_cursor.execute.assert_called() + mock_cursor.executemany.assert_called() + + +@mock_ec2 +def test_terminated_instance(): + """Verify that adding a terminated instance works as expected.""" + # Create and populate a VPC with a terminated EC2 instance + ec2 = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.19.74.0/24") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = ec2.create_subnet(CidrBlock="10.19.74.0/24", VpcId=vpc_id) + subnet_id = subnet["Subnet"]["SubnetId"] + amis = ec2.describe_images( + Filters=[ + {"Name": "Name", "Values": ["amzn-ami-hvm-2017.09.1.20171103-x86_64-gp2"]} + ] + ) + ami = amis["Images"][0] + ami_id = ami["ImageId"] + instances = ec2.run_instances( + ImageId=ami_id, + SubnetId=subnet_id, + MaxCount=1, + MinCount=1, + TagSpecifications=[ + {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": "Linux"}]} + ], + ) + instance_id = instances["Instances"][0]["InstanceId"] + ec2.terminate_instances(InstanceIds=[instance_id]) + + # Mock the PostgreSQL database connection + mock_connection = MagicMock( + name="Mock PostgreSQL connection", spec_set=psycopg.Connection + ) + mock_cursor = MagicMock(name="Mock PostgreSQL cursor", spec_set=psycopg.Cursor) + mock_cursor.__enter__.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + # Checking to see if guacuser exists and then adding it + {"count": 0}, + {"entity_id": 1}, + ] + mock_connection.__enter__.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + with patch.object( + sys, + "argv", + [ + "--log-level=debug", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={vpc_id}", + ], + ): + with patch.object( + psycopg, "connect", return_value=mock_connection + ) as mock_connect: + guacscanner.guacscanner.main() + mock_connect.assert_called_once() + mock_connection.cursor.assert_called() + mock_connection.commit.assert_called() + mock_cursor.fetchone.assert_called() + mock_cursor.execute.assert_called() + mock_cursor.executemany.assert_not_called() + + +@mock_ec2 +def test_stopped_instance(): + """Verify that adding a stopped instance works as expected.""" + # Create and populate a VPC with a stopped EC2 instance + ec2 = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.19.74.0/24") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = ec2.create_subnet(CidrBlock="10.19.74.0/24", VpcId=vpc_id) + subnet_id = subnet["Subnet"]["SubnetId"] + amis = ec2.describe_images( + Filters=[ + {"Name": "Name", "Values": ["amzn-ami-hvm-2017.09.1.20171103-x86_64-gp2"]} + ] + ) + ami = amis["Images"][0] + ami_id = ami["ImageId"] + instances = ec2.run_instances( + ImageId=ami_id, + SubnetId=subnet_id, + MaxCount=1, + MinCount=1, + TagSpecifications=[ + {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": "Linux"}]} + ], + ) + instance_id = instances["Instances"][0]["InstanceId"] + ec2.stop_instances(InstanceIds=[instance_id]) + + # Mock the PostgreSQL database connection + mock_connection = MagicMock( + name="Mock PostgreSQL connection", spec_set=psycopg.Connection + ) + mock_cursor = MagicMock(name="Mock PostgreSQL cursor", spec_set=psycopg.Cursor) + mock_cursor.__enter__.return_value = mock_cursor + mock_connection.__enter__.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + with patch.object( + sys, + "argv", + [ + "--log-level=debug", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={vpc_id}", + ], + ): + with patch.object( + psycopg, "connect", return_value=mock_connection + ) as mock_connect: + guacscanner.guacscanner.main() + mock_connect.assert_called_once() + mock_connection.cursor.assert_called() + mock_connection.commit.assert_called() + + +@mock_ec2 +def test_new_windows_instance(): + """Verify that adding a new Windows instance works as expected.""" + # Create and populate a VPC with an EC2 instance + ec2 = boto3.client("ec2", "us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.19.74.0/24") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = ec2.create_subnet(CidrBlock="10.19.74.0/24", VpcId=vpc_id) + subnet_id = subnet["Subnet"]["SubnetId"] + amis = ec2.describe_images( + Filters=[ + { + "Name": "Name", + "Values": [ + "Windows_Server-2016-English-Full-SQL_2017_Enterprise-2017.10.13" + ], + } + ] + ) + ami = amis["Images"][0] + ami_id = ami["ImageId"] + ec2.run_instances( + ImageId=ami_id, + SubnetId=subnet_id, + MaxCount=1, + MinCount=1, + TagSpecifications=[ + {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": "Windows"}]} + ], + ) + + # Mock the PostgreSQL database connection + mock_connection = MagicMock( + name="Mock PostgreSQL connection", spec_set=psycopg.Connection + ) + mock_cursor = MagicMock(name="Mock PostgreSQL cursor", spec_set=psycopg.Cursor) + mock_cursor.__enter__.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + # Checking to see if guacuser exists and then adding it + {"count": 0}, + {"entity_id": 1}, + # Checking to see if the connection exists and then adding it + {"count": 0}, + {"connection_id": 1}, + ] + mock_connection.__enter__.return_value = mock_connection + mock_connection.cursor.return_value = mock_cursor + + with patch.object( + sys, + "argv", + [ + "--log-level=debug", + "--oneshot", + "--postgres-password=dummy_db_password", + "--postgres-username=dummy_db_username", + "--private-ssh-key=dummy_key", + "--rdp-password=dummy_rdp_password", + "--rdp-username=dummy_rdp_username", + "--vnc-password=dummy_vnc_password", + "--vnc-username=dummy_vnc_username", + f"--vpc-id={vpc_id}", + ], + ): + with patch.object( + psycopg, "connect", return_value=mock_connection + ) as mock_connect: + guacscanner.guacscanner.main() + mock_connect.assert_called_once() + mock_connection.cursor.assert_called() + mock_connection.commit.assert_called() + mock_cursor.fetchone.assert_called() + mock_cursor.execute.assert_called() + mock_cursor.executemany.assert_called()