Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DBT model instrumentation #11268

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# These modules are added to the context. Consider alternative
# approaches which will extend well to potentially many modules
import pytz
from opentelemetry import trace

import dbt.flags as flags_module
from dbt import tracking, utils
Expand Down Expand Up @@ -86,12 +87,20 @@ def get_itertools_module_context() -> Dict[str, Any]:
return {name: getattr(itertools, name) for name in context_exports}


def get_otel_trace_module_context() -> Dict[str, Dict[str, Any]]:
context_exports = trace.__all__
return {name: getattr(trace, name) for name in context_exports}


def get_context_modules() -> Dict[str, Dict[str, Any]]:
return {
"pytz": get_pytz_module_context(),
"datetime": get_datetime_module_context(),
"re": get_re_module_context(),
"itertools": get_itertools_module_context(),
"opentelemetry": {
"trace": get_otel_trace_module_context(),
},
}


Expand Down
134 changes: 72 additions & 62 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from multiprocessing.pool import ThreadPool
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type

from opentelemetry import trace
from opentelemetry.trace import StatusCode

from dbt import tracking, utils
from dbt.adapters.base import BaseAdapter, BaseRelation
from dbt.adapters.capability import Capability
Expand Down Expand Up @@ -328,7 +331,6 @@ def execute(self, model, manifest):
)

hook_ctx = self.adapter.pre_model_hook(context_config)

return self._execute_model(hook_ctx, context_config, model, context, materialization_macro)


Expand Down Expand Up @@ -699,6 +701,7 @@ def __init__(
) -> None:
super().__init__(args, config, manifest)
self.batch_map = batch_map
self._dbt_tracer = trace.get_tracer("com.dbt.runner")

