diff --git a/src/msgraph_core/requests/batch_request_builder.py b/src/msgraph_core/requests/batch_request_builder.py index 608876f0..de0e7098 100644 --- a/src/msgraph_core/requests/batch_request_builder.py +++ b/src/msgraph_core/requests/batch_request_builder.py @@ -1,5 +1,7 @@ from typing import TypeVar, Type, Dict, Optional, Union import logging +import json +import base64 from kiota_abstractions.request_adapter import RequestAdapter from kiota_abstractions.request_information import RequestInformation @@ -59,12 +61,9 @@ async def post( response_type = BatchResponseContent if isinstance(batch_request_content, BatchRequestContent): + print(f"Batch request content: {batch_request_content.requests}") request_info = await self.to_post_request_information(batch_request_content) - bytes_content = request_info.content - json_content = bytes_content.decode("utf-8") - updated_str = '{"requests":' + json_content + '}' - updated_bytes = updated_str.encode("utf-8") - request_info.content = updated_bytes + request_info.content = self._prepare_request_content(request_info.content) error_map = error_map or self.error_map response = None try: @@ -107,7 +106,12 @@ async def _post_batch_collection( batch_responses = BatchResponseContentCollection() for batch_request_content in batch_request_content_collection.batches: + print(f"Batch request content: {batch_request_content.requests}") + request_info = await self.to_post_request_information(batch_request_content) + print(f"content before processing {request_info.content}") + updated_bytes = self._prepare_request_content(request_info.content) + request_info.content = updated_bytes response = await self._request_adapter.send_async( request_info, BatchResponseContent, error_map or self.error_map ) @@ -115,6 +119,43 @@ async def _post_batch_collection( return batch_responses + def _prepare_request_content(self, content: bytes) -> bytes: + """ + Prepares the request content by updating the JSON structure and converting + the 'body' field from string to a dictionary if necessary. + + Args: + content (bytes): The original request content. + + Returns: + bytes: The updated request content. + """ + json_content = content.decode("utf-8") + print(json_content) + requests_list = json.loads(json_content) + for request in requests_list: + if 'body' in request: + if isinstance(request['body'], dict): + pass + elif isinstance(request['body'], str): + try: + request['body'] = json.loads(request['body']) + except json.JSONDecodeError: + pass + elif isinstance(request['body'], bytes): + request['body'] = base64.b64encode(request['body']).decode('utf-8') + + if isinstance(request['body'], dict): + request['headers'] = {"Content-Type": "application/json"} + else: + request['headers'] = {"Content-Type": "application/octet-stream"} + else: + request['headers'] = {"Content-Type": "application/json"} + + updated_json_content = json.dumps({"requests": requests_list}) + return updated_json_content.encode("utf-8") + # return json.dumps(requests_list).encode("utf-8") + async def to_post_request_information( self, batch_request_content: BatchRequestContent ) -> RequestInformation: @@ -131,6 +172,7 @@ async def to_post_request_information( if batch_request_content is None: raise ValueError("batch_request_content cannot be Null.") batch_request_items = list(batch_request_content.requests.values()) + print(f"Batch request items: {batch_request_items}") request_info = RequestInformation() request_info.http_method = Method.POST diff --git a/src/msgraph_core/requests/batch_request_content.py b/src/msgraph_core/requests/batch_request_content.py index 9a48ef0f..1ba93622 100644 --- a/src/msgraph_core/requests/batch_request_content.py +++ b/src/msgraph_core/requests/batch_request_content.py @@ -52,7 +52,7 @@ def add_request(self, request_id: Optional[str], request: BatchRequestItem) -> N request.id = str(uuid.uuid4()) if hasattr(request, 'depends_on') and request.depends_on: for dependent_id in request.depends_on: - if dependent_id not in [req.id for req in self.requests]: + if dependent_id not in self.requests: dependent_request = self._request_by_id(dependent_id) if dependent_request: self._requests[dependent_id] = dependent_request @@ -137,4 +137,8 @@ def serialize(self, writer: SerializationWriter) -> None: Args: writer: Serialization writer to use to serialize this model """ - writer.write_collection_of_object_values("requests", self.requests) + if not writer: + raise ValueError("writer cannot be None") + writer.write_collection_of_object_values({"requests", list(self.requests.values())}) + # requests_dict = {request_id: request for request_id, request in self.requests.items()} + # writer.write_object_value("requests", requests_dict) diff --git a/src/msgraph_core/requests/batch_request_item.py b/src/msgraph_core/requests/batch_request_item.py index 3bab3453..57e81ba7 100644 --- a/src/msgraph_core/requests/batch_request_item.py +++ b/src/msgraph_core/requests/batch_request_item.py @@ -5,6 +5,8 @@ from typing import List, Optional, Dict, Union, Any from io import BytesIO import base64 +import logging + import urllib.request from urllib.parse import urlparse @@ -238,16 +240,25 @@ def serialize(self, writer: SerializationWriter) -> None: Args: writer (SerializationWriter): The writer to write to. """ + if not writer: + raise ValueError("writer cannot be None") + writer.write_str_value('id', self.id) writer.write_str_value('method', self.method) writer.write_str_value('url', self.url) + writer.write_collection_of_primitive_values('depends_on', self._depends_on) - headers = {key: ", ".join(val) for key, val in self._headers.items()} + + headers = self._headers writer.write_collection_of_object_values('headers', headers) + if self._body: - json_object = json.loads(self._body) - is_json_string = json_object and isinstance(json_object, dict) - writer.write_collection_of_object_values( - 'body', - json_object if is_json_string else base64.b64encode(self._body).decode('utf-8') - ) + if isinstance(self._body, bytes): + body_content = base64.b64encode(self._body).decode('utf-8') + elif isinstance(self._body, str): + body_content = self._body + else: + raise ValueError("Unsupported body type") + writer.write_str_value('body', body_content) + else: + logging.info("Content info: there is no body to serialize") diff --git a/tests/requests/test_batch_request_content.py b/tests/requests/test_batch_request_content.py index 3ee241af..2499b68b 100644 --- a/tests/requests/test_batch_request_content.py +++ b/tests/requests/test_batch_request_content.py @@ -105,7 +105,9 @@ def test_get_field_deserializers(batch_request_content): def test_serialize(batch_request_content): writer = Mock(spec=SerializationWriter) + batch_request_content.serialize(writer) + writer.write_collection_of_object_values.assert_called_once_with( - "requests", batch_request_content.requests + "requests", list(batch_request_content.requests.values()) ) diff --git a/tests/requests/test_batch_request_item.py b/tests/requests/test_batch_request_item.py index 2dd3d863..005bcf44 100644 --- a/tests/requests/test_batch_request_item.py +++ b/tests/requests/test_batch_request_item.py @@ -1,5 +1,8 @@ import pytest from unittest.mock import Mock +import base64 +import json + from urllib.request import Request from kiota_abstractions.request_information import RequestInformation from kiota_abstractions.method import Method @@ -25,6 +28,60 @@ def batch_request_item(request_info): return BatchRequestItem(request_information=request_info) +@pytest.fixture +def request_info_json(): + request_info = RequestInformation() + request_info.http_method = "POST" + request_info.url = "https://graph.microsoft.com/v1.0/me/events" + request_info.headers = RequestHeaders() + request_info.headers.add("Content-Type", "application/json") + request_info.content = json.dumps( + { + "@odata.type": "#microsoft.graph.event", + "end": { + "dateTime": "2024-10-14T17:30:00", + "timeZone": "Pacific Standard Time" + }, + "start": { + "dateTime": "2024-10-14T17:00:00", + "timeZone": "Pacific Standard Time" + }, + "subject": "File end-of-day report" + } + ).encode('utf-8') + return request_info + + +@pytest.fixture +def request_info_bytes(): + request_info = RequestInformation() + request_info.http_method = "POST" + request_info.url = "https://graph.microsoft.com/v1.0/me/events" + request_info.headers = RequestHeaders() + request_info.headers.add("Content-Type", "application/json") + request_info.content = b'{"@odata.type": "#microsoft.graph.event", "end": {"dateTime": "2024-10-14T17:30:00", "timeZone": "Pacific Standard Time"}, "start": {"dateTime": "2024-10-14T17:00:00", "timeZone": "Pacific Standard Time"}, "subject": "File end-of-day report"}' + return request_info + + +@pytest.fixture +def batch_request_item_json(request_info_json): + return BatchRequestItem(request_information=request_info_json) + + +@pytest.fixture +def batch_request_item_bytes(request_info_bytes): + return BatchRequestItem(request_information=request_info_bytes) + + +def encode_body_to_base64(body): + if isinstance(body, bytes): + return base64.b64encode(body).decode('utf-8') + elif isinstance(body, str): + return base64.b64encode(body.encode('utf-8')).decode('utf-8') + else: + raise ValueError("Unsupported body type") + + def test_initialization(batch_request_item, request_info): assert batch_request_item.method == "GET" assert batch_request_item.url == "f{base_url}/me" @@ -124,3 +181,21 @@ def test_batch_request_item_method_enum(): def test_depends_on_property(batch_request_item): batch_request_item.set_depends_on(["request1", "request2"]) assert batch_request_item.depends_on == ["request1", "request2"] + + +def test_serialize_with_json_body(batch_request_item_json): + item = batch_request_item_json + writer = Mock() + processed_body = encode_body_to_base64(item.body) + + item.serialize(writer) + writer.write_str_value.assert_called_with('body', processed_body) + + +def test_serialize_with_bytes_body(batch_request_item_bytes): + item = batch_request_item_bytes + writer = Mock() + processed_body = encode_body_to_base64(item.body) + + item.serialize(writer) + writer.write_str_value.assert_called_with('body', processed_body)