Skip to content

Commit 3a3f156

Browse files
committed
feat: enable bulk add and remove users
1 parent 7822be0 commit 3a3f156

File tree

4 files changed

+121
-15
lines changed

4 files changed

+121
-15
lines changed

Diff for: tableauserverclient/server/endpoint/groups_endpoint.py

+41-14
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,37 @@
11
import logging
22

3-
from .endpoint import QuerysetEndpoint, api
4-
from .exceptions import MissingRequiredFieldError
3+
from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api
4+
from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError
55
from tableauserverclient.server import RequestFactory
66
from tableauserverclient.models import GroupItem, UserItem, PaginationItem, JobItem
7-
from ..pager import Pager
7+
from tableauserverclient.server.pager import Pager
88

99
from tableauserverclient.helpers.logging import logger
1010

11-
from typing import List, Optional, TYPE_CHECKING, Tuple, Union
11+
from typing import Iterable, List, Optional, TYPE_CHECKING, Tuple, Union
1212

1313
if TYPE_CHECKING:
14-
from ..request_options import RequestOptions
14+
from tableauserverclient.server.request_options import RequestOptions
1515

1616

1717
class Groups(QuerysetEndpoint[GroupItem]):
1818
@property
1919
def baseurl(self) -> str:
2020
return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id)
2121

22-
# Gets all groups
2322
@api(version="2.0")
2423
def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[GroupItem], PaginationItem]:
24+
"""Gets all groups"""
2525
logger.info("Querying all groups on site")
2626
url = self.baseurl
2727
server_response = self.get_request(url, req_options)
2828
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
2929
all_group_items = GroupItem.from_response(server_response.content, self.parent_srv.namespace)
3030
return all_group_items, pagination_item
3131

32-
# Gets all users in a given group
3332
@api(version="2.0")
34-
def populate_users(self, group_item, req_options: Optional["RequestOptions"] = None) -> None:
33+
def populate_users(self, group_item: GroupItem, req_options: Optional["RequestOptions"] = None) -> None:
34+
"""Gets all users in a given group"""
3535
if not group_item.id:
3636
error = "Group item missing ID. Group must be retrieved from server first."
3737
raise MissingRequiredFieldError(error)
@@ -47,7 +47,7 @@ def user_pager():
4747
group_item._set_users(user_pager)
4848

4949
def _get_users_for_group(
50-
self, group_item, req_options: Optional["RequestOptions"] = None
50+
self, group_item: GroupItem, req_options: Optional["RequestOptions"] = None
5151
) -> Tuple[List[UserItem], PaginationItem]:
5252
url = "{0}/{1}/users".format(self.baseurl, group_item.id)
5353
server_response = self.get_request(url, req_options)
@@ -56,9 +56,9 @@ def _get_users_for_group(
5656
logger.info("Populated users for group (ID: {0})".format(group_item.id))
5757
return user_item, pagination_item
5858

59-
# Deletes 1 group by id
6059
@api(version="2.0")
6160
def delete(self, group_id: str) -> None:
61+
"""Deletes 1 group by id"""
6262
if not group_id:
6363
error = "Group ID undefined."
6464
raise ValueError(error)
@@ -87,17 +87,17 @@ def update(self, group_item: GroupItem, as_job: bool = False) -> Union[GroupItem
8787
else:
8888
return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0]
8989

90-
# Create a 'local' Tableau group
9190
@api(version="2.0")
9291
def create(self, group_item: GroupItem) -> GroupItem:
92+
"""Create a 'local' Tableau group"""
9393
url = self.baseurl
9494
create_req = RequestFactory.Group.create_local_req(group_item)
9595
server_response = self.post_request(url, create_req)
9696
return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0]
9797

98-
# Create a group based on Active Directory
9998
@api(version="2.0")
10099
def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[GroupItem, JobItem]:
100+
"""Create a group based on Active Directory"""
101101
asJobparameter = "?asJob=true" if asJob else ""
102102
url = self.baseurl + asJobparameter
103103
create_req = RequestFactory.Group.create_ad_req(group_item)
@@ -107,9 +107,9 @@ def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[G
107107
else:
108108
return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0]
109109

