Skip to content

Commit

Permalink
fix: Handle the streaming of JSON delimited by newlines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719423860
  • Loading branch information
yeesian authored and copybara-github committed Jan 28, 2025
1 parent 4531d08 commit a16b4ce
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 33 deletions.
30 changes: 14 additions & 16 deletions tests/unit/vertex_langchain/test_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,60 +2234,58 @@ class ToParsedJsonTest(parameterized.TestCase):
obj=httpbody_pb2.HttpBody(
content_type="application/json", data=b'{"a": 1, "b": "hello"}'
),
expected={"a": 1, "b": "hello"},
expected=[{"a": 1, "b": "hello"}],
),
dict(
testcase_name="invalid_json",
obj=httpbody_pb2.HttpBody(
content_type="application/json", data=b'{"a": 1, "b": "hello"'
),
expected=httpbody_pb2.HttpBody(
content_type="application/json", data=b'{"a": 1, "b": "hello"'
),
expected=['{"a": 1, "b": "hello"'], # returns the unparsed string
),
dict(
testcase_name="missing_content_type",
obj=httpbody_pb2.HttpBody(data=b'{"a": 1}'),
expected=httpbody_pb2.HttpBody(data=b'{"a": 1}'),
expected=[httpbody_pb2.HttpBody(data=b'{"a": 1}')],
),
dict(
testcase_name="missing_data",
obj=httpbody_pb2.HttpBody(content_type="application/json"),
expected=None,
expected=[None],
),
dict(
testcase_name="wrong_content_type",
obj=httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"),
expected=httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"),
expected=[httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello")],
),
dict(
testcase_name="empty_data",
obj=httpbody_pb2.HttpBody(content_type="application/json", data=b""),
expected=None,
expected=[None],
),
dict(
testcase_name="unicode_data",
obj=httpbody_pb2.HttpBody(
content_type="application/json", data='{"a": "你好"}'.encode("utf-8")
),
expected={"a": "你好"},
expected=[{"a": "你好"}],
),
dict(
testcase_name="nested_json",
obj=httpbody_pb2.HttpBody(
content_type="application/json", data=b'{"a": {"b": 1}}'
),
expected={"a": {"b": 1}},
expected=[{"a": {"b": 1}}],
),
dict(
testcase_name="error_handling",
testcase_name="multiline_json",
obj=httpbody_pb2.HttpBody(
content_type="application/json", data=b'{"a": 1, "b": "hello"'
),
expected=httpbody_pb2.HttpBody(
content_type="application/json", data=b'{"a": 1, "b": "hello"'
content_type="application/json",
data=b'{"a": {"b": 1}}\n{"a": {"b": 2}}',
),
expected=[{"a": {"b": 1}}, {"a": {"b": 2}}],
),
)
def test_to_parsed_json(self, obj, expected):
self.assertEqual(_utils.to_parsed_json(obj), expected)
for got, want in zip(_utils.yield_parsed_json(obj), expected):
self.assertEqual(got, want)
6 changes: 3 additions & 3 deletions vertexai/reasoning_engines/_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,9 @@ def _method(self, **kwargs) -> Iterable[Any]:
),
)
for chunk in response:
parsed_json = _utils.to_parsed_json(chunk)
if parsed_json is not None:
yield parsed_json
for parsed_json in _utils.yield_parsed_json(chunk):
if parsed_json is not None:
yield parsed_json

_method.__name__ = method_name
_method.__doc__ = doc
Expand Down
34 changes: 20 additions & 14 deletions vertexai/reasoning_engines/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import types
import typing
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union

import proto

Expand Down Expand Up @@ -90,36 +90,42 @@ def to_dict(message: proto.Message) -> JsonDict:
return result


def to_parsed_json(body: httpbody_pb2.HttpBody) -> Any:
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
"""Converts the contents of the httpbody message to JSON format.
Args:
body (httpbody_pb2.HttpBody):
Required. The httpbody body to be converted to a JSON.
Returns:
Yields:
Any: A JSON object or the original body if it is not JSON or None.
"""
content_type = getattr(body, "content_type", None)
data = getattr(body, "data", None)

if content_type is None or data is None or "application/json" not in content_type:
return body
yield body
return

try:
utf8_data = data.decode("utf-8")
except Exception as e:
_LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
return body
yield body
return

if not utf8_data:
return None
yield None
return

try:
return json.loads(utf8_data)
except Exception as e:
_LOGGER.warning(f"Failed to parse JSON: {utf8_data}. Exception: {e}")
return body # Return the raw body on error
# Handle the case of multiple dictionaries delimited by newlines.
for line in utf8_data.split("\n"):
if line:
try:
line = json.loads(line)
except Exception as e:
_LOGGER.warning(f"failed to parse json: {line}. Exception: {e}")
yield line


def generate_schema(
Expand Down Expand Up @@ -195,9 +201,9 @@ def generate_schema(
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
if typing.get_origin(annotation) is typing.Union and type(
None
) in typing.get_args(annotation):
if typing.get_origin(annotation) is Union and type(None) in typing.get_args(
annotation
):
# for "typing.Optional" arguments, function_arg might be a
# dictionary like
#
Expand Down

0 comments on commit a16b4ce

Please sign in to comment.