Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 16 additions & 33 deletions faststream/confluent/parser.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from faststream.message import StreamMessage, decode_message

from .message import FAKE_CONSUMER, KafkaMessage

if TYPE_CHECKING:
from collections.abc import Sequence

from confluent_kafka import Message

from faststream._internal.basic_types import DecodedMessage

from .message import ConsumerProtocol

# Type of headers returned by confluent_kafka Message.headers()
_HeadersInput = (
dict[str, str | bytes | None]
| list[tuple[str, str | bytes | None]]
| tuple[tuple[str, str | bytes | None], ...]
)


class AsyncConfluentParser:
"""A class to parse Kafka messages."""
Expand All @@ -31,12 +22,18 @@ def __init__(self, is_manual: bool = False) -> None:
def _setup(self, consumer: "ConsumerProtocol") -> None:
self._consumer = consumer

@staticmethod
def _decode_header(headers: dict[str, str | bytes | None], key: str) -> str | None:
"""Decode a single header value safely, handling non-UTF-8 bytes."""
val = headers.get(key)
return val.decode(errors="replace") if isinstance(val, bytes) and val else val or None

async def parse_message(
self,
message: "Message",
) -> KafkaMessage:
"""Parses a Kafka message."""
headers = _parse_msg_headers(cast("_HeadersInput", message.headers() or ()))
headers = dict(message.headers() or ())