110-
# Removes 1 user from 1 group
111110
@api(version="2.0")
112111
def remove_user(self, group_item: GroupItem, user_id: str) -> None:
112+
"""Removes 1 user from 1 group"""
113113
if not group_item.id:
114114
error = "Group item missing ID."
115115
raise MissingRequiredFieldError(error)
@@ -120,9 +120,22 @@ def remove_user(self, group_item: GroupItem, user_id: str) -> None:
120120
self.delete_request(url)
121121
logger.info("Removed user (id: {0}) from group (ID: {1})".format(user_id, group_item.id))
122122

123-
# Adds 1 user to 1 group
123+
@api(version="3.21")
124+
def remove_users(self, group_item: GroupItem, users: Iterable[Union[str, UserItem]]) -> None:
125+
"""Removes multiple users from 1 group"""
126+
group_id = group_item.id if hasattr(group_item, "id") else group_item
127+
if not isinstance(group_id, str):
128+
raise ValueError(f"Invalid group provided: {group_item}")
129+
130+
url = f"{self.baseurl}/{group_id}/users/remove"
131+
add_req = RequestFactory.Group.remove_users_req(users)
132+
_ = self.put_request(url, add_req)
133+
logger.info("Removed users to group (ID: {0})".format(group_item.id))
134+
return None
135+
124136
@api(version="2.0")
125137
def add_user(self, group_item: GroupItem, user_id: str) -> UserItem:
138+
"""Adds 1 user to 1 group"""
126139
if not group_item.id:
127140
error = "Group item missing ID."
128141
raise MissingRequiredFieldError(error)
@@ -135,3 +148,17 @@ def add_user(self, group_item: GroupItem, user_id: str) -> UserItem:
135148
user = UserItem.from_response(server_response.content, self.parent_srv.namespace).pop()
136149
logger.info("Added user (id: {0}) to group (ID: {1})".format(user_id, group_item.id))
137150
return user
151+
152+
@api(version="3.21")
153+
def add_users(self, group_item: GroupItem, users: Iterable[Union[str, UserItem]]) -> List[UserItem]:
154+
"""Adds multiple users to 1 group"""
155+
group_id = group_item.id if hasattr(group_item, "id") else group_item
156+
if not isinstance(group_id, str):
157+
raise ValueError(f"Invalid group provided: {group_item}")
158+
159+
url = f"{self.baseurl}/{group_id}/users"
160+
add_req = RequestFactory.Group.add_users_req(users)
161+
server_response = self.post_request(url, add_req)
162+
users = UserItem.from_response(server_response.content, self.parent_srv.namespace)
163+
logger.info("Added users to group (ID: {0})".format(group_item.id))
164+
return users

Diff for: tableauserverclient/server/request_factory.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import xml.etree.ElementTree as ET
2-
from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING
2+
from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union
33

44
from requests.packages.urllib3.fields import RequestField
55
from requests.packages.urllib3.filepost import encode_multipart_formdata
@@ -387,6 +387,28 @@ def add_user_req(self, user_id: str) -> bytes:
387387
user_element.attrib["id"] = user_id
388388
return ET.tostring(xml_request)
389389

390+
@_tsrequest_wrapped
391+
def add_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes:
392+
users_element = ET.SubElement(xml_request, "users")
393+
for user in users:
394+
user_element = ET.SubElement(users_element, "user")
395+
if not (user_id := user.id if isinstance(user, UserItem) else user):
396+
raise ValueError("User ID must be populated")
397+
user_element.attrib["id"] = user_id
398+
399+
return ET.tostring(xml_request)
400+
401+
@_tsrequest_wrapped
402+
def remove_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes:
403+
users_element = ET.SubElement(xml_request, "users")
404+
for user in users:
405+
user_element = ET.SubElement(users_element, "user")
406+
if not (user_id := user.id if isinstance(user, UserItem) else user):
407+
raise ValueError("User ID must be populated")
408+
user_element.attrib["id"] = user_id
409+
410+
return ET.tostring(xml_request)
411+
390412
def create_local_req(self, group_item: GroupItem) -> bytes:
391413
xml_request = ET.Element("tsRequest")
392414
group_element = ET.SubElement(xml_request, "group")

