Skip to content

Commit 56ad27b

Browse files
author
XiaoHongbo
authored
[python] Make mock REST server behave closer to real REST server (#7575)
1 parent c08827b commit 56ad27b

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

paimon-python/pypaimon/tests/rest/rest_base_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,29 @@ def test_list_partitions_paged(self):
339339
)
340340
self.assertEqual(len(result.elements), 3)
341341

342+
def test_alter_database(self):
343+
"""Test alter_database sets and removes properties."""
344+
from pypaimon.catalog.rest.property_change import PropertyChange
345+
db_name = "alter_db_test"
346+
self.rest_catalog.create_database(db_name, True)
347+
348+
# set property
349+
self.rest_catalog.alter_database(
350+
db_name,
351+
[PropertyChange.set_property("key1", "value1"),
352+
PropertyChange.set_property("key2", "value2")])
353+
db = self.rest_catalog.get_database(db_name)
354+
self.assertEqual(db.options.get("key1"), "value1")
355+
self.assertEqual(db.options.get("key2"), "value2")
356+
357+
# remove property
358+
self.rest_catalog.alter_database(
359+
db_name,
360+
[PropertyChange.remove_property("key1")])
361+
db = self.rest_catalog.get_database(db_name)
362+
self.assertNotIn("key1", db.options)
363+
self.assertEqual(db.options.get("key2"), "value2")
364+
342365
def test_list_partitions_paged_empty(self):
343366
"""Test list_partitions_paged returns empty when no partitions."""
344367
identifier = Identifier.from_string('default.test_reader_iterator')

paimon-python/pypaimon/tests/rest/rest_server.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
if TYPE_CHECKING:
3030
from pypaimon.catalog.rest.rest_token import RESTToken
3131

32-
from pypaimon.api.api_request import (AlterTableRequest, CreateDatabaseRequest,
32+
from pypaimon.api.api_request import (AlterDatabaseRequest, AlterTableRequest,
33+
CreateDatabaseRequest,
3334
CreateTableRequest, RenameTableRequest)
3435
from pypaimon.api.api_response import (ConfigResponse, GetDatabaseResponse,
3536
GetTableResponse, ListDatabasesResponse,
@@ -45,6 +46,7 @@
4546
TableAlreadyExistException)
4647
from pypaimon.catalog.rest.table_metadata import TableMetadata
4748
from pypaimon.common.identifier import Identifier
49+
from pypaimon.api.typedef import RESTAuthParameter
4850
from pypaimon.common.json_util import JSON
4951
from pypaimon import Schema
5052
from pypaimon.schema.schema_change import Actions, SchemaChange
@@ -258,11 +260,11 @@ def _handle_request(self, method: str):
258260
content_length = int(self.headers.get('Content-Length', 0))
259261
data = self.rfile.read(content_length).decode('utf-8') if content_length > 0 else ""
260262

261-
# Get headers
263+
# Get headers (case-insensitive from HTTPMessage)
264+
auth_token = self.headers.get(AUTHORIZATION_HEADER_KEY)
262265
headers = dict(self.headers)
263266

264267
# Handle authentication
265-
auth_token = headers.get(AUTHORIZATION_HEADER_KEY.lower())
266268
if not self._authenticate(auth_token, resource_path, parameters, method, data):
267269
self._send_response(401, "Unauthorized")
268270
return
@@ -292,9 +294,25 @@ def _parse_query_params(self, query: str) -> Dict[str, str]:
292294

293295
def _authenticate(self, token: str, path: str, params: Dict[str, str],
294296
method: str, data: str) -> bool:
295-
"""Authenticate request"""
296-
# Simplified authentication - always return True for mock
297-
return True
297+
"""Authenticate request by verifying Authorization header."""
298+
if server_instance.auth_provider is None:
299+
return True
300+
if path.startswith("/ram/security-credential"):
301+
return True
302+
if not token:
303+
return False
304+
rest_auth_parameter = RESTAuthParameter(
305+
method=method,
306+
path=path,
307+
data=data or "",
308+
parameters=params or {},
309+
)
310+
from pypaimon.api.auth.base import RESTAuthFunction
311+
auth_fn = RESTAuthFunction({}, server_instance.auth_provider)
312+
expected_headers = auth_fn(rest_auth_parameter)
313+
expected_token = expected_headers.get(
314+
AUTHORIZATION_HEADER_KEY, "")
315+
return token == expected_token
298316

299317
def _send_response(self, status_code: int, body: str):
300318
"""Send HTTP response"""
@@ -514,6 +532,18 @@ def _database_handle(self, method: str, data: str, database_name: str) -> Tuple[
514532
response = database
515533
return self._mock_response(response, 200)
516534

535+
elif method == "POST":
536+
request_body = JSON.from_json(data, AlterDatabaseRequest)
537+
removals = request_body.removals or []
538+
updates = request_body.updates or {}
539+
options = dict(database.options) if database.options else {}
540+
options.update(updates)
541+
for key in removals:
542+
options.pop(key, None)
543+
self.database_store[database_name] = self.mock_database(
544+
database_name, options)
545+
return self._mock_response("", 200)
546+
517547
elif method == "DELETE":
518548
del self.database_store[database_name]
519549
return self._mock_response("", 200)

0 commit comments

Comments
 (0)