From 3a3f15624c60f3cc19d96f71497b55885517a463 Mon Sep 17 00:00:00 2001 From: Jordan Woods <13803242+jorwoods@users.noreply.github.com> Date: Fri, 28 Jun 2024 21:48:14 -0500 Subject: [PATCH] feat: enable bulk add and remove users --- .../server/endpoint/groups_endpoint.py | 55 ++++++++++++++----- tableauserverclient/server/request_factory.py | 24 +++++++- test/assets/group_add_users.xml | 8 +++ test/test_group.py | 49 +++++++++++++++++ 4 files changed, 121 insertions(+), 15 deletions(-) create mode 100644 test/assets/group_add_users.xml diff --git a/tableauserverclient/server/endpoint/groups_endpoint.py b/tableauserverclient/server/endpoint/groups_endpoint.py index 2ee9fe0ab..8c1fe02a7 100644 --- a/tableauserverclient/server/endpoint/groups_endpoint.py +++ b/tableauserverclient/server/endpoint/groups_endpoint.py @@ -1,17 +1,17 @@ import logging -from .endpoint import QuerysetEndpoint, api -from .exceptions import MissingRequiredFieldError +from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api +from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError from tableauserverclient.server import RequestFactory from tableauserverclient.models import GroupItem, UserItem, PaginationItem, JobItem -from ..pager import Pager +from tableauserverclient.server.pager import Pager from tableauserverclient.helpers.logging import logger -from typing import List, Optional, TYPE_CHECKING, Tuple, Union +from typing import Iterable, List, Optional, TYPE_CHECKING, Tuple, Union if TYPE_CHECKING: - from ..request_options import RequestOptions + from tableauserverclient.server.request_options import RequestOptions class Groups(QuerysetEndpoint[GroupItem]): @@ -19,9 +19,9 @@ class Groups(QuerysetEndpoint[GroupItem]): def baseurl(self) -> str: return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id) - # Gets all groups @api(version="2.0") def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[GroupItem], PaginationItem]: + """Gets all groups""" logger.info("Querying all groups on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -29,9 +29,9 @@ def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[Grou all_group_items = GroupItem.from_response(server_response.content, self.parent_srv.namespace) return all_group_items, pagination_item - # Gets all users in a given group @api(version="2.0") - def populate_users(self, group_item, req_options: Optional["RequestOptions"] = None) -> None: + def populate_users(self, group_item: GroupItem, req_options: Optional["RequestOptions"] = None) -> None: + """Gets all users in a given group""" if not group_item.id: error = "Group item missing ID. Group must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -47,7 +47,7 @@ def user_pager(): group_item._set_users(user_pager) def _get_users_for_group( - self, group_item, req_options: Optional["RequestOptions"] = None + self, group_item: GroupItem, req_options: Optional["RequestOptions"] = None ) -> Tuple[List[UserItem], PaginationItem]: url = "{0}/{1}/users".format(self.baseurl, group_item.id) server_response = self.get_request(url, req_options) @@ -56,9 +56,9 @@ def _get_users_for_group( logger.info("Populated users for group (ID: {0})".format(group_item.id)) return user_item, pagination_item - # Deletes 1 group by id @api(version="2.0") def delete(self, group_id: str) -> None: + """Deletes 1 group by id""" if not group_id: error = "Group ID undefined." raise ValueError(error) @@ -87,17 +87,17 @@ def update(self, group_item: GroupItem, as_job: bool = False) -> Union[GroupItem else: return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0] - # Create a 'local' Tableau group @api(version="2.0") def create(self, group_item: GroupItem) -> GroupItem: + """Create a 'local' Tableau group""" url = self.baseurl create_req = RequestFactory.Group.create_local_req(group_item) server_response = self.post_request(url, create_req) return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0] - # Create a group based on Active Directory @api(version="2.0") def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[GroupItem, JobItem]: + """Create a group based on Active Directory""" asJobparameter = "?asJob=true" if asJob else "" url = self.baseurl + asJobparameter create_req = RequestFactory.Group.create_ad_req(group_item) @@ -107,9 +107,9 @@ def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[G else: return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0] - # Removes 1 user from 1 group @api(version="2.0") def remove_user(self, group_item: GroupItem, user_id: str) -> None: + """Removes 1 user from 1 group""" if not group_item.id: error = "Group item missing ID." raise MissingRequiredFieldError(error) @@ -120,9 +120,22 @@ def remove_user(self, group_item: GroupItem, user_id: str) -> None: self.delete_request(url) logger.info("Removed user (id: {0}) from group (ID: {1})".format(user_id, group_item.id)) - # Adds 1 user to 1 group + @api(version="3.21") + def remove_users(self, group_item: GroupItem, users: Iterable[Union[str, UserItem]]) -> None: + """Removes multiple users from 1 group""" + group_id = group_item.id if hasattr(group_item, "id") else group_item + if not isinstance(group_id, str): + raise ValueError(f"Invalid group provided: {group_item}") + + url = f"{self.baseurl}/{group_id}/users/remove" + add_req = RequestFactory.Group.remove_users_req(users) + _ = self.put_request(url, add_req) + logger.info("Removed users to group (ID: {0})".format(group_item.id)) + return None + @api(version="2.0") def add_user(self, group_item: GroupItem, user_id: str) -> UserItem: + """Adds 1 user to 1 group""" if not group_item.id: error = "Group item missing ID." raise MissingRequiredFieldError(error) @@ -135,3 +148,17 @@ def add_user(self, group_item: GroupItem, user_id: str) -> UserItem: user = UserItem.from_response(server_response.content, self.parent_srv.namespace).pop() logger.info("Added user (id: {0}) to group (ID: {1})".format(user_id, group_item.id)) return user + + @api(version="3.21") + def add_users(self, group_item: GroupItem, users: Iterable[Union[str, UserItem]]) -> List[UserItem]: + """Adds multiple users to 1 group""" + group_id = group_item.id if hasattr(group_item, "id") else group_item + if not isinstance(group_id, str): + raise ValueError(f"Invalid group provided: {group_item}") + + url = f"{self.baseurl}/{group_id}/users" + add_req = RequestFactory.Group.add_users_req(users) + server_response = self.post_request(url, add_req) + users = UserItem.from_response(server_response.content, self.parent_srv.namespace) + logger.info("Added users to group (ID: {0})".format(group_item.id)) + return users diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 87438ecde..7bf2118fd 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1,5 +1,5 @@ import xml.etree.ElementTree as ET -from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union from requests.packages.urllib3.fields import RequestField from requests.packages.urllib3.filepost import encode_multipart_formdata @@ -387,6 +387,28 @@ def add_user_req(self, user_id: str) -> bytes: user_element.attrib["id"] = user_id return ET.tostring(xml_request) + @_tsrequest_wrapped + def add_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes: + users_element = ET.SubElement(xml_request, "users") + for user in users: + user_element = ET.SubElement(users_element, "user") + if not (user_id := user.id if isinstance(user, UserItem) else user): + raise ValueError("User ID must be populated") + user_element.attrib["id"] = user_id + + return ET.tostring(xml_request) + + @_tsrequest_wrapped + def remove_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes: + users_element = ET.SubElement(xml_request, "users") + for user in users: + user_element = ET.SubElement(users_element, "user") + if not (user_id := user.id if isinstance(user, UserItem) else user): + raise ValueError("User ID must be populated") + user_element.attrib["id"] = user_id + + return ET.tostring(xml_request) + def create_local_req(self, group_item: GroupItem) -> bytes: xml_request = ET.Element("tsRequest") group_element = ET.SubElement(xml_request, "group") diff --git a/test/assets/group_add_users.xml b/test/assets/group_add_users.xml new file mode 100644 index 000000000..23fd7bd9f --- /dev/null +++ b/test/assets/group_add_users.xml @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/test/test_group.py b/test/test_group.py index 1edc50555..fc9c75a6d 100644 --- a/test/test_group.py +++ b/test/test_group.py @@ -14,6 +14,7 @@ POPULATE_USERS = os.path.join(TEST_ASSET_DIR, "group_populate_users.xml") POPULATE_USERS_EMPTY = os.path.join(TEST_ASSET_DIR, "group_populate_users_empty.xml") ADD_USER = os.path.join(TEST_ASSET_DIR, "group_add_user.xml") +ADD_USERS = TEST_ASSET_DIR / "group_add_users.xml" ADD_USER_POPULATE = os.path.join(TEST_ASSET_DIR, "group_users_added.xml") CREATE_GROUP = os.path.join(TEST_ASSET_DIR, "group_create.xml") CREATE_GROUP_AD = os.path.join(TEST_ASSET_DIR, "group_create_ad.xml") @@ -123,6 +124,54 @@ def test_add_user(self) -> None: self.assertEqual("testuser", user.name) self.assertEqual("ServerAdministrator", user.site_role) + def test_add_users(self) -> None: + self.server.version = "3.21" + self.baseurl = self.server.groups.baseurl + + def make_user(id: str, name: str, siteRole: str) -> TSC.UserItem: + user = TSC.UserItem(name, siteRole) + user._id = id + return user + + users = [ + make_user(id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7", name="Alice", siteRole="ServerAdministrator"), + make_user(id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8", name="Bob", siteRole="Explorer"), + make_user(id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8", name="Charlie", siteRole="Viewer"), + ] + group = TSC.GroupItem("test") + group._id = "e7833b48-c6f7-47b5-a2a7-36e7dd232758" + + with requests_mock.mock() as m: + m.post(f"{self.baseurl}/{group.id}/users", text=ADD_USERS.read_text()) + resp_users = self.server.groups.add_users(group, users) + + for user, resp_user in zip(users, resp_users): + with self.subTest(user=user, resp_user=resp_user): + assert user.id == resp_user.id + assert user.name == resp_user.name + assert user.site_role == resp_user.site_role + + def test_remove_users(self) -> None: + self.server.version = "3.21" + self.baseurl = self.server.groups.baseurl + + def make_user(id: str, name: str, siteRole: str) -> TSC.UserItem: + user = TSC.UserItem(name, siteRole) + user._id = id + return user + + users = [ + make_user(id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7", name="Alice", siteRole="ServerAdministrator"), + make_user(id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8", name="Bob", siteRole="Explorer"), + make_user(id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8", name="Charlie", siteRole="Viewer"), + ] + group = TSC.GroupItem("test") + group._id = "e7833b48-c6f7-47b5-a2a7-36e7dd232758" + + with requests_mock.mock() as m: + m.put(f"{self.baseurl}/{group.id}/users/remove") + self.server.groups.remove_users(group, users) + def test_add_user_before_populating(self) -> None: with open(GET_XML, "rb") as f: get_xml_response = f.read().decode("utf-8")