diff --git a/.gitignore b/.gitignore index cdfa6ebf..b735760c 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ flake8.report junit*.xml doc/build .cache +.idea/ diff --git a/test/inventory.yml b/test/inventory.yml new file mode 100644 index 00000000..3dc844ae --- /dev/null +++ b/test/inventory.yml @@ -0,0 +1,33 @@ +all: + hosts: + s1: + ansible_host: 192.168.1.1 + s2: + ansible_host: 192.168.1.2 + s3: + ansible_host: 192.168.1.3 + s4: + ansible_host: 192.168.1.4 + s5: + ansible_host: 192.168.1.5 + s6: + ansible_host: 192.168.1.6 + s7: + ansible_host: 192.168.1.7 + s8: + ansible_host: 192.168.1.8 + s9: + ansible_host: 192.168.1.9 + + children: + servers: + hosts: + s1: + s2: + s3: + s4: + s5: + s6: + s7: + s8: + s9: diff --git a/test/test_backends.py b/test/test_backends.py index c3b717d5..30325ab5 100644 --- a/test/test_backends.py +++ b/test/test_backends.py @@ -523,6 +523,14 @@ def test_docker_encoding(host): def test_parse_hostspec(hostspec, expected): assert BaseBackend.parse_hostspec(hostspec) == expected +@pytest.mark.parametrize( + "hostspec,expected", + [ + ("ansible://host1", ('host1', {'connection': 'ansible'})), + ], +) +def test_init_parse_hostspec(hostspec, expected): + assert testinfra.backend.parse_hostspec(hostspec) == expected @pytest.mark.parametrize( "hostspec,pod,container,namespace,kubeconfig,context", @@ -642,6 +650,28 @@ def test_get_hosts(): ] +def test_get_hosts_ansible_limit(): + # Hosts returned by get_host must be deduplicated (by name & kwargs) and in + # same order as asked + hosts = testinfra.backend.get_backends( + [ + "ansible://s%5B1-4%5D%2A?ansible_inventory=inventory.yml" # s%5B1-4%5D%2A == s[1-4]* + ] + ) + assert [h.hostname for h in hosts] == ["s1", "s2", "s3", "s4"] + +def test_get_hosts_ansible_limit_from_kwargs(): + # Hosts returned by get_host must be deduplicated (by name & kwargs) and in + # same order as asked + hosts = testinfra.backend.get_backends( + [ + "ansible://all" + ], + ansible_inventory="inventory.yml", + ansible_limit="s[5-8]*" + ) + assert [h.hostname for h in hosts] == ["s5", "s6", "s7", "s8"] + @pytest.mark.testinfra_hosts(*HOSTS) def test_command_deadlock(host): # Test for deadlock when exceeding Paramiko transport buffer (2MB) diff --git a/testinfra/backend/__init__.py b/testinfra/backend/__init__.py index e8bdf91f..625ed26a 100644 --- a/testinfra/backend/__init__.py +++ b/testinfra/backend/__init__.py @@ -96,6 +96,7 @@ def get_backends( backends = {} for hostspec in hosts: host, kw = parse_hostspec(hostspec) + host = urllib.parse.unquote(host) for k, v in kwargs.items(): kw.setdefault(k, v) connection = kw.get("connection") diff --git a/testinfra/backend/ansible.py b/testinfra/backend/ansible.py index 61bf25f8..5c95685b 100644 --- a/testinfra/backend/ansible.py +++ b/testinfra/backend/ansible.py @@ -17,6 +17,9 @@ from testinfra.backend import base from testinfra.utils.ansible_runner import AnsibleRunner +from ansible.parsing.dataloader import DataLoader +from ansible.inventory.manager import InventoryManager + logger = logging.getLogger("testinfra") @@ -87,4 +90,11 @@ def get_variables(self) -> dict[str, Any]: @classmethod def get_hosts(cls, host: str, **kwargs: Any) -> list[str]: inventory = kwargs.get("ansible_inventory") - return AnsibleRunner.get_runner(inventory).get_hosts(host or "all") + hosts = AnsibleRunner.get_runner(inventory).get_hosts(host or "all") + limit = kwargs.get("ansible_limit") + if limit: + loader = DataLoader() + inventory_manager = InventoryManager(loader=loader, sources=inventory) + return list(map(lambda h: h.address, inventory_manager.get_hosts(pattern=limit))) + else: + return hosts diff --git a/testinfra/plugin.py b/testinfra/plugin.py index db3541b0..c43157d3 100644 --- a/testinfra/plugin.py +++ b/testinfra/plugin.py @@ -107,6 +107,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: dest="nagios", help="Nagios plugin", ) + group.addoption( + "--ansible-limit", + action="store_true", + dest="ansible_limit", + help="Limit to specific hosts using the same syntax as Ansible's --limit option.", + ) def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: @@ -126,6 +132,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: sudo_user=metafunc.config.option.sudo_user, ansible_inventory=metafunc.config.option.ansible_inventory, force_ansible=metafunc.config.option.force_ansible, + ansible_limit=metafunc.config.option.ansible_limit, ) params = sorted(params, key=lambda x: x.backend.get_pytest_id()) ids = [e.backend.get_pytest_id() for e in params]