body = message.value() or b""
offset = message.offset()
Expand All @@ -45,10 +42,10 @@ async def parse_message(
return KafkaMessage(
body=body,
headers=headers,
reply_to=headers.get("reply_to", ""),
content_type=headers.get("content-type"),
reply_to=self._decode_header(headers, "reply_to") or "",
content_type=self._decode_header(headers, "content-type"),
message_id=f"{offset}-{timestamp}",
correlation_id=headers.get("correlation_id"),
correlation_id=self._decode_header(headers, "correlation_id"),
raw_message=message,
consumer=self._consumer,
is_manual=self.is_manual,
Expand All @@ -60,16 +57,14 @@ async def parse_batch(
) -> KafkaMessage:
"""Parses a batch of messages from a Kafka consumer."""
body: list[Any] = []
batch_headers: list[dict[str, str]] = []
batch_headers: list[dict[str, str | bytes | None]] = []

first = message[0]
last = message[-1]

for m in message:
body.append(m.value() or b"")
batch_headers.append(
_parse_msg_headers(cast("_HeadersInput", m.headers() or ()))
)
batch_headers.append(dict(m.headers() or ()))

headers = next(iter(batch_headers), {})

Expand All @@ -79,10 +74,10 @@ async def parse_batch(
body=body,
headers=headers,
batch_headers=batch_headers,
reply_to=headers.get("reply_to", ""),
content_type=headers.get("content-type"),
reply_to=self._decode_header(headers, "reply_to") or "",
content_type=self._decode_header(headers, "content-type"),
message_id=f"{first.offset()}-{last.offset()}-{first_timestamp}",
correlation_id=headers.get("correlation_id"),
correlation_id=self._decode_header(headers, "correlation_id"),
raw_message=message,
consumer=self._consumer,
is_manual=self.is_manual,
Expand All @@ -101,15 +96,3 @@ async def decode_batch(
) -> "DecodedMessage":
"""Decode a batch of messages."""
return [decode_message(await self.parse_message(m)) for m in msg.raw_message]


def _parse_msg_headers(headers: "_HeadersInput") -> dict[str, str]:
if isinstance(headers, dict):
seq: Sequence[tuple[str, bytes | str | None]] = list(headers.items())
else:
seq = headers
return {
i: (j if isinstance(j, str) else (j.decode() if j is not None else ""))
for i, j in seq
if j is not None
}
4 changes: 2 additions & 2 deletions faststream/kafka/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def parse_message(
message: Union["ConsumerRecord", "KafkaRawMessage"],
) -> "StreamMessage[ConsumerRecord]":
"""Parses a Kafka message."""
headers = {i: j.decode() for i, j in message.headers}
headers = {i: j.decode(errors="replace") for i, j in message.headers}

return self.msg_class(
body=message.value or b"",
Expand Down Expand Up @@ -75,7 +75,7 @@ async def parse_batch(

for m in message:
body.append(m.value or b"")
batch_headers.append({i: j.decode() for i, j in m.headers})
batch_headers.append({i: j.decode(errors="replace") for i, j in m.headers})

headers = next(iter(batch_headers), {})

Expand Down
133 changes: 133 additions & 0 deletions tests/brokers/confluent/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from unittest.mock import MagicMock

from faststream.confluent.parser import AsyncConfluentParser
from tests.brokers.base.parser import CustomParserTestcase

from .basic import ConfluentTestcaseConfig
Expand All @@ -9,3 +11,134 @@
@pytest.mark.confluent()
class TestCustomParser(ConfluentTestcaseConfig, CustomParserTestcase):
pass


class TestNonUtf8ConfluentHeaders:
"""Tests for non-UTF-8 header handling in AsyncConfluentParser.

Covers issues:
- https://github.com/ag2ai/faststream/issues/2458 (UnicodeDecodeError on non-UTF-8 headers)
- https://github.com/ag2ai/faststream/issues/2214 (support bytes headers)
"""

def _make_message(self, headers):
msg = MagicMock()
msg.headers.return_value = headers
msg.value.return_value = b"body"
msg.offset.return_value = 0
msg.timestamp.return_value = (0, 0)
return msg

def _make_batch(self, headers_list):
messages = []
for headers in headers_list:
msg = MagicMock()
msg.headers.return_value = headers
msg.value.return_value = b"body"
msg.offset.return_value = 0
msg.timestamp.return_value = (0, 0)
messages.append(msg)
return tuple(messages)

@pytest.mark.asyncio
async def test_non_utf8_header_does_not_raise(self):
"""parse_message must not raise UnicodeDecodeError on invalid UTF-8 bytes."""
parser = AsyncConfluentParser()
msg = self._make_message([("trash_header", b"\xc3\x28")])
result = await parser.parse_message(msg)
assert "trash_header" in result.headers

@pytest.mark.asyncio
async def test_non_utf8_reply_to_decoded_with_replace(self):
"""reply_to with non-UTF-8 bytes should be decoded with errors='replace'."""
parser = AsyncConfluentParser()
msg = self._make_message([("reply_to", b"\xc3\x28")])
result = await parser.parse_message(msg)
assert isinstance(result.reply_to, str)

@pytest.mark.asyncio
async def test_reply_to_empty_string_when_missing(self):
"""reply_to must default to '' when header is absent."""
parser = AsyncConfluentParser()
msg = self._make_message([])
result = await parser.parse_message(msg)
assert result.reply_to == ""

@pytest.mark.asyncio
async def test_content_type_none_when_missing(self):
"""content_type must be None when header is absent."""
parser = AsyncConfluentParser()
msg = self._make_message([])
result = await parser.parse_message(msg)
assert result.content_type is None

@pytest.mark.asyncio
async def test_bytes_header_value_preserved_in_headers(self):
"""Issue #2214: raw bytes header values should be preserved in msg.headers."""
import uuid
parser = AsyncConfluentParser()
event_id = uuid.uuid4().bytes
msg = self._make_message([("event_id", event_id)])
result = await parser.parse_message(msg)
assert result.headers.get("event_id") == event_id

@pytest.mark.asyncio
async def test_valid_utf8_header_decoded_correctly(self):
"""Valid UTF-8 headers should be decoded correctly."""
parser = AsyncConfluentParser()
msg = self._make_message([
("reply_to", b"some-topic"),
("content-type", b"application/json"),
("correlation_id", b"abc-123"),
])
result = await parser.parse_message(msg)
assert result.reply_to == "some-topic"
assert result.content_type == "application/json"
assert result.correlation_id == "abc-123"

@pytest.mark.asyncio
async def test_none_header_value_handled(self):
"""Header with None value should not raise."""
parser = AsyncConfluentParser()
msg = self._make_message([("nullable_header", None)])
result = await parser.parse_message(msg)
assert result.headers.get("nullable_header") is None

@pytest.mark.asyncio
async def test_batch_non_utf8_header_does_not_raise(self):
"""parse_batch must not raise on non-UTF-8 headers."""
parser = AsyncConfluentParser()
messages = self._make_batch([
[("trash_header", b"\xc3\x28")],
[("other", b"valid")],
])
result = await parser.parse_batch(messages)
assert isinstance(result.reply_to, str)

@pytest.mark.asyncio
async def test_batch_bytes_headers_preserved(self):
"""Issue #2214: bytes header values preserved in batch_headers."""
import uuid
parser = AsyncConfluentParser()
event_id = uuid.uuid4().bytes
messages = self._make_batch([
[("event_id", event_id)],
])
result = await parser.parse_batch(messages)
assert result.batch_headers[0].get("event_id") == event_id

def test_decode_header_returns_none_for_missing_key(self):
assert AsyncConfluentParser._decode_header({}, "missing") is None

def test_decode_header_decodes_valid_bytes(self):
assert AsyncConfluentParser._decode_header({"key": b"value"}, "key") == "value"

def test_decode_header_replaces_invalid_bytes(self):
result = AsyncConfluentParser._decode_header({"key": b"\xc3\x28"}, "key")
assert isinstance(result, str)

def test_decode_header_passthrough_str(self):
assert AsyncConfluentParser._decode_header({"key": "already-str"}, "key") == "already-str"

def test_decode_header_returns_none_for_none_value(self):
assert AsyncConfluentParser._decode_header({"key": None}, "key") is None
86 changes: 86 additions & 0 deletions tests/brokers/kafka/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
from unittest.mock import MagicMock

from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser
from faststream.kafka.message import KafkaMessage
from tests.brokers.base.parser import CustomParserTestcase

from .basic import KafkaTestcaseConfig
Expand All @@ -9,3 +12,86 @@
@pytest.mark.connected()
class TestCustomParser(KafkaTestcaseConfig, CustomParserTestcase):
pass


class TestNonUtf8KafkaHeaders:
"""Tests for non-UTF-8 header handling in AioKafkaParser / AioKafkaBatchParser.

Covers issue:
- https://github.com/ag2ai/faststream/issues/2458 (UnicodeDecodeError on non-UTF-8 headers)
"""

def _make_record(self, headers):
record = MagicMock()
record.headers = headers
record.value = b"body"
record.offset = 0
record.timestamp = 0
record.topic = "test-topic"
record.consumer = MagicMock()
return record

def _make_parser(self):
return AioKafkaParser(msg_class=KafkaMessage, regex=None)

def _make_batch_parser(self):
return AioKafkaBatchParser(msg_class=KafkaMessage, regex=None)

@pytest.mark.asyncio
async def test_non_utf8_header_does_not_raise(self):
"""parse_message must not raise UnicodeDecodeError on invalid UTF-8 bytes."""
parser = self._make_parser()
record = self._make_record([("trash_header", b"\xc3\x28")])
result = await parser.parse_message(record)
assert "trash_header" in result.headers

@pytest.mark.asyncio
async def test_non_utf8_header_decoded_with_replace(self):
"""Non-UTF-8 header bytes should be decoded with errors='replace'."""
parser = self._make_parser()
record = self._make_record([("key", b"\xc3\x28")])
result = await parser.parse_message(record)
assert isinstance(result.headers["key"], str)

@pytest.mark.asyncio
async def test_valid_utf8_decoded_correctly(self):
"""Valid UTF-8 headers should decode without modification."""
parser = self._make_parser()
record = self._make_record([
("reply_to", b"my-topic"),
("content-type", b"application/json"),
("correlation_id", b"uuid-123"),
])
result = await parser.parse_message(record)
assert result.reply_to == "my-topic"
assert result.content_type == "application/json"
assert result.correlation_id == "uuid-123"

@pytest.mark.asyncio
async def test_reply_to_defaults_to_empty_string(self):
"""reply_to must default to '' when header is absent."""
parser = self._make_parser()
record = self._make_record([])
result = await parser.parse_message(record)
assert result.reply_to == ""

@pytest.mark.asyncio
async def test_batch_non_utf8_header_does_not_raise(self):
"""parse_batch must not raise on invalid UTF-8."""
parser = self._make_batch_parser()
records = tuple([
self._make_record([("trash_header", b"\xc3\x28")]),
self._make_record([("other", b"valid")]),
])
result = await parser.parse_batch(records)
assert isinstance(result.reply_to, str)

@pytest.mark.asyncio
async def test_batch_headers_decoded_with_replace(self):
"""All batch_headers should be decoded with errors='replace'."""
parser = self._make_batch_parser()
records = tuple([
self._make_record([("key", b"\xc3\x28")]),
])
result = await parser.parse_batch(records)
assert isinstance(result.batch_headers[0]["key"], str)