Skip to content

Commit

Permalink
cleanup connect helpers and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BerndSchuller committed Feb 23, 2024
1 parent 7b66543 commit 8243110
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 34 deletions.
26 changes: 11 additions & 15 deletions pyunicore/helpers/connection/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,30 @@


def connect_to_registry(
registry_url: str, credentials: credentials.Credential
registry_url: str, credential: credentials.Credential
) -> pyunicore.client.Registry:
"""Connect to a registry.
Args:
registry_url (str): URL to the UNICORE registry.
credentials (pyunicore.credentials.Credential): Authentication method.
credential (pyunicore.credentials.Credential): Authentication method.
Returns:
pyunicore.client.Registry
"""
transport = pyunicore.client.Transport(credentials)
return pyunicore.client.Registry(transport=transport, url=registry_url)
return pyunicore.client.Registry(credential, url=registry_url)


def connect_to_site_from_registry(
registry_url: str, site_name: str, credentials: credentials.Credential
registry_url: str, site_name: str, credential: credentials.Credential
) -> pyunicore.client.Client:
"""Create a connection to a site's UNICORE API from the registry base URL.
Args:
registry_url (str): URL to the UNICORE registry.
site_name (str): Name of the site to connect to.
credentials (pyunicore.credentials.Credential): Authentication method.
credential (pyunicore.credentials.Credential): Authentication method.
Raises:
ValueError: Site not available in the registry.
Expand All @@ -39,25 +38,22 @@ def connect_to_site_from_registry(
pyunicore.client.Client
"""
transport = pyunicore.client.Transport(credentials)
site_api_url = _get_site_api_url(
site=site_name,
registry_url=registry_url,
transport=transport,
site=site_name, credential=credential, registry_url=registry_url
)
client = _site.connect_to_site(
site_api_url=site_api_url,
credentials=credentials,
credential=credential,
)
return client


def _get_site_api_url(
site: str,
transport: pyunicore.client.Transport,
credential: credentials.Credential,
registry_url: str,
) -> str:
api_urls = _get_api_urls(transport=transport, registry_url=registry_url)
api_urls = _get_api_urls(credential, registry_url=registry_url)
try:
api_url = api_urls[site]
except KeyError:
Expand All @@ -70,6 +66,6 @@ def _get_site_api_url(
return api_url


def _get_api_urls(transport: pyunicore.client.Transport, registry_url: str) -> Dict[str, str]:
registry = pyunicore.client.Registry(transport=transport, url=registry_url)
def _get_api_urls(credential: credentials.Credential, registry_url: str) -> Dict[str, str]:
registry = pyunicore.client.Registry(credential, url=registry_url)
return registry.site_urls
13 changes: 6 additions & 7 deletions pyunicore/helpers/connection/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def connect_to_site(
site_api_url: str, credentials: _credentials.Credential
site_api_url: str, credential: _credentials.Credential
) -> pyunicore.client.Client:
"""Create a connection to a site's UNICORE API.
Expand All @@ -20,7 +20,7 @@ def connect_to_site(
"""
client = _connect_to_site(
api_url=site_api_url,
credentials=credentials,
credential=credential,
)
if _authentication_failed(client):
raise _credentials.AuthenticationFailedException(
Expand All @@ -30,14 +30,13 @@ def connect_to_site(
return client


def _connect_to_site(api_url: str, credentials: _credentials.Credential) -> pyunicore.client.Client:
transport = pyunicore.client.Transport(credentials)
client = _create_client(transport=transport, api_url=api_url)
def _connect_to_site(api_url: str, credential: _credentials.Credential) -> pyunicore.client.Client:
client = _create_client(credential=credential, api_url=api_url)
return client


def _create_client(transport: pyunicore.client.Transport, api_url: str) -> pyunicore.client.Client:
return pyunicore.client.Client(transport=transport, site_url=api_url)
def _create_client(credential: _credentials.Credential, api_url: str) -> pyunicore.client.Client:
return pyunicore.client.Client(credential, site_url=api_url)


def _authentication_failed(client: pyunicore.client.Client) -> bool:
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/helpers/connection/test_registry_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


@pytest.fixture()
def transport():
return testing.FakeTransport()
def credential():
return credentials.UsernamePassword(username="test_user", password="test_password")


def create_fake_registry(contains: Dict[str, str]) -> functools.partial:
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_connect_to_registry(monkeypatch):

result = _registry.connect_to_registry(
registry_url=registry_url,
credentials=creds,
credential=creds,
)

assert isinstance(result, pyunicore.client.Registry)
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_connect_to_site_from_registry(monkeypatch, login_successful, expected):
result = _registry.connect_to_site_from_registry(
registry_url=registry_url,
site_name=site,
credentials=creds,
credential=creds,
)

assert isinstance(result, expected)
Expand All @@ -89,7 +89,7 @@ def test_connect_to_site_from_registry(monkeypatch, login_successful, expected):
("test_unavailable_site", ValueError()),
],
)
def test_get_site_api_url_from_registry(monkeypatch, transport, site, expected):
def test_get_site_api_url_from_registry(monkeypatch, credential, site, expected):
monkeypatch.setattr(
pyunicore.client,
"Registry",
Expand All @@ -99,7 +99,7 @@ def test_get_site_api_url_from_registry(monkeypatch, transport, site, expected):
with testing.expect_raise_if_exception(expected):
result = _registry._get_site_api_url(
site=site,
transport=transport,
credential=credential,
registry_url="test_registry_url",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/helpers/connection/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_connect_to_site(monkeypatch, login_successful, expected):
with testing.expect_raise_if_exception(expected):
result = _connect.connect_to_site(
site_api_url=api_url,
credentials=creds,
credential=creds,
)

assert isinstance(result, expected)
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ def test_transport(self):
credential = uc_credentials.UsernamePassword("demouser", "test123")
transport = Transport(credential)
self.assertEqual(header_val, transport._headers({})["Authorization"])
# old style
transport = Transport(token_str, oidc=False)
self.assertEqual(header_val, transport._headers({})["Authorization"])
transport2 = transport._clone()
self.assertEqual(header_val, transport2._headers({})["Authorization"])


class MockRefresh:
Expand Down

0 comments on commit 8243110

Please sign in to comment.