Skip to content

Commit d242ef2

Browse files
committed
initial commit for llo handler unit tests
1 parent ea81f26 commit d242ef2

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from unittest import TestCase
2+
from unittest.mock import MagicMock, patch, call
3+
4+
from amazon.opentelemetry.distro.llo_handler import LLOHandler
5+
from opentelemetry._events import Event
6+
from opentelemetry.sdk._logs import LoggerProvider
7+
from opentelemetry.sdk.trace import ReadableSpan, SpanContext
8+
from opentelemetry.trace import SpanKind, TraceFlags, TraceState
9+
10+
11+
class TestLLOHandler(TestCase):
12+
def setUp(self):
13+
self.logger_provider_mock = MagicMock(spec=LoggerProvider)
14+
self.event_logger_mock = MagicMock()
15+
self.event_logger_provider_mock = MagicMock()
16+
self.event_logger_provider_mock.get_event_logger.return_value = self.event_logger_mock
17+
18+
with patch(
19+
"amazon.opentelemetry.distro.llo_handler.EventLoggerProvider", return_value=self.event_logger_provider_mock
20+
):
21+
self.llo_handler = LLOHandler(self.logger_provider_mock)
22+
23+
def _create_mock_span(self, attributes=None, kind=SpanKind.INTERNAL):
24+
"""
25+
Helper method to create a mock span with given attributes
26+
"""
27+
if attributes is None:
28+
attributes = {}
29+
30+
span_context = SpanContext(
31+
trace_id=0x123456789ABCDEF0123456789ABCDEF0,
32+
span_id=0x123456789ABCDEF0,
33+
is_remote=False,
34+
trace_flags=TraceFlags.SAMPLED,
35+
trace_state=TraceState.get_default(),
36+
)
37+
38+
mock_span = MagicMock(spec=ReadableSpan)
39+
mock_span.context = span_context
40+
mock_span.attributes = attributes
41+
mock_span.kind = kind
42+
mock_span.start_time = 1234567890
43+
44+
return mock_span
45+
46+
def test_init(self):
47+
"""
48+
Test initialization of LLOHandler
49+
"""
50+
self.assertEqual(self.llo_handler._logger_provider, self.logger_provider_mock)
51+
self.assertEqual(self.llo_handler._event_logger_provider, self.event_logger_provider_mock)
52+
self.event_logger_provider_mock.get_event_logger.assert_called_once_with("gen_ai.events")
53+
54+
def test_is_llo_attribute_match(self):
55+
"""
56+
Test _is_llo_attribute method with matching patterns
57+
"""
58+
self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.prompt.0.content"))
59+
self.assertTrue(self.llo_handler._is_llo_attribute("gen_ai.prompt.123.content"))
60+
61+
def test_is_llo_attribute_no_match(self):
62+
"""
63+
Test _is_llo_attribute method with non-matching patterns
64+
"""
65+
self.assertFalse(self.llo_handler._is_llo_attribute("gen_ai.prompt.content"))
66+
self.assertFalse(self.llo_handler._is_llo_attribute("gen_ai.prompt.abc.content"))
67+
self.assertFalse(self.llo_handler._is_llo_attribute("some.other.attribute"))
68+
69+
def test_filter_attributes(self):
70+
"""
71+
Test _filter_attributes method
72+
"""
73+
attributes = {
74+
"gen_ai.prompt.0.content": "test content",
75+
"gen_ai.prompt.0.role": "user",
76+
"normal.attribute": "value",
77+
"another.normal.attribute": 123
78+
}
79+
80+
filtered = self.llo_handler._filter_attributes(attributes)
81+
82+
self.assertNotIn("gen_ai.prompt.0.content", filtered)
83+
self.assertIn("gen_ai.prompt.0.role", filtered)
84+
self.assertIn("normal.attribute", filtered)
85+
self.assertIn("another.normal.attribute", filtered)
86+
87+
def test_extract_gen_ai_prompt_events_system_role(self):
88+
"""
89+
Test _extract_gen_ai_prompt_events with system role
90+
"""
91+
attributes = {
92+
"gen_ai.prompt.0.content": "system instruction",
93+
"gen_ai.prompt.0.role": "system",
94+
"gen_ai.system": "openai"
95+
}
96+
97+
span = self._create_mock_span(attributes)
98+
99+
events = self.llo_handler._extract_gen_ai_prompt_events(span, attributes)
100+
101+
self.assertEqual(len(events), 1)
102+
event = events[0]
103+
self.assertEqual(event.name, "gen_ai.system.message")
104+
self.assertEqual(event.body["content"], "system instruction")
105+
self.assertEqual(event.body["role"], "system")
106+
self.assertEqual(event.attributes["gen_ai.system"], "openai")
107+
self.assertEqual(event.attributes["original_attribute"], "gen_ai.prompt.0.content")
108+
109+
def test_extract_gen_ai_prompt_events_user_role(self):
110+
"""
111+
Test _extract_gen_ai_prompt_events with user role
112+
"""
113+
attributes = {
114+
"gen_ai.prompt.0.content": "user question",
115+
"gen_ai.prompt.0.role": "user",
116+
"gen_ai.system": "anthropic"
117+
}
118+
119+
span = self._create_mock_span(attributes)
120+
121+
events = self.llo_handler._extract_gen_ai_prompt_events(span, attributes)
122+
123+
self.assertEqual(len(events), 1)
124+
event = events[0]
125+
self.assertEqual(event.name, "gen_ai.user.message")
126+
self.assertEqual(event.body["content"], "user question")
127+
self.assertEqual(event.body["role"], "user")
128+
self.assertEqual(event.attributes["gen_ai.system"], "anthropic")
129+
self.assertEqual(event.attributes["original_attribute"], "gen_ai.prompt.0.content")
130+
131+
def test_extract_gen_ai_prompt_events_assistant_role(self):
132+
"""
133+
Test _extract_gen_ai_prompt_events with assistant role
134+
"""
135+
attributes = {
136+
"gen_ai.prompt.1.content": "assistant response",
137+
"gen_ai.prompt.1.role": "assistant",
138+
"gen_ai.system": "anthropic"
139+
}
140+
141+
span = self._create_mock_span(attributes)
142+
143+
events = self.llo_handler._extract_gen_ai_prompt_events(span, attributes)
144+
145+
self.assertEqual(len(events), 1)
146+
event = events[0]
147+
self.assertEqual(event.name, "gen_ai.assistant.message")
148+
self.assertEqual(event.body["content"], "assistant response")
149+
self.assertEqual(event.body["role"], "assistant")
150+
self.assertEqual(event.attributes["gen_ai.system"], "anthropic")
151+
self.assertEqual(event.attributes["original_attribute"], "gen_ai.prompt.1.content")
152+
153+
def test_extract_gen_ai_prompt_events_function_role(self):
154+
"""
155+
Test _extract_gen_ai_prompt_events with function role
156+
"""
157+
pass
158+
159+
def test_emit_llo_attributes(self):
160+
pass
161+
162+
def test_process_spans(self):
163+
pass

0 commit comments

Comments
 (0)