| 
1 |  | -import tempfile  | 
2 | 1 | import os  | 
3 |  | -import boto3  | 
4 |  | -import uuid  | 
 | 2 | +import random  | 
5 | 3 | import subprocess  | 
6 |  | -from typing import Dict  | 
 | 4 | +import tempfile  | 
 | 5 | +import uuid  | 
 | 6 | +from typing import Dict, Generator  | 
 | 7 | + | 
 | 8 | +import boto3  | 
 | 9 | +import pytest  | 
7 | 10 | 
 
  | 
8 | 11 | THIS_PATH = os.path.abspath(os.path.dirname(__file__))  | 
9 | 12 | ROOT_PATH = os.path.join(THIS_PATH, "..")  | 
10 | 13 | TFLOCAL_BIN = os.path.join(ROOT_PATH, "bin", "tflocal")  | 
11 | 14 | LOCALSTACK_ENDPOINT = "http://localhost:4566"  | 
12 | 15 | 
 
  | 
13 | 16 | 
 
  | 
 | 17 | +@pytest.mark.parametrize("customize_access_key", [True, False])  | 
 | 18 | +def test_customize_access_key_feature_flag(monkeypatch, customize_access_key: bool):  | 
 | 19 | +    monkeypatch.setenv("CUSTOMIZE_ACCESS_KEY", str(customize_access_key))  | 
 | 20 | + | 
 | 21 | +    # create buckets in multiple accounts  | 
 | 22 | +    access_key = mock_access_key()  | 
 | 23 | +    monkeypatch.setenv("AWS_ACCESS_KEY_ID", access_key)  | 
 | 24 | +    bucket_name = short_uid()  | 
 | 25 | + | 
 | 26 | +    create_test_bucket(bucket_name)  | 
 | 27 | + | 
 | 28 | +    s3_bucket_names_default_account = get_bucket_names()  | 
 | 29 | +    s3_bucket_names_specific_account = get_bucket_names(aws_access_key_id=access_key)  | 
 | 30 | + | 
 | 31 | +    if customize_access_key:  | 
 | 32 | +        # if CUSTOMISE_ACCESS_KEY is enabled, the bucket name is only in the specific account  | 
 | 33 | +        assert bucket_name not in s3_bucket_names_default_account  | 
 | 34 | +        assert bucket_name in s3_bucket_names_specific_account  | 
 | 35 | +    else:  | 
 | 36 | +        # if CUSTOMISE_ACCESS_KEY is disabled, the bucket name is only in the default account  | 
 | 37 | +        assert bucket_name in s3_bucket_names_default_account  | 
 | 38 | +        assert bucket_name not in s3_bucket_names_specific_account  | 
 | 39 | + | 
 | 40 | + | 
 | 41 | +def _profile_names() -> Generator:  | 
 | 42 | +    yield short_uid()  | 
 | 43 | +    yield "default"  | 
 | 44 | + | 
 | 45 | + | 
 | 46 | +def _generate_test_name(param: str) -> str:  | 
 | 47 | +    return "random" if param != "default" else param  | 
 | 48 | + | 
 | 49 | + | 
 | 50 | +@pytest.mark.parametrize("profile_name", _profile_names(), ids=_generate_test_name)  | 
 | 51 | +def test_access_key_override_by_profile(monkeypatch, profile_name: str):  | 
 | 52 | +    monkeypatch.setenv("CUSTOMIZE_ACCESS_KEY", "1")  | 
 | 53 | +    access_key = mock_access_key()  | 
 | 54 | +    bucket_name = short_uid()  | 
 | 55 | +    credentials = """  | 
 | 56 | +    [%s]  | 
 | 57 | +    aws_access_key_id = %s  | 
 | 58 | +    aws_secret_access_key = test  | 
 | 59 | +    region = eu-west-1  | 
 | 60 | +    """ % (profile_name, access_key)  | 
 | 61 | +    with tempfile.TemporaryDirectory() as temp_dir:  | 
 | 62 | +        credentials_file = os.path.join(temp_dir, "credentials")  | 
 | 63 | +        with open(credentials_file, "w") as f:  | 
 | 64 | +            f.write(credentials)  | 
 | 65 | + | 
 | 66 | +        if profile_name != "default":  | 
 | 67 | +            monkeypatch.setenv("AWS_PROFILE", profile_name)  | 
 | 68 | +        monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", credentials_file)  | 
 | 69 | + | 
 | 70 | +        create_test_bucket(bucket_name)  | 
 | 71 | + | 
 | 72 | +        extra_param = {"aws_access_key_id": None, "aws_secret_access_key": None} if profile_name == "default" else {}  | 
 | 73 | +        s3_bucket_names_specific_profile = get_bucket_names(**extra_param)  | 
 | 74 | + | 
 | 75 | +        monkeypatch.delenv("AWS_PROFILE", raising=False)  | 
 | 76 | + | 
 | 77 | +        s3_bucket_names_default_account = get_bucket_names()  | 
 | 78 | + | 
 | 79 | +        assert bucket_name in s3_bucket_names_specific_profile  | 
 | 80 | +        assert bucket_name not in s3_bucket_names_default_account  | 
 | 81 | + | 
 | 82 | + | 
 | 83 | +def test_access_key_override_by_provider(monkeypatch):  | 
 | 84 | +    monkeypatch.setenv("CUSTOMIZE_ACCESS_KEY", "1")  | 
 | 85 | +    access_key = mock_access_key()  | 
 | 86 | +    bucket_name = short_uid()  | 
 | 87 | +    create_test_bucket(bucket_name, access_key)  | 
 | 88 | + | 
 | 89 | +    s3_bucket_names_default_account = get_bucket_names()  | 
 | 90 | +    s3_bucket_names_specific_account = get_bucket_names(aws_access_key_id=access_key)  | 
 | 91 | + | 
 | 92 | +    assert bucket_name not in s3_bucket_names_default_account  | 
 | 93 | +    assert bucket_name in s3_bucket_names_specific_account  | 
 | 94 | + | 
 | 95 | + | 
