diff --git a/ddtrace/contrib/internal/dd_trace_api/patch.py b/ddtrace/contrib/internal/dd_trace_api/patch.py index c4b82bfa6b1..45a480682d7 100644 --- a/ddtrace/contrib/internal/dd_trace_api/patch.py +++ b/ddtrace/contrib/internal/dd_trace_api/patch.py @@ -4,17 +4,53 @@ from typing import List from typing import Optional from typing import Tuple +from typing import TypeVar import weakref import dd_trace_api +from wrapt.importer import when_imported import ddtrace +from ddtrace.internal.logger import get_logger +from ddtrace.internal.wrapping.context import WrappingContext _DD_HOOK_NAME = "dd.hook" _TRACER_KEY = "Tracer" _STUB_TO_REAL = weakref.WeakKeyDictionary() _STUB_TO_REAL[dd_trace_api.tracer] = ddtrace.tracer +log = get_logger(__name__) +T = TypeVar("T") + + +class DDTraceAPIWrappingContextBase(WrappingContext): + def _handle_return(self) -> None: + method_name = self.__frame__.f_code.co_name + stub_self = self.get_local("self") + api_return_value = self.get_local("retval") + _call_on_real_instance(stub_self, method_name, api_return_value, self.get_local("name")) + + def _handle_enter(self) -> None: + pass + + def __enter__(self) -> "DDTraceAPIWrappingContextBase": + super().__enter__() + + try: + self._handle_enter() + except Exception: # noqa: E722 + log.debug("Error handling dd_trace_api instrumentation enter", exc_info=True) + + return self + + def __return__(self, value: T) -> T: + """Always return the original value no matter what our instrumentation does""" + try: + self._handle_return() + except Exception: # noqa: E722 + log.debug("Error handling instrumentation return", exc_info=True) + + return value def _proxy_span_arguments(args: List, kwargs: Dict) -> Tuple[List, Dict]: @@ -57,12 +93,17 @@ def get_version() -> str: def patch(tracer=None): if getattr(dd_trace_api, "__datadog_patch", False): return - dd_trace_api.__datadog_patch = True _STUB_TO_REAL[dd_trace_api.tracer] = tracer - if not getattr(dd_trace_api, "__dd_has_audit_hook", False): + if False and not getattr(dd_trace_api, "__dd_has_audit_hook", False): addaudithook(_hook) dd_trace_api.__dd_has_audit_hook = True + @when_imported("dd_trace_api") + def _(m): + DDTraceAPIWrappingContextBase(m.Tracer.start_span).wrap() + + dd_trace_api.__datadog_patch = True + def unpatch(): if not getattr(dd_trace_api, "__datadog_patch", False):