Skip to content

Commit d8a0ee3

Browse files
committed
Unit Test and MCP Instrumentor
1 parent b6b7cf8 commit d8a0ee3

File tree

5 files changed

+751
-0
lines changed

5 files changed

+751
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# MCP Instrumentor
2+
3+
OpenTelemetry MCP instrumentation package.
4+
5+
## Installation
6+
7+
```bash
8+
pip install mcpinstrumentor
9+
```
10+
11+
## Usage
12+
13+
```python
14+
from mcpinstrumentor import MCPInstrumentor
15+
16+
MCPInstrumentor().instrument()
17+
```

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/loggertwo.log

Whitespace-only changes.
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import logging
2+
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast
3+
4+
from openinference.instrumentation.mcp.package import _instruments
5+
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
6+
7+
from opentelemetry import context, propagate, trace
8+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
9+
from opentelemetry.instrumentation.utils import unwrap
10+
from opentelemetry.sdk.resources import Resource
11+
12+
13+
def setup_loggertwo():
14+
logger = logging.getLogger("loggertwo")
15+
logger.setLevel(logging.DEBUG)
16+
handler = logging.FileHandler("loggertwo.log", mode="w")
17+
handler.setLevel(logging.DEBUG)
18+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
19+
handler.setFormatter(formatter)
20+
if not logger.handlers:
21+
logger.addHandler(handler)
22+
return logger
23+
24+
25+
loggertwo = setup_loggertwo()
26+
27+
28+
class MCPInstrumentor(BaseInstrumentor):
29+
"""
30+
An instrumenter for MCP.
31+
"""
32+
33+
def instrumentation_dependencies(self) -> Collection[str]:
34+
return _instruments
35+
36+
def _instrument(self, **kwargs: Any) -> None:
37+
tracer_provider = kwargs.get("tracer_provider") # Move this line up
38+
if tracer_provider:
39+
self.tracer_provider = tracer_provider
40+
else:
41+
self.tracer_provider = None
42+
register_post_import_hook(
43+
lambda _: wrap_function_wrapper(
44+
"mcp.shared.session",
45+
"BaseSession.send_request",
46+
self._send_request_wrapper,
47+
),
48+
"mcp.shared.session",
49+
)
50+
register_post_import_hook(
51+
lambda _: wrap_function_wrapper(
52+
"mcp.server.lowlevel.server",
53+
"Server._handle_request",
54+
self._server_handle_request_wrapper,
55+
),
56+
"mcp.server.lowlevel.server",
57+
)
58+
59+
def _uninstrument(self, **kwargs: Any) -> None:
60+
unwrap("mcp.shared.session", "BaseSession.send_request")
61+
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
62+
63+
def handle_attributes(self, span, request, is_client=True):
64+
import mcp.types as types
65+
66+
operation = "Server Handle Request"
67+
if isinstance(request, types.ListToolsRequest):
68+
operation = "ListTool"
69+
span.set_attribute("mcp.list_tools", True)
70+
elif isinstance(request, types.CallToolRequest):
71+
if hasattr(request, "params") and hasattr(request.params, "name"):
72+
operation = request.params.name
73+
span.set_attribute("mcp.call_tool", True)
74+
if is_client:
75+
self._add_client_attributes(span, operation, request)
76+
else:
77+
self._add_server_attributes(span, operation, request)
78+
79+
80+
81+
def _add_client_attributes(self, span, operation, request):
82+
span.set_attribute("span.kind", "CLIENT")
83+
span.set_attribute("aws.remote.service", "Appsignals MCP Server")
84+
span.set_attribute("aws.remote.operation", operation)
85+
if hasattr(request, "params") and hasattr(request.params, "name"):
86+
span.set_attribute("tool.name", request.params.name)
87+
88+
def _add_server_attributes(self, span, operation, request):
89+
span.set_attribute("server_side", True)
90+
span.set_attribute("aws.span.kind", "SERVER")
91+
if hasattr(request, "params") and hasattr(request.params, "name"):
92+
span.set_attribute("tool.name", request.params.name)
93+
94+
def _inject_trace_context(self, request_data, span_ctx):
95+
if "params" not in request_data:
96+
request_data["params"] = {}
97+
if "_meta" not in request_data["params"]:
98+
request_data["params"]["_meta"] = {}
99+
request_data["params"]["_meta"]["trace_context"] = {"trace_id": span_ctx.trace_id, "span_id": span_ctx.span_id}
100+
101+
# Send Request Wrapper
102+
def _send_request_wrapper(self, wrapped, instance, args, kwargs):
103+
"""
104+
Changes made:
105+
The wrapper intercepts the request before sending, injects distributed tracing context into the
106+
request's params._meta field and creates OpenTelemetry spans. The wrapper does not change anything else from the original function's
107+
behavior because it reconstructs the request object with the same type and calling the original function with identical parameters.
108+
"""
109+
110+
async def async_wrapper():
111+
if self.tracer_provider is None:
112+
tracer = trace.get_tracer("mcp.client")
113+
else:
114+
tracer = self.tracer_provider.get_tracer("mcp.client")
115+
with tracer.start_as_current_span(
116+
"client.send_request", kind=trace.SpanKind.CLIENT
117+
) as span:
118+
span_ctx = span.get_span_context()
119+
request = args[0] if len(args) > 0 else kwargs.get("request")
120+
if request:
121+
req_root = request.root if hasattr(request, "root") else request
122+
123+
self.handle_attributes(span, req_root, True)
124+
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
125+
self._inject_trace_context(request_data, span_ctx)
126+
# Reconstruct request object with injected trace context
127+
modified_request = type(request).model_validate(request_data)
128+
if len(args) > 0:
129+
new_args = (modified_request,) + args[1:]
130+
result = await wrapped(*new_args, **kwargs)
131+
else:
132+
kwargs["request"] = modified_request
133+
result = await wrapped(*args, **kwargs)
134+
else:
135+
result = await wrapped(*args, **kwargs)
136+
return result
137+
138+
return async_wrapper()
139+
140+
def getname(self, req):
141+
span_name = "unknown"
142+
import mcp.types as types
143+
144+
if isinstance(req, types.ListToolsRequest):
145+
span_name = "tools/list"
146+
elif isinstance(req, types.CallToolRequest):
147+
if hasattr(req, "params") and hasattr(req.params, "name"):
148+
span_name = f"tools/{req.params.name}"
149+
else:
150+
span_name = "unknown"
151+
return span_name
152+
153+
# Handle Request Wrapper
154+
async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs):
155+
"""
156+
Changes made:
157+
This wrapper intercepts requests before processing, extracts distributed tracing context from
158+
the request's params._meta field, and creates server-side OpenTelemetry spans linked to the client spans. The wrapper
159+
also does not change the original function's behavior by calling it with identical parameters
160+
ensuring no breaking changes to the MCP server functionality.
161+
"""
162+
req = args[1] if len(args) > 1 else None
163+
trace_context = None
164+
165+
if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta:
166+
trace_context = req.params.meta.trace_context
167+
if trace_context:
168+
169+
if self.tracer_provider is None:
170+
tracer = trace.get_tracer("mcp.server")
171+
else:
172+
tracer = self.tracer_provider.get_tracer("mcp.server")
173+
trace_id = trace_context.get("trace_id")
174+
span_id = trace_context.get("span_id")
175+
span_context = trace.SpanContext(trace_id=trace_id, span_id=span_id, is_remote=True,trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),trace_state=trace.TraceState())
176+
span_name = self.getname(req)
177+
with tracer.start_as_current_span(
178+
span_name,
179+
kind=trace.SpanKind.SERVER,
180+
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
181+
) as span:
182+
self.handle_attributes(span, req, False)
183+
result = await wrapped(*args, **kwargs)
184+
return result
185+
else:
186+
return await wrapped(*args, **kwargs,)
187+
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
[build-system]
2+
requires = ["hatchling"]
3+
build-backend = "hatchling.build"
4+
5+
[project]
6+
name = "amazon-opentelemetry-distro-mcpinstrumentor"
7+
version = "0.1.0"
8+
description = "OpenTelemetry MCP instrumentation for AWS Distro"
9+
readme = "README.md"
10+
license = "Apache-2.0"
11+
requires-python = ">=3.9"
12+
authors = [
13+
{ name = "Johnny Lin", email = "[email protected]" },
14+
]
15+
classifiers = [
16+
"Development Status :: 4 - Beta",
17+
"Intended Audience :: Developers",
18+
"License :: OSI Approved :: Apache Software License",
19+
"Programming Language :: Python",
20+
"Programming Language :: Python :: 3",
21+
"Programming Language :: Python :: 3.9",
22+
"Programming Language :: Python :: 3.10",
23+
"Programming Language :: Python :: 3.11",
24+
"Programming Language :: Python :: 3.12",
25+
"Programming Language :: Python :: 3.13",
26+
]
27+
dependencies = [
28+
"opentelemetry-api",
29+
"opentelemetry-instrumentation",
30+
"opentelemetry-semantic-conventions",
31+
"wrapt",
32+
"opentelemetry-sdk",
33+
]
34+
35+
[project.optional-dependencies]
36+
instruments = ["mcp"]
37+
38+
[project.entry-points.opentelemetry_instrumentor]
39+
mcp = "mcpinstrumentor:MCPInstrumentor"
40+
41+
[tool.hatch.build.targets.sdist]
42+
include = [
43+
"mcpinstrumentor.py",
44+
"README.md"
45+
]
46+
47+
[tool.hatch.build.targets.wheel]
48+
packages = ["."]

0 commit comments

Comments
 (0)