Diff for: test/assets/group_add_users.xml

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<tsResponse xmlns="http://tableau.com/api" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://tableau.com/api http://tableau.com/api/ts-api-2.3.xsd">
3+
<users>
4+
<user id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7" name="Alice" siteRole="ServerAdministrator" />
5+
<user id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8" name="Bob" siteRole="Explorer" />
6+
<user id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8" name="Charlie" siteRole="Viewer" />
7+
</users>
8+
</tsResponse>

Diff for: test/test_group.py

+49
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
POPULATE_USERS = os.path.join(TEST_ASSET_DIR, "group_populate_users.xml")
1515
POPULATE_USERS_EMPTY = os.path.join(TEST_ASSET_DIR, "group_populate_users_empty.xml")
1616
ADD_USER = os.path.join(TEST_ASSET_DIR, "group_add_user.xml")
17+
ADD_USERS = TEST_ASSET_DIR / "group_add_users.xml"
1718
ADD_USER_POPULATE = os.path.join(TEST_ASSET_DIR, "group_users_added.xml")
1819
CREATE_GROUP = os.path.join(TEST_ASSET_DIR, "group_create.xml")
1920
CREATE_GROUP_AD = os.path.join(TEST_ASSET_DIR, "group_create_ad.xml")
@@ -123,6 +124,54 @@ def test_add_user(self) -> None:
123124
self.assertEqual("testuser", user.name)
124125
self.assertEqual("ServerAdministrator", user.site_role)
125126

127+
def test_add_users(self) -> None:
128+
self.server.version = "3.21"
129+
self.baseurl = self.server.groups.baseurl
130+
131+
def make_user(id: str, name: str, siteRole: str) -> TSC.UserItem:
132+
user = TSC.UserItem(name, siteRole)
133+
user._id = id
134+
return user
135+
136+
users = [
137+
make_user(id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7", name="Alice", siteRole="ServerAdministrator"),
138+
make_user(id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8", name="Bob", siteRole="Explorer"),
139+
make_user(id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8", name="Charlie", siteRole="Viewer"),
140+
]
141+
group = TSC.GroupItem("test")
142+
group._id = "e7833b48-c6f7-47b5-a2a7-36e7dd232758"
143+
144+
with requests_mock.mock() as m:
145+
m.post(f"{self.baseurl}/{group.id}/users", text=ADD_USERS.read_text())
146+
resp_users = self.server.groups.add_users(group, users)
147+
148+
for user, resp_user in zip(users, resp_users):
149+
with self.subTest(user=user, resp_user=resp_user):
150+
assert user.id == resp_user.id
151+
assert user.name == resp_user.name
152+
assert user.site_role == resp_user.site_role
153+
154+
def test_remove_users(self) -> None:
155+
self.server.version = "3.21"
156+
self.baseurl = self.server.groups.baseurl
157+
158+
def make_user(id: str, name: str, siteRole: str) -> TSC.UserItem:
159+
user = TSC.UserItem(name, siteRole)
160+
user._id = id
161+
return user
162+
163+
users = [
164+
make_user(id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7", name="Alice", siteRole="ServerAdministrator"),
165+
make_user(id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8", name="Bob", siteRole="Explorer"),
166+
make_user(id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8", name="Charlie", siteRole="Viewer"),
167+
]
168+
group = TSC.GroupItem("test")
169+
group._id = "e7833b48-c6f7-47b5-a2a7-36e7dd232758"
170+
171+
with requests_mock.mock() as m:
172+
m.put(f"{self.baseurl}/{group.id}/users/remove")
173+
self.server.groups.remove_users(group, users)
174+
126175
def test_add_user_before_populating(self) -> None:
127176
with open(GET_XML, "rb") as f:
128177
get_xml_response = f.read().decode("utf-8")

0 commit comments

Comments
 (0)