14 | 96 | def test_s3_path_addressing():  | 
15 | 97 |     bucket_name = f"bucket.{short_uid()}"  | 
16 | 98 |     config = """  | 
@@ -45,7 +127,7 @@ def test_use_s3_path_style(monkeypatch):  | 
45 | 127 |     assert not use_s3_path_style()  # noqa  | 
46 | 128 | 
 
  | 
47 | 129 | 
 
  | 
48 |  | -def test_provider_aliases(monkeypatch):  | 
 | 130 | +def test_provider_aliases():  | 
49 | 131 |     queue_name1 = f"q{short_uid()}"  | 
50 | 132 |     queue_name2 = f"q{short_uid()}"  | 
51 | 133 |     config = """  | 
@@ -127,15 +209,43 @@ def deploy_tf_script(script: str, env_vars: Dict[str, str] = None):  | 
127 | 209 |         return out  | 
128 | 210 | 
 
  | 
129 | 211 | 
 
  | 
 | 212 | +def get_bucket_names(**kwargs: dict) -> list:  | 
 | 213 | +    s3 = client("s3", region_name="eu-west-1", **kwargs)  | 
 | 214 | +    s3_buckets = s3.list_buckets().get("Buckets")  | 
 | 215 | +    return [s["Name"] for s in s3_buckets]  | 
 | 216 | + | 
 | 217 | + | 
 | 218 | +def create_test_bucket(bucket_name: str, access_key: str = None) -> None:  | 
 | 219 | +    access_key_section = f'access_key = "{access_key}"' if access_key else ""  | 
 | 220 | +    config = """  | 
 | 221 | +    provider "aws" {  | 
 | 222 | +      %s  | 
 | 223 | +      region = "eu-west-1"  | 
 | 224 | +    }  | 
 | 225 | +    resource "aws_s3_bucket" "test_bucket" {  | 
 | 226 | +      bucket = "%s"  | 
 | 227 | +    }""" % (access_key_section, bucket_name)  | 
 | 228 | +    deploy_tf_script(config)  | 
 | 229 | + | 
 | 230 | + | 
130 | 231 | def short_uid() -> str:  | 
131 | 232 |     return str(uuid.uuid4())[0:8]  | 
132 | 233 | 
 
  | 
133 | 234 | 
 
  | 
 | 235 | +def mock_access_key() -> str:  | 
 | 236 | +    return str(random.randrange(999999999999)).zfill(12)  | 
 | 237 | + | 
 | 238 | + | 
134 | 239 | def client(service: str, **kwargs):  | 
 | 240 | +    # if aws access key is not set AND no profile is in the environment,  | 
 | 241 | +    # we want to set the accesss key and the secret key to test  | 
 | 242 | +    if "aws_access_key_id" not in kwargs and "AWS_PROFILE" not in os.environ:  | 
 | 243 | +        kwargs["aws_access_key_id"] = "test"  | 
 | 244 | +    if "aws_access_key_id" in kwargs and "aws_secret_access_key" not in kwargs:  | 
 | 245 | +        kwargs["aws_secret_access_key"] = "test"  | 
 | 246 | +    boto3.setup_default_session()  | 
135 | 247 |     return boto3.client(  | 
136 | 248 |         service,  | 
137 |  | -        aws_access_key_id="test",  | 
138 |  | -        aws_secret_access_key="test",  | 
139 | 249 |         endpoint_url=LOCALSTACK_ENDPOINT,  | 
140 | 250 |         **kwargs,  | 
141 | 251 |     )  | 
 | 
0 commit comments