def raise_on_first_error(self) -> bool:
return False
Expand Down Expand Up @@ -887,77 +890,83 @@ def safe_run_hooks(
failed = False
num_hooks = len(ordered_hooks)

for idx, hook in enumerate(ordered_hooks, 1):
with log_contextvars(node_info=hook.node_info):
hook.index = idx
hook_name = f"{hook.package_name}.{hook_type}.{hook.index - 1}"
execution_time = 0.0
timing: List[TimingInfo] = []
failures = 1

if not failed:
with collect_timing_info("compile", timing.append):
sql = self.get_hook_sql(
adapter, hook, hook.index, num_hooks, extra_context
if num_hooks == 0:
return status

with self._dbt_tracer.start_as_current_span(hook_type) as hook_span:
for idx, hook in enumerate(ordered_hooks, 1):
with log_contextvars(node_info=hook.node_info):
hook.index = idx
hook_name = f"{hook.package_name}.{hook_type}.{hook.index - 1}"
execution_time = 0.0
timing: List[TimingInfo] = []
failures = 1

if not failed:
with collect_timing_info("compile", timing.append):
sql = self.get_hook_sql(
adapter, hook, hook.index, num_hooks, extra_context
)

started_at = timing[0].started_at or datetime.utcnow()
hook.update_event_status(
started_at=started_at.isoformat(), node_status=RunningStatus.Started
)

started_at = timing[0].started_at or datetime.utcnow()
hook.update_event_status(
started_at=started_at.isoformat(), node_status=RunningStatus.Started
fire_event(
LogHookStartLine(
statement=hook_name,
index=hook.index,
total=num_hooks,
node_info=hook.node_info,
)
)

with collect_timing_info("execute", timing.append):
status, message = get_execution_status(sql, adapter)

finished_at = timing[1].completed_at or datetime.utcnow()
hook.update_event_status(finished_at=finished_at.isoformat())
execution_time = (finished_at - started_at).total_seconds()
failures = 0 if status == RunStatus.Success else 1

if status == RunStatus.Success:
message = f"{hook_name} passed"
else:
message = f"{hook_name} failed, error:\n {message}"
failed = True
hook_span.set_status(StatusCode.ERROR)
else:
status = RunStatus.Skipped
message = f"{hook_name} skipped"

hook.update_event_status(node_status=status)
hook_span.set_attribute("node.status", status.value)

self.node_results.append(
RunResult(
status=status,
thread_id="main",
timing=timing,
message=message,
adapter_response={},
execution_time=execution_time,
failures=failures,
node=hook,
)
)

fire_event(
LogHookStartLine(
LogHookEndLine(
statement=hook_name,
status=status,
index=hook.index,
total=num_hooks,
execution_time=execution_time,
node_info=hook.node_info,
)
)

with collect_timing_info("execute", timing.append):
status, message = get_execution_status(sql, adapter)

finished_at = timing[1].completed_at or datetime.utcnow()
hook.update_event_status(finished_at=finished_at.isoformat())
execution_time = (finished_at - started_at).total_seconds()
failures = 0 if status == RunStatus.Success else 1

if status == RunStatus.Success:
message = f"{hook_name} passed"
else:
message = f"{hook_name} failed, error:\n {message}"
failed = True
else:
status = RunStatus.Skipped
message = f"{hook_name} skipped"

hook.update_event_status(node_status=status)

self.node_results.append(
RunResult(
status=status,
thread_id="main",
timing=timing,
message=message,
adapter_response={},
execution_time=execution_time,
failures=failures,
node=hook,
)
)

fire_event(
LogHookEndLine(
statement=hook_name,
status=status,
index=hook.index,
total=num_hooks,
execution_time=execution_time,
node_info=hook.node_info,
)
)

if hook_type == RunHookType.Start and ordered_hooks:
fire_event(Formatting(""))

Expand Down Expand Up @@ -991,8 +1000,9 @@ def before_run(self, adapter: BaseAdapter, selected_uids: AbstractSet[str]) -> R
with adapter.connection_named("master"):
self.defer_to_manifest()
required_schemas = self.get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, required_schemas)
self.populate_adapter_cache(adapter, required_schemas)
with self._dbt_tracer.start_as_current_span("metadata setup") as _:
self.create_schemas(adapter, required_schemas)
self.populate_adapter_cache(adapter, required_schemas)
self.populate_microbatch_batches(selected_uids)
group_lookup.init(self.manifest, selected_uids)
run_hooks_status = self.safe_run_hooks(adapter, RunHookType.Start, {})
Expand Down
41 changes: 34 additions & 7 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from pathlib import Path
from typing import AbstractSet, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

from opentelemetry import context, trace
from opentelemetry.trace import Link, SpanContext, StatusCode

import dbt.exceptions
import dbt.tracking
import dbt.utils
Expand Down Expand Up @@ -91,6 +94,8 @@ def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> No
self.previous_defer_state: Optional[PreviousState] = None
self.run_count: int = 0
self.started_at: float = 0
self._node_span_context_mapping: Dict[str, SpanContext] = {}
self._dbt_tracer = trace.get_tracer("com.dbt.runner")

if self.args.state:
self.previous_state = PreviousState(
Expand Down Expand Up @@ -222,14 +227,28 @@ def get_runner(self, node) -> BaseRunner:

return cls(self.config, adapter, node, run_count, num_nodes)

def call_runner(self, runner: BaseRunner) -> RunResult:
with log_contextvars(node_info=runner.node.node_info):
def call_runner(self, runner: BaseRunner, parent_context=None) -> RunResult:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why passing the context?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because contextvars seem not be propagated automatically for thread pools. Hence I had to do this manually.

node_info = runner.node.node_info
links = []
if hasattr(runner.node.depends_on, "nodes"):
for parent_node in runner.node.depends_on.nodes:
if parent_node in self._node_span_context_mapping:
links.append(
Link(
self._node_span_context_mapping[parent_node],
{"parent_model_fqn": parent_node},
),
)
with log_contextvars(node_info=node_info), self._dbt_tracer.start_as_current_span(
node_info["unique_id"], context=parent_context, links=links
) as node_span:
self._node_span_context_mapping[node_info["unique_id"]] = node_span.get_span_context()
runner.node.update_event_status(
started_at=datetime.utcnow().isoformat(), node_status=RunningStatus.Started
)
fire_event(
NodeStart(
node_info=runner.node.node_info,
node_info=node_info,
)
)
try:
Expand All @@ -242,10 +261,16 @@ def call_runner(self, runner: BaseRunner) -> RunResult:
result = None
thread_exception = e
finally:
if result.status in (NodeStatus.Error, NodeStatus.Fail, NodeStatus.PartialSuccess):
node_span.set_status(StatusCode.ERROR)
node_span.set_attribute("node.status", result.status.value)
node_span.set_attribute("node.materialization", node_info["materialized"])
node_span.set_attribute("node.database", node_info["node_relation"]["database"])
node_span.set_attribute("node.schema", node_info["node_relation"]["schema"])
if result is not None:
fire_event(
NodeFinished(
node_info=runner.node.node_info,
node_info=node_info,
run_result=result.to_msg_dict(),
)
)
Expand All @@ -256,7 +281,7 @@ def call_runner(self, runner: BaseRunner) -> RunResult:
GenericExceptionOnRun(
unique_id=runner.node.unique_id,
exc=str(thread_exception),
node_info=runner.node.node_info,
node_info=node_info,
)
)

Expand Down Expand Up @@ -304,6 +329,7 @@ def _submit(self, pool, args, callback):

This does still go through the callback path for result collection.
"""
args.append(context.get_current())
if self.config.args.single_threaded:
callback(self.call_runner(*args))
else:
Expand Down Expand Up @@ -501,7 +527,8 @@ def populate_adapter_cache(
def before_run(self, adapter: BaseAdapter, selected_uids: AbstractSet[str]) -> RunStatus:
with adapter.connection_named("master"):
self.defer_to_manifest()
self.populate_adapter_cache(adapter)
with self._dbt_tracer.start_as_current_span("metadata setup") as _:
self.populate_adapter_cache(adapter)
return RunStatus.Success

def after_run(self, adapter, results) -> None:
Expand Down Expand Up @@ -684,9 +711,9 @@ def create_schema(relation: BaseRelation) -> None:

list_futures = []
create_futures = []

# TODO: following has a mypy issue because profile and project config
# defines threads as int and HasThreadingConfig defines it as Optional[int]

with dbt_common.utils.executor(self.config) as tpe: # type: ignore
for req in required_databases:
if req.database is None:
Expand Down
5 changes: 4 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-adapters
git+https://github.com/sfc-gh-vguttha/dbt-adapters.git@vguttha-add-telemetry#subdirectory=dbt-adapters
git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter
git+https://github.com/dbt-labs/dbt-common.git@main
git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-postgres
git+https://github.com/sfc-gh-vguttha/dbt-adapters.git@vguttha-add-telemetry#subdirectory=dbt-snowflake
# black must match what's in .pre-commit-config.yaml to be sure local env matches CI
Comment on lines +1 to +5
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: will revert these changes once I am fully complete with my testing.

black==24.3.0
bumpversion
Expand Down Expand Up @@ -38,3 +39,5 @@ types-pytz
types-requests
types-setuptools
mocker
opentelemetry-api
opentelemetry-sdk
34 changes: 34 additions & 0 deletions tests/functional/dbt_runner/test_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from unittest import mock

import pytest
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from dbt.adapters.factory import FACTORY, reset_adapters
from dbt.cli.exceptions import DbtUsageException
Expand Down Expand Up @@ -153,6 +158,7 @@ class TestDbtRunnerHooks:
def models(self):
return {
"models.sql": "select 1 as id",
"model2.sql": "select * from {{ ref('models') }}",
}

@pytest.fixture(scope="class")
Expand All @@ -163,3 +169,31 @@ def test_node_info_non_persistence(self, project):
dbt = dbtRunner()
dbt.invoke(["run", "--select", "models"])
assert get_node_info() == {}

def test_dbt_runner_spans(self, project):
tracer_provider = TracerProvider(resource=Resource.get_empty())
span_exporter = InMemorySpanExporter()
trace.set_tracer_provider(tracer_provider)
trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(span_exporter))
dbt = dbtRunner()
dbt.invoke(["run", "--select", "models", "model2"])
assert get_node_info() == {}
exported_spans = span_exporter.get_finished_spans()
assert len(exported_spans) == 3
assert exported_spans[0].instrumentation_scope.name == "com.dbt.runner"
span_names = [span.name for span in exported_spans]
span_names.sort()
assert span_names == ["model.test.model2", "model.test.models", "on-run-end"]
model2_span = None
models_span = None
for span in exported_spans:
if span.name == "model.test.model2":
model2_span = span
if span.name == "model.test.models":
models_span = span

# verify span links
assert len(model2_span.links) == 1
assert model2_span.links[0].attributes["parent_model_fqn"] == "model.test.models"
assert model2_span.links[0].context.span_id == models_span.context.span_id
assert model2_span.links[0].context.trace_id == models_span.context.trace_id
Loading
Loading