1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import contextlib
4
5
import json
5
6
from logging import getLogger
6
7
from typing import (
14
15
Protocol ,
15
16
Sequence ,
16
17
Tuple ,
18
+ cast ,
17
19
)
18
20
19
21
import aiokafka
@@ -56,19 +58,18 @@ async def __call__(
56
58
headers : HeadersT | None = None ,
57
59
) -> asyncio .Future [RecordMetadata ]: ...
58
60
59
- ProduceHookT = Optional [
60
- Callable [[ Span , Tuple [Any , ...], Dict [str , Any ]], Awaitable [None ] ]
61
+ ProduceHookT = Callable [
62
+ [ Span , Tuple [Any , ...], Dict [str , Any ]], Awaitable [None ]
61
63
]
62
- ConsumeHookT = Optional [
63
- Callable [
64
- [
65
- Span ,
66
- aiokafka .ConsumerRecord [object , object ],
67
- Tuple [aiokafka .TopicPartition , ...],
68
- Dict [str , Any ],
69
- ],
70
- Awaitable [None ],
71
- ]
64
+
65
+ ConsumeHookT = Callable [
66
+ [
67
+ Span ,
68
+ aiokafka .ConsumerRecord [object , object ],
69
+ Tuple [aiokafka .TopicPartition , ...],
70
+ Dict [str , Any ],
71
+ ],
72
+ Awaitable [None ],
72
73
]
73
74
74
75
HeadersT = Sequence [Tuple [str , Optional [bytes ]]]
@@ -89,7 +90,7 @@ def _extract_client_id(client: aiokafka.AIOKafkaClient) -> str:
89
90
def _extract_consumer_group (
90
91
consumer : aiokafka .AIOKafkaConsumer ,
91
92
) -> str | None :
92
- return consumer ._group_id
93
+ return consumer ._group_id # type: ignore[reportUnknownVariableType]
93
94
94
95
95
96
def _extract_argument (
@@ -139,6 +140,17 @@ def _move_headers_to_kwargs(
139
140
return args [:5 ], kwargs
140
141
141
142
143
+ def _deserialize_key (key : object | None ) -> str | None :
144
+ if key is None :
145
+ return None
146
+
147
+ if isinstance (key , bytes ):
148
+ with contextlib .suppress (UnicodeDecodeError ):
149
+ return key .decode ()
150
+
151
+ return str (key )
152
+
153
+
142
154
async def _extract_send_partition (
143
155
instance : aiokafka .AIOKafkaProducer ,
144
156
args : tuple [Any , ...],
@@ -150,17 +162,20 @@ async def _extract_send_partition(
150
162
key = _extract_send_key (args , kwargs )
151
163
value = _extract_send_value (args , kwargs )
152
164
partition = _extract_argument ("partition" , 3 , None , args , kwargs )
153
- key_bytes , value_bytes = instance ._serialize (topic , key , value )
165
+ key_bytes , value_bytes = cast (
166
+ "tuple[bytes | None, bytes | None]" ,
167
+ instance ._serialize (topic , key , value ), # type: ignore[reportUnknownMemberType]
168
+ )
154
169
valid_types = (bytes , bytearray , memoryview , type (None ))
155
170
if (
156
171
type (key_bytes ) not in valid_types
157
172
or type (value_bytes ) not in valid_types
158
173
):
159
174
return None
160
175
161
- await instance .client ._wait_on_metadata (topic )
176
+ await instance .client ._wait_on_metadata (topic ) # type: ignore[reportUnknownMemberType]
162
177
163
- return instance ._partition (
178
+ return instance ._partition ( # type: ignore[reportUnknownMemberType]
164
179
topic , partition , key , value , key_bytes , value_bytes
165
180
)
166
181
except Exception as exception : # pylint: disable=W0703
@@ -170,26 +185,21 @@ async def _extract_send_partition(
170
185
171
186
class AIOKafkaContextGetter (textmap .Getter ["HeadersT" ]):
172
187
def get (self , carrier : HeadersT , key : str ) -> list [str ] | None :
173
- if carrier is None :
174
- return None
175
-
176
188
for item_key , value in carrier :
177
189
if item_key == key :
178
190
if value is not None :
179
191
return [value .decode ()]
180
192
return None
181
193
182
194
def keys (self , carrier : HeadersT ) -> list [str ]:
183
- if carrier is None :
184
- return []
185
- return [key for (key , value ) in carrier ]
195
+ return [key for (key , _ ) in carrier ]
186
196
187
197
188
198
class AIOKafkaContextSetter (textmap .Setter ["HeadersT" ]):
189
199
def set (
190
200
self , carrier : HeadersT , key : str | None , value : str | None
191
201
) -> None :
192
- if carrier is None or key is None :
202
+ if key is None :
193
203
return
194
204
195
205
if not isinstance (carrier , MutableSequence ):
@@ -215,7 +225,7 @@ def _enrich_base_span(
215
225
client_id : str ,
216
226
topic : str ,
217
227
partition : int | None ,
218
- key : object | None ,
228
+ key : str | None ,
219
229
) -> None :
220
230
span .set_attribute (
221
231
messaging_attributes .MESSAGING_SYSTEM ,
@@ -235,8 +245,7 @@ def _enrich_base_span(
235
245
236
246
if key is not None :
237
247
span .set_attribute (
238
- messaging_attributes .MESSAGING_KAFKA_MESSAGE_KEY ,
239
- key , # FIXME: serialize key to str?
248
+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_KEY , key
240
249
)
241
250
242
251
@@ -247,7 +256,7 @@ def _enrich_send_span(
247
256
client_id : str ,
248
257
topic : str ,
249
258
partition : int | None ,
250
- key : object | None ,
259
+ key : str | None ,
251
260
) -> None :
252
261
if not span .is_recording ():
253
262
return
@@ -276,7 +285,7 @@ def _enrich_getone_span(
276
285
consumer_group : str | None ,
277
286
topic : str ,
278
287
partition : int | None ,
279
- key : object | None ,
288
+ key : str | None ,
280
289
offset : int ,
281
290
) -> None :
282
291
if not span .is_recording ():
@@ -399,8 +408,8 @@ def _get_span_name(operation: str, topic: str):
399
408
return f"{ topic } { operation } "
400
409
401
410
402
- def _wrap_send (
403
- tracer : Tracer , async_produce_hook : ProduceHookT
411
+ def _wrap_send ( # type: ignore[reportUnusedFunction]
412
+ tracer : Tracer , async_produce_hook : ProduceHookT | None
404
413
) -> Callable [..., Awaitable [asyncio .Future [RecordMetadata ]]]:
405
414
async def _traced_send (
406
415
func : AIOKafkaSendProto ,
@@ -417,7 +426,7 @@ async def _traced_send(
417
426
topic = _extract_send_topic (args , kwargs )
418
427
bootstrap_servers = _extract_bootstrap_servers (instance .client )
419
428
client_id = _extract_client_id (instance .client )
420
- key = _extract_send_key (args , kwargs )
429
+ key = _deserialize_key ( _extract_send_key (args , kwargs ) )
421
430
partition = await _extract_send_partition (instance , args , kwargs )
422
431
span_name = _get_span_name ("send" , topic )
423
432
with tracer .start_as_current_span (
@@ -449,7 +458,7 @@ async def _traced_send(
449
458
450
459
async def _create_consumer_span (
451
460
tracer : Tracer ,
452
- async_consume_hook : ConsumeHookT ,
461
+ async_consume_hook : ConsumeHookT | None ,
453
462
record : aiokafka .ConsumerRecord [object , object ],
454
463
extracted_context : Context ,
455
464
bootstrap_servers : str | list [str ],
@@ -473,7 +482,7 @@ async def _create_consumer_span(
473
482
consumer_group = consumer_group ,
474
483
topic = record .topic ,
475
484
partition = record .partition ,
476
- key = record .key ,
485
+ key = _deserialize_key ( record .key ) ,
477
486
offset = record .offset ,
478
487
)
479
488
try :
@@ -486,8 +495,8 @@ async def _create_consumer_span(
486
495
return span
487
496
488
497
489
- def _wrap_getone (
490
- tracer : Tracer , async_consume_hook : ConsumeHookT
498
+ def _wrap_getone ( # type: ignore[reportUnusedFunction]
499
+ tracer : Tracer , async_consume_hook : ConsumeHookT | None
491
500
) -> Callable [..., Awaitable [aiokafka .ConsumerRecord [object , object ]]]:
492
501
async def _traced_getone (
493
502
func : AIOKafkaGetOneProto ,
@@ -521,8 +530,8 @@ async def _traced_getone(
521
530
return _traced_getone
522
531
523
532
524
- def _wrap_getmany (
525
- tracer : Tracer , async_consume_hook : ConsumeHookT
533
+ def _wrap_getmany ( # type: ignore[reportUnusedFunction]
534
+ tracer : Tracer , async_consume_hook : ConsumeHookT | None
526
535
) -> Callable [
527
536
...,
528
537
Awaitable [
0 commit comments