From 3a584ac8fdbb3081fd2261b9e0de50db7be19e67 Mon Sep 17 00:00:00 2001 From: Lei Lei Date: Wed, 11 Mar 2026 13:42:34 +0800 Subject: [PATCH 1/2] first migration. --- .claude/skills/hta-cli/SKILL.md | 81 + .../skills/hta-cli/references/subcommands.md | 528 +++++ .github/workflows/ci-go.yml | 29 + .gitignore | 0 README.md | 0 cmd/tracepyre/main.go | 930 +++++++++ go.mod | 17 + go.sum | 23 + pkg/analysis/callstack.go | 125 ++ pkg/analysis/criticalpath/critical_path.go | 1726 +++++++++++++++++ .../criticalpath/critical_path_test.go | 158 ++ pkg/analysis/helpers_test.go | 20 + pkg/analysis/intervals.go | 32 + pkg/analysis/intervals_test.go | 60 + pkg/analysis/kernel/annotation.go | 131 ++ pkg/analysis/kernel/annotation_test.go | 85 + pkg/analysis/kernel/aten_delay.go | 308 +++ pkg/analysis/kernel/aten_delay_test.go | 58 + pkg/analysis/kernel/helpers_test.go | 60 + pkg/analysis/kernel/kernel_breakdown.go | 268 +++ pkg/analysis/kernel/kernel_breakdown_test.go | 214 ++ pkg/analysis/kernel/kernel_sequences.go | 500 +++++ pkg/analysis/kernel/kernel_sequences_test.go | 230 +++ pkg/analysis/kernel/launch_stats.go | 105 + pkg/analysis/kernel/launch_stats_test.go | 88 + pkg/analysis/kernel/testmain_test.go | 86 + pkg/analysis/kerneltype.go | 66 + pkg/analysis/kerneltype_test.go | 36 + pkg/analysis/mathutil.go | 30 + pkg/analysis/names.go | 55 + pkg/analysis/names_test.go | 61 + pkg/analysis/profiler_steps.go | 42 + pkg/analysis/profiler_steps_test.go | 88 + pkg/analysis/resource/cupti_counters.go | 201 ++ pkg/analysis/resource/cupti_counters_test.go | 160 ++ pkg/analysis/resource/helpers_test.go | 61 + pkg/analysis/resource/memory_bw.go | 299 +++ pkg/analysis/resource/memory_bw_test.go | 170 ++ pkg/analysis/resource/queue_length.go | 377 ++++ pkg/analysis/resource/queue_length_test.go | 197 ++ pkg/analysis/resource/testmain_test.go | 86 + pkg/analysis/resource/trace_with_counters.go | 284 +++ .../resource/trace_with_counters_test.go | 165 ++ pkg/analysis/straggler/straggler.go | 417 ++++ pkg/analysis/straggler/straggler_test.go | 138 ++ pkg/analysis/temporal/idle_time.go | 226 +++ pkg/analysis/temporal/idle_time_test.go | 128 ++ pkg/analysis/temporal/overlap.go | 126 ++ pkg/analysis/temporal/overlap_test.go | 65 + pkg/analysis/temporal/temporal.go | 103 + pkg/analysis/temporal/temporal_test.go | 105 + pkg/pipeline/preprocess.go | 351 ++++ pkg/store/reader.go | 889 +++++++++ pkg/store/schema.go | 193 ++ pkg/symbol/table.go | 66 + pkg/symbol/table_test.go | 49 + pkg/trace/parser.go | 318 +++ 57 files changed, 11414 insertions(+) create mode 100644 .claude/skills/hta-cli/SKILL.md create mode 100644 .claude/skills/hta-cli/references/subcommands.md create mode 100644 .github/workflows/ci-go.yml create mode 100644 .gitignore create mode 100644 README.md create mode 100644 cmd/tracepyre/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 pkg/analysis/callstack.go create mode 100644 pkg/analysis/criticalpath/critical_path.go create mode 100644 pkg/analysis/criticalpath/critical_path_test.go create mode 100644 pkg/analysis/helpers_test.go create mode 100644 pkg/analysis/intervals.go create mode 100644 pkg/analysis/intervals_test.go create mode 100644 pkg/analysis/kernel/annotation.go create mode 100644 pkg/analysis/kernel/annotation_test.go create mode 100644 pkg/analysis/kernel/aten_delay.go create mode 100644 pkg/analysis/kernel/aten_delay_test.go create mode 100644 pkg/analysis/kernel/helpers_test.go create mode 100644 pkg/analysis/kernel/kernel_breakdown.go create mode 100644 pkg/analysis/kernel/kernel_breakdown_test.go create mode 100644 pkg/analysis/kernel/kernel_sequences.go create mode 100644 pkg/analysis/kernel/kernel_sequences_test.go create mode 100644 pkg/analysis/kernel/launch_stats.go create mode 100644 pkg/analysis/kernel/launch_stats_test.go create mode 100644 pkg/analysis/kernel/testmain_test.go create mode 100644 pkg/analysis/kerneltype.go create mode 100644 pkg/analysis/kerneltype_test.go create mode 100644 pkg/analysis/mathutil.go create mode 100644 pkg/analysis/names.go create mode 100644 pkg/analysis/names_test.go create mode 100644 pkg/analysis/profiler_steps.go create mode 100644 pkg/analysis/profiler_steps_test.go create mode 100644 pkg/analysis/resource/cupti_counters.go create mode 100644 pkg/analysis/resource/cupti_counters_test.go create mode 100644 pkg/analysis/resource/helpers_test.go create mode 100644 pkg/analysis/resource/memory_bw.go create mode 100644 pkg/analysis/resource/memory_bw_test.go create mode 100644 pkg/analysis/resource/queue_length.go create mode 100644 pkg/analysis/resource/queue_length_test.go create mode 100644 pkg/analysis/resource/testmain_test.go create mode 100644 pkg/analysis/resource/trace_with_counters.go create mode 100644 pkg/analysis/resource/trace_with_counters_test.go create mode 100644 pkg/analysis/straggler/straggler.go create mode 100644 pkg/analysis/straggler/straggler_test.go create mode 100644 pkg/analysis/temporal/idle_time.go create mode 100644 pkg/analysis/temporal/idle_time_test.go create mode 100644 pkg/analysis/temporal/overlap.go create mode 100644 pkg/analysis/temporal/overlap_test.go create mode 100644 pkg/analysis/temporal/temporal.go create mode 100644 pkg/analysis/temporal/temporal_test.go create mode 100644 pkg/pipeline/preprocess.go create mode 100644 pkg/store/reader.go create mode 100644 pkg/store/schema.go create mode 100644 pkg/symbol/table.go create mode 100644 pkg/symbol/table_test.go create mode 100644 pkg/trace/parser.go diff --git a/.claude/skills/hta-cli/SKILL.md b/.claude/skills/hta-cli/SKILL.md new file mode 100644 index 0000000..cfed73e --- /dev/null +++ b/.claude/skills/hta-cli/SKILL.md @@ -0,0 +1,81 @@ +--- +name: pytorch-profile +description: >- +hollistic trace analysis (hta) gives insight about distributed training with pytorch. +It should be used when the user asks to "analyse pytorch trace", +or mentions any subcommand like temporal-breakdown, comm-comp-overlap, +gpu-kernel-breakdown, idle-time-breakdown, critical-path, queue-length, etc. +--- + +# Pytorch Profile Data + +The hta CLI (`python -m hta`) exposes every major trace analysis as a standalone subcommand. It is designed for CI pipelines, shell scripts, and quick interactive analysis without notebooks. + +## Two-Step Workflow + +All CLI usage follows a **pre-process then analyze** pattern: + +```bash +# Step 1: Parse raw PyTorch Profiler traces into parquet +python -m hta pre-process --trace-dir ./raw_traces -o ./preprocessed + +# Step 2: Run any analysis subcommand on the preprocessed directory +python -m hta temporal-breakdown -i ./preprocessed +python -m hta idle-time-breakdown -i ./preprocessed --ranks 0,1 +``` + +**Step 1 (`pre-process`)** reads raw JSON traces from `--trace-dir`, writes one `.parquet` file per rank plus a `metadata.json` into `-o`. This only needs to run once per trace set. + +**Step 2 (any analysis subcommand)** reads from the pre-processed directory via `-i` / `--input`. Most subcommands print markdown tables to stdout. + +## Subcommand Quick Reference + +| Subcommand | Description | Key Args (besides `-i`) | Category | +|---|---|---|---| +| `pre-process` | Parse raw traces to parquet | `--trace-dir`, `-o` (both required) | Preprocessing | +| `temporal-breakdown` | Time breakdown (compute, comm, idle) per rank | — | Overview | +| `comm-comp-overlap` | Communication/computation overlap per rank | — | Overview | +| `profiler-steps` | List profiler step indices | — | Overview | +| `potential-stragglers` | Identify slow ranks | `--num-candidates`, `--profiler-steps` | Overview | +| `gpu-kernel-breakdown` | GPU time by kernel type + top kernels | `--num-kernels`, `--duration-ratio`, `--no-memory-kernels` | GPU Kernels | +| `gpu-kernels-with-annotations` | GPU kernels with user annotation context | `--rank` (required) | GPU Kernels | +| `gpu-user-annotation-breakdown` | GPU/CPU time by user annotations | `--cpu`, `--duration-ratio`, `--num-kernels`, `--allowlist-patterns` | GPU Kernels | +| `frequent-cuda-kernel-sequences` | Frequent CUDA kernel patterns per operator | `--operator-name`, `--output-dir` (both required), `--top-k`, `--rank` | GPU Kernels | +| `aten-op-kernels-and-delay` | ATen op to GPU kernel mapping with launch delay | `--ranks`, `--sort-by` | GPU Kernels | +| `cuda-kernel-launch-stats` | CUDA kernel launch duration and delay stats | `--ranks`, `--runtime-cutoff`, `--launch-delay-cutoff` | GPU Kernels | +| `generate-trace-with-counters` | Augmented trace with queue length / memory BW counters | `--ranks`, `--time-series`, `--output-suffix` | Counters | +| `queue-length-summary` | Queue length summary stats per rank | `--ranks` | Counters | +| `queue-length-time-series` | Full queue length time series per rank | `--ranks` | Counters | +| `blocked-on-full-queue` | Time CPU blocked on full GPU queue | `--ranks`, `--max-queue-length` | Counters | +| `memory-bw-summary` | Memory bandwidth summary stats per rank | `--ranks` | Counters | +| `memory-bw-time-series` | Full memory bandwidth time series per rank | `--ranks` | Counters | +| `idle-time-breakdown` | GPU idle time by category per rank/stream | `--ranks`, `--streams`, `--show-idle-interval-stats` | Idle Time | +| `cupti-counter-data` | CUPTI hardware counter data with operators | `--ranks` | CUPTI | +| `critical-path` | Critical path analysis with trace overlay | `--rank`, `--annotation`, `--instance-id`, `--output-dir` (all required) | Critical Path | + +## Common Patterns + +**Filtering by rank:** Most analysis subcommands accept `--ranks` as a comma-separated list (e.g., `--ranks 0,1,3`). If omitted, all ranks are analyzed. + +**Output format:** Most subcommands print markdown tables to stdout. Pipe to a file or use in scripts: +```bash +python -m hta temporal-breakdown -i ./preprocessed > results.md +``` + +**Getting help:** Run `python -m hta --help` for all subcommands, or `python -m hta --help` for a specific one. + +**Running via uv:** In this project, prefix with `uv run`: +```bash +uv run python -m hta pre-process --trace-dir ./traces -o ./preprocessed +uv run python -m hta temporal-breakdown -i ./preprocessed +``` + +## Key Source Files + +- `hta/__main__.py` — CLI implementation (argument parsing and subcommand handlers) +- `hta/trace_analysis.py` — `TraceAnalysis` class that backs every subcommand +- `docs/cli-guide.md` — Human-facing CLI documentation + +## Additional Resources + +For full argument tables, types, defaults, and detailed output descriptions for every subcommand, see `references/subcommands.md`. diff --git a/.claude/skills/hta-cli/references/subcommands.md b/.claude/skills/hta-cli/references/subcommands.md new file mode 100644 index 0000000..a8b9861 --- /dev/null +++ b/.claude/skills/hta-cli/references/subcommands.md @@ -0,0 +1,528 @@ +# HTA CLI Subcommand Reference + +Full argument tables, examples, and output descriptions for all 20 HTA CLI subcommands. + +Source of truth: `hta/__main__.py` (argument definitions), `docs/cli-guide.md` (human-facing docs). + +--- + +## Preprocessing + +### `pre-process` + +Parse raw PyTorch Profiler traces and save as parquet for fast repeated analysis. + +```bash +python -m hta pre-process --trace-dir -o [--include-last-profiler-step] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `--trace-dir` | str | yes | — | Path to directory containing raw trace JSON files | +| `-o` / `--output` | str | yes | — | Output directory for parquet files and metadata.json | +| `--include-last-profiler-step` | flag | no | false | Include the last profiler step (excluded by default) | + +**Example:** +```bash +python -m hta pre-process --trace-dir ./raw_traces -o ./preprocessed +``` + +**Output:** One `.parquet` file per rank and a `metadata.json` in the output directory. Prints confirmation message. + +--- + +## Overview Analysis + +### `temporal-breakdown` + +Show how time is spent (compute, communication, idle, etc.) for each rank. + +See: `docs/source/features/temporal_breakdown.rst` + +```bash +python -m hta temporal-breakdown -i +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | + +**Example:** +```bash +python -m hta temporal-breakdown -i ./preprocessed +``` + +**Output:** Markdown table with one row per rank showing time percentages for each category (idle, compute, communication, etc.). + +--- + +### `comm-comp-overlap` + +Show the overlap between communication and computation for each rank. + +See: `docs/source/features/comm_comp_overlap.rst` + +```bash +python -m hta comm-comp-overlap -i +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | + +**Example:** +```bash +python -m hta comm-comp-overlap -i ./preprocessed +``` + +**Output:** Markdown table with overlap percentages per rank. + +--- + +### `profiler-steps` + +List the profiler step indices found in the trace. + +```bash +python -m hta profiler-steps -i +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | + +**Example:** +```bash +python -m hta profiler-steps -i ./preprocessed +# 2,3,4,5,6 +``` + +**Output:** Comma-separated list of profiler step integers printed to stdout. + +--- + +### `potential-stragglers` + +Identify ranks that are potential stragglers (slower than peers). + +```bash +python -m hta potential-stragglers -i [--num-candidates N] [--profiler-steps STEPS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--num-candidates` | int | no | None | Maximum number of straggler candidates to return | +| `--profiler-steps` | str | no | None | Comma-separated profiler step indices to analyze | + +**Example:** +```bash +python -m hta potential-stragglers -i ./preprocessed --num-candidates 2 +# 3,7 +``` + +**Output:** Comma-separated list of rank IDs that are potential stragglers. + +--- + +## GPU Kernel Analysis + +### `gpu-kernel-breakdown` + +Break down GPU time by kernel type (computation, communication, memory) and list top kernels. + +See: `docs/source/features/kernel_breakdown.rst` + +```bash +python -m hta gpu-kernel-breakdown -i [--duration-ratio R] [--num-kernels N] [--no-memory-kernels] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--duration-ratio` | float | no | None | Minimum fraction of total duration for a kernel to be included | +| `--num-kernels` | int | no | None | Maximum number of top kernels to show | +| `--no-memory-kernels` | flag | no | false | Exclude memory-related kernels from the breakdown | + +**Example:** +```bash +python -m hta gpu-kernel-breakdown -i ./preprocessed --num-kernels 10 +``` + +**Output:** Two markdown tables: +1. **Kernel Type Breakdown** — time per kernel category (compute, communication, memory) +2. **Top Kernels** — individual kernel durations and counts + +--- + +### `gpu-kernels-with-annotations` + +List GPU kernels annotated with their user-defined annotation context (e.g., forward/backward/optimizer). + +See: `docs/source/features/kernel_breakdown.rst` (related) + +```bash +python -m hta gpu-kernels-with-annotations -i --rank R [--no-expand-names] [--no-shorten-names] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--rank` | int | yes | — | Rank to analyze | +| `--no-expand-names` | flag | no | false | Do not expand kernel names | +| `--no-shorten-names` | flag | no | false | Do not shorten kernel names | + +**Example:** +```bash +python -m hta gpu-kernels-with-annotations -i ./preprocessed --rank 0 +``` + +**Output:** Markdown table with one row per GPU kernel, including its user annotation context. + +--- + +### `gpu-user-annotation-breakdown` + +Break down GPU (or CPU) time by user-defined annotations. + +See: `docs/source/features/kernel_breakdown.rst` (related) + +```bash +python -m hta gpu-user-annotation-breakdown -i [--cpu] [--duration-ratio R] [--num-kernels N] [--allowlist-patterns PAT ...] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--cpu` | flag | no | false | Use CPU time instead of GPU time | +| `--duration-ratio` | float | no | None | Minimum fraction of total duration for inclusion | +| `--num-kernels` | int | no | None | Maximum number of entries to show | +| `--allowlist-patterns` | str (multiple) | no | None | Annotation patterns to keep distinct (space-separated) | + +**Example:** +```bash +python -m hta gpu-user-annotation-breakdown -i ./preprocessed --duration-ratio 0.05 +``` + +**Output:** Markdown table with time breakdown by user annotation. + +--- + +### `frequent-cuda-kernel-sequences` + +Find frequently occurring sequences of CUDA kernels launched by a given operator. + +See: `docs/source/features/frequent_cuda_kernels.rst` + +```bash +python -m hta frequent-cuda-kernel-sequences -i --operator-name NAME --output-dir DIR [--min-pattern-len N] [--rank R] [--top-k K] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--operator-name` | str | yes | — | Name of the CPU operator to analyze | +| `--output-dir` | str | yes | — | Directory for output files | +| `--min-pattern-len` | int | no | None | Minimum length of kernel sequence patterns | +| `--rank` | int | no | None | Specific rank to analyze | +| `--top-k` | int | no | None | Number of top frequent patterns to return | + +**Example:** +```bash +python -m hta frequent-cuda-kernel-sequences -i ./preprocessed \ + --operator-name aten::linear --output-dir ./freq_out --top-k 5 +``` + +**Output:** Markdown table of frequent kernel sequence patterns with their counts. + +--- + +### `aten-op-kernels-and-delay` + +Map ATen operators to their launched GPU kernels, showing launch delay. + +See: `docs/source/features/cuda_kernel_launch_stats.rst` (related) + +```bash +python -m hta aten-op-kernels-and-delay -i [--ranks RANKS] [--sort-by COLS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | +| `--sort-by` | str | no | None | Comma-separated column names to sort by | + +**Example:** +```bash +python -m hta aten-op-kernels-and-delay -i ./preprocessed --ranks 0 --sort-by "duration" +``` + +**Output:** Per-rank markdown tables mapping ATen ops to GPU kernels with delay statistics. + +--- + +### `cuda-kernel-launch-stats` + +Compute statistics about CUDA kernel launches (durations, launch delays, short kernels). + +See: `docs/source/features/cuda_kernel_launch_stats.rst` + +```bash +python -m hta cuda-kernel-launch-stats -i [--ranks RANKS] [--runtime-cutoff N] [--launch-delay-cutoff N] [--no-memory-events] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | +| `--runtime-cutoff` | int | no | None | Runtime threshold (microseconds) for flagging short kernels | +| `--launch-delay-cutoff` | int | no | None | Launch delay threshold (microseconds) for flagging slow launches | +| `--no-memory-events` | flag | no | false | Exclude memory events from the analysis | + +**Example:** +```bash +python -m hta cuda-kernel-launch-stats -i ./preprocessed --runtime-cutoff 10 +``` + +**Output:** Per-rank markdown tables with kernel launch statistics. + +--- + +## Augmented Counters (Queue Length & Memory Bandwidth) + +### `generate-trace-with-counters` + +Generate an augmented trace file with queue length and/or memory bandwidth counter time series embedded. + +See: `docs/source/features/augmented_counters.rst` + +```bash +python -m hta generate-trace-with-counters -i [--ranks RANKS] [--time-series TYPE] [--output-suffix SUFFIX] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | +| `--time-series` | str | no | None | Which counters: `queue_length`, `memcpy_bandwidth`, or `both` | +| `--output-suffix` | str | no | None | Suffix appended to the output trace filename | + +**Example:** +```bash +python -m hta generate-trace-with-counters -i ./preprocessed --time-series both +``` + +**Output:** Augmented trace JSON file(s) in the original trace directory, viewable in `chrome://tracing` or Perfetto. Prints confirmation message. + +--- + +### `queue-length-summary` + +Show summary statistics of the CUDA stream queue length (min, max, mean, etc.) per rank. + +See: `docs/source/features/augmented_counters.rst` + +```bash +python -m hta queue-length-summary -i [--ranks RANKS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | + +**Example:** +```bash +python -m hta queue-length-summary -i ./preprocessed +``` + +**Output:** Markdown table with queue length statistics per rank. + +--- + +### `queue-length-time-series` + +Get the full queue length time series (timestamp, queue_length) per rank. + +See: `docs/source/features/augmented_counters.rst` + +```bash +python -m hta queue-length-time-series -i [--ranks RANKS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | + +**Example:** +```bash +python -m hta queue-length-time-series -i ./preprocessed --ranks 0 +``` + +**Output:** Per-rank markdown tables of (timestamp, queue_length) data points. + +--- + +### `blocked-on-full-queue` + +Compute time the CPU spent blocked because the GPU launch queue was full. + +See: `docs/source/features/augmented_counters.rst` + +```bash +python -m hta blocked-on-full-queue -i [--ranks RANKS] [--max-queue-length N] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | +| `--max-queue-length` | int | no | None | Queue length considered "full" (default: NVIDIA limit of 1024) | + +**Example:** +```bash +python -m hta blocked-on-full-queue -i ./preprocessed --max-queue-length 1024 +``` + +**Output:** Markdown table with blocking duration per rank. + +--- + +### `memory-bw-summary` + +Show memory bandwidth summary statistics per rank. + +See: `docs/source/features/augmented_counters.rst` + +```bash +python -m hta memory-bw-summary -i [--ranks RANKS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | + +**Example:** +```bash +python -m hta memory-bw-summary -i ./preprocessed +``` + +**Output:** Markdown table with memory bandwidth statistics per rank. + +--- + +### `memory-bw-time-series` + +Get the full memory bandwidth time series per rank. + +See: `docs/source/features/augmented_counters.rst` + +```bash +python -m hta memory-bw-time-series -i [--ranks RANKS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | + +**Example:** +```bash +python -m hta memory-bw-time-series -i ./preprocessed --ranks 0,1 +``` + +**Output:** Per-rank markdown tables of memory bandwidth data points over time. + +--- + +## Idle Time + +### `idle-time-breakdown` + +Break down GPU idle time by category (host wait, kernel wait, other) per rank and stream. + +See: `docs/source/features/idle_time_breakdown.rst` + +```bash +python -m hta idle-time-breakdown -i [--ranks RANKS] [--streams STREAMS] [--show-idle-interval-stats] [--consecutive-kernel-delay N] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | +| `--streams` | str | no | None | Comma-separated CUDA stream IDs | +| `--show-idle-interval-stats` | flag | no | false | Also output statistics about individual idle intervals | +| `--consecutive-kernel-delay` | int | no | None | Threshold (microseconds) for classifying gaps between consecutive kernels | + +**Example:** +```bash +python -m hta idle-time-breakdown -i ./preprocessed --show-idle-interval-stats +``` + +**Output:** Markdown table "Idle Time Breakdown" with idle time categories per rank/stream. If `--show-idle-interval-stats` is set, a second table "Idle Interval Statistics" is also printed. + +--- + +## CUPTI Counters + +### `cupti-counter-data` + +Extract CUPTI hardware performance counter data joined with operator information. + +See: `docs/source/features/cupti_counter_analysis.rst` + +```bash +python -m hta cupti-counter-data -i [--ranks RANKS] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--ranks` | str | no | None | Comma-separated ranks | + +**Example:** +```bash +python -m hta cupti-counter-data -i ./preprocessed --ranks 0 +``` + +**Output:** Indexed markdown tables of CUPTI counter data with associated operator information. + +--- + +## Critical Path + +### `critical-path` + +Run critical path analysis on a specific annotation instance and overlay the result onto a trace file. + +See: `docs/source/features/lightweight_critical_path_analysis.rst` + +```bash +python -m hta critical-path -i --rank R --annotation ANN --instance-id ID --output-dir DIR [--data-load-events EVT ...] [--show-all-edges] +``` + +| Argument | Type | Required | Default | Description | +|---|---|---|---|---| +| `-i` / `--input` | str | yes | — | Path to pre-processed trace directory | +| `--rank` | int | yes | — | Rank to analyze | +| `--annotation` | str | yes | — | User annotation name (e.g., `ProfilerStep`) | +| `--instance-id` | str | yes | — | Single int (e.g., `3`) or `start,end` range (e.g., `3,5`) | +| `--output-dir` | str | yes | — | Directory for the overlaid trace output | +| `--data-load-events` | str (multiple) | no | None | Names of data loading events (space-separated) | +| `--show-all-edges` | flag | no | false | Show all edges in the overlaid trace, not just the critical path | + +**Example:** +```bash +python -m hta critical-path -i ./preprocessed \ + --rank 0 --annotation ProfilerStep --instance-id 3 \ + --output-dir ./cp_output +``` + +**Output:** Three sections printed to stdout: +1. **Critical Path Summary** — high-level statistics (total time, breakdown percentages) +2. **Critical Path Breakdown** — per-category time on the critical path +3. **Overlaid trace path** — file path to the generated trace JSON with the critical path overlaid, viewable in `chrome://tracing` or Perfetto diff --git a/.github/workflows/ci-go.yml b/.github/workflows/ci-go.yml new file mode 100644 index 0000000..bc8c435 --- /dev/null +++ b/.github/workflows/ci-go.yml @@ -0,0 +1,29 @@ +name: Go CI + +on: + pull_request: + paths: + - '**/*.go' + - 'go.mod' + - 'go.sum' + - '.github/workflows/ci-go.yml' + push: + branches: ["main"] + paths: + - '**/*.go' + - 'go.mod' + - 'go.sum' + - '.github/workflows/ci-go.yml' + +jobs: + golang-ci: + runs-on: ubuntu-latest-m + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - run: go vet ./... + - run: go test -timeout 20m ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/cmd/tracepyre/main.go b/cmd/tracepyre/main.go new file mode 100644 index 0000000..27d75e0 --- /dev/null +++ b/cmd/tracepyre/main.go @@ -0,0 +1,930 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "sort" + "strconv" + "strings" + + "hta/pkg/analysis" + "hta/pkg/analysis/criticalpath" + "hta/pkg/analysis/kernel" + "hta/pkg/analysis/resource" + "hta/pkg/analysis/straggler" + "hta/pkg/analysis/temporal" + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func main() { + log.SetFlags(log.Ltime) + + if len(os.Args) < 2 { + usage() + os.Exit(1) + } + + switch os.Args[1] { + case "pre-process": + cmdPreProcess(os.Args[2:]) + case "temporal-breakdown": + cmdTemporalBreakdown(os.Args[2:]) + case "gpu-kernel-breakdown": + cmdGPUKernelBreakdown(os.Args[2:]) + case "comm-comp-overlap": + cmdCommCompOverlap(os.Args[2:]) + case "profiler-steps": + cmdProfilerSteps(os.Args[2:]) + case "queue-length-summary": + cmdQueueLengthSummary(os.Args[2:]) + case "potential-stragglers": + cmdPotentialStragglers(os.Args[2:]) + case "cuda-kernel-launch-stats": + cmdCUDAKernelLaunchStats(os.Args[2:]) + case "queue-length-time-series": + cmdQueueLengthTimeSeries(os.Args[2:]) + case "idle-time-breakdown": + cmdIdleTimeBreakdown(os.Args[2:]) + case "blocked-on-full-queue": + cmdBlockedOnFullQueue(os.Args[2:]) + case "memory-bw-summary": + cmdMemoryBWSummary(os.Args[2:]) + case "memory-bw-time-series": + cmdMemoryBWTimeSeries(os.Args[2:]) + case "gpu-kernels-with-annotations": + cmdGPUKernelsWithAnnotations(os.Args[2:]) + case "generate-trace-with-counters": + cmdGenerateTraceWithCounters(os.Args[2:]) + case "aten-op-kernels-and-delay": + cmdAtenOpKernelsAndDelay(os.Args[2:]) + case "frequent-cuda-kernel-sequences": + cmdFrequentCUDAKernelSequences(os.Args[2:]) + case "critical-path": + cmdCriticalPath(os.Args[2:]) + case "cupti-counter-data": + cmdCUPTICounterData(os.Args[2:]) + default: + fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", os.Args[1]) + usage() + os.Exit(1) + } +} + +func usage() { + fmt.Fprintln(os.Stderr, "Usage: hta [flags]") + fmt.Fprintln(os.Stderr, " pre-process Parse traces → SQLite DB") + fmt.Fprintln(os.Stderr, " temporal-breakdown GPU temporal breakdown from DB") + fmt.Fprintln(os.Stderr, " gpu-kernel-breakdown GPU kernel breakdown from DB") + fmt.Fprintln(os.Stderr, " comm-comp-overlap Comm-comp overlap per rank from DB") + fmt.Fprintln(os.Stderr, " profiler-steps List profiler step numbers from DB") + fmt.Fprintln(os.Stderr, " queue-length-summary Queue length statistics per stream from DB") + fmt.Fprintln(os.Stderr, " potential-stragglers Identify potential straggler ranks from DB") + fmt.Fprintln(os.Stderr, " cuda-kernel-launch-stats CUDA kernel launch statistics from DB") + fmt.Fprintln(os.Stderr, " queue-length-time-series CUDA stream queue depth over time") + fmt.Fprintln(os.Stderr, " idle-time-breakdown GPU idle time classification from DB") + fmt.Fprintln(os.Stderr, " blocked-on-full-queue Time spent blocked on full CUDA launch queue") + fmt.Fprintln(os.Stderr, " memory-bw-summary Memory bandwidth summary from DB") + fmt.Fprintln(os.Stderr, " memory-bw-time-series Memory bandwidth time series from DB") + fmt.Fprintln(os.Stderr, " gpu-kernels-with-annotations GPU kernels with user annotations from DB") + fmt.Fprintln(os.Stderr, " generate-trace-with-counters Generate enriched trace with counter events") + fmt.Fprintln(os.Stderr, " aten-op-kernels-and-delay ATen op to GPU kernel mapping with delay metrics") + fmt.Fprintln(os.Stderr, " frequent-cuda-kernel-sequences Find frequent GPU kernel launch patterns") + fmt.Fprintln(os.Stderr, " critical-path Critical path analysis for a single rank") + fmt.Fprintln(os.Stderr, " cupti-counter-data CUPTI profiler counter data with operator stacks") +} + +func cmdPreProcess(args []string) { + fs := flag.NewFlagSet("pre-process", flag.ExitOnError) + traceDir := fs.String("trace-dir", "", "Directory containing trace JSON/GZ files") + output := fs.String("output", "trace.db", "Output SQLite database path") + fs.Parse(args) + + if *traceDir == "" { + fmt.Fprintln(os.Stderr, "error: --trace-dir is required") + fs.Usage() + os.Exit(1) + } + + if err := pipeline.Run(*traceDir, *output); err != nil { + log.Fatalf("pre-process failed: %v", err) + } +} + +func cmdGPUKernelBreakdown(args []string) { + fs := flag.NewFlagSet("gpu-kernel-breakdown", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + durationRatio := fs.Float64("duration-ratio", 0.8, "Cumulative duration ratio cutoff") + numKernels := fs.Int("num-kernels", 10, "Max kernels per type per rank") + noMemory := fs.Bool("no-memory-kernels", false, "Exclude MEMORY kernel type") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := kernel.KernelBreakdownOpts{ + DurationRatio: *durationRatio, + NumKernels: *numKernels, + IncludeMemory: !*noMemory, + } + + result, err := kernel.GPUKernelBreakdown(db, opts) + if err != nil { + log.Fatalf("gpu kernel breakdown: %v", err) + } + + fmt.Println("## Kernel Type Breakdown") + fmt.Println() + fmt.Println("| kernel_type | sum(us) | percentage |") + fmt.Println("|-------------|---------|------------|") + for _, r := range result.TypeBreakdown { + fmt.Printf("| %s | %d | %.1f |\n", r.KernelType, r.SumUs, r.Percentage) + } + + fmt.Println() + fmt.Println("## Top Kernels") + fmt.Println() + fmt.Println("| name | sum(us) | max(us) | min(us) | mean(us) | stddev | kernel_type | rank |") + fmt.Println("|------|---------|---------|---------|----------|--------|-------------|------|") + for _, r := range result.TopKernels { + fmt.Printf("| %s | %d | %d | %d | %d | %.2f | %s | %d |\n", + r.Name, r.SumUs, r.MaxUs, r.MinUs, r.MeanUs, r.Stddev, r.KernelType, r.Rank) + } +} + +func cmdTemporalBreakdown(args []string) { + fs := flag.NewFlagSet("temporal-breakdown", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + results, err := temporal.TemporalBreakdown(db) + if err != nil { + log.Fatalf("temporal breakdown: %v", err) + } + + // Markdown table output + fmt.Println("| rank | idle_time(us) | compute_time(us) | non_compute_time(us) | kernel_time(us) | idle_time_pctg | compute_time_pctg | non_compute_time_pctg |") + fmt.Println("|------|---------------|------------------|----------------------|-----------------|----------------|-------------------|-----------------------|") + for _, r := range results { + fmt.Printf("| %d | %d | %d | %d | %d | %.2f | %.2f | %.2f |\n", + r.Rank, r.IdleTimeUs, r.ComputeTimeUs, r.NonComputeTimeUs, r.KernelTimeUs, + r.IdleTimePctg, r.ComputeTimePctg, r.NonComputeTimePctg) + } +} + +func cmdCommCompOverlap(args []string) { + fs := flag.NewFlagSet("comm-comp-overlap", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + results, err := temporal.CommCompOverlap(db) + if err != nil { + log.Fatalf("comm-comp overlap: %v", err) + } + + // Markdown table output + fmt.Println("| rank | overlap_pctg |") + fmt.Println("|------|--------------|") + for _, r := range results { + fmt.Printf("| %d | %.2f |\n", r.Rank, r.OverlapPctg) + } +} + +func cmdProfilerSteps(args []string) { + fs := flag.NewFlagSet("profiler-steps", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + steps, err := analysis.ProfilerSteps(db) + if err != nil { + log.Fatalf("profiler steps: %v", err) + } + + parts := make([]string, len(steps)) + for i, s := range steps { + parts[i] = strconv.Itoa(s) + } + fmt.Println(strings.Join(parts, ",")) +} + +func cmdQueueLengthSummary(args []string) { + fs := flag.NewFlagSet("queue-length-summary", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated rank list (default: all)") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + ranks := parseRanks(*ranksFlag) + + results, err := resource.QueueLengthSummary(db, ranks) + if err != nil { + log.Fatalf("queue length summary: %v", err) + } + + fmt.Println("| rank | stream | count | min | max | std | 25% | 50% | 75% |") + fmt.Println("|------|--------|-------|-----|-----|-----|-----|-----|-----|") + for _, r := range results { + fmt.Printf("| %d | %d | %d | %d | %d | %.2f | %.2f | %.2f | %.2f |\n", + r.Rank, r.Stream, r.Count, r.Min, r.Max, r.Std, r.P25, r.P50, r.P75) + } +} + +// parseRanks parses a comma-separated list of rank numbers. +// Returns nil (meaning "all ranks") if the string is empty. +func parseRanks(s string) []int { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + parts := strings.Split(s, ",") + ranks := make([]int, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + r, err := strconv.Atoi(p) + if err != nil { + log.Fatalf("invalid rank %q: %v", p, err) + } + ranks = append(ranks, r) + } + return ranks +} + +func cmdPotentialStragglers(args []string) { + fs := flag.NewFlagSet("potential-stragglers", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + numCandidates := fs.Int("num-candidates", 2, "Top K straggler candidates") + profilerSteps := fs.String("profiler-steps", "", "Comma-separated step numbers (default: all)") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := straggler.StragglerOpts{ + NumCandidates: *numCandidates, + } + + if *profilerSteps != "" { + for _, s := range strings.Split(*profilerSteps, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + n, err := strconv.Atoi(s) + if err != nil { + log.Fatalf("invalid profiler step %q: %v", s, err) + } + opts.ProfilerSteps = append(opts.ProfilerSteps, n) + } + } + + results, err := straggler.PotentialStragglers(db, opts) + if err != nil { + log.Fatalf("potential stragglers: %v", err) + } + + if len(results) == 0 { + fmt.Println("No potential stragglers detected (no qualifying communication kernels found).") + return + } + + parts := make([]string, len(results)) + for i, r := range results { + parts[i] = strconv.Itoa(r.Rank) + } + fmt.Println(strings.Join(parts, ",")) +} + +func cmdQueueLengthTimeSeries(args []string) { + fs := flag.NewFlagSet("queue-length-time-series", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated ranks (default: all)") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + ranks := parseRanks(*ranksFlag) + + result, err := resource.QueueLengthTimeSeries(db, ranks) + if err != nil { + log.Fatalf("queue-length-time-series: %v", err) + } + + // Collect and sort rank keys for deterministic output. + sortedRanks := make([]int, 0, len(result)) + for r := range result { + sortedRanks = append(sortedRanks, r) + } + sort.Ints(sortedRanks) + + for _, r := range sortedRanks { + points := result[r] + fmt.Printf("## Rank %d\n\n", r) + fmt.Println("| ts | stream | queue_length |") + fmt.Println("|----|--------|--------------|") + for _, p := range points { + fmt.Printf("| %d | %d | %d |\n", p.Timestamp, p.Stream, p.QueueLength) + } + fmt.Println() + } +} + +func cmdCUDAKernelLaunchStats(args []string) { + fs := flag.NewFlagSet("cuda-kernel-launch-stats", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksStr := fs.String("ranks", "", "Comma-separated rank list (default: all)") + runtimeCutoff := fs.Int("runtime-cutoff", 50, "Runtime duration cutoff in µs") + launchDelayCutoff := fs.Int("launch-delay-cutoff", 100, "Launch delay cutoff in µs") + noMemory := fs.Bool("no-memory-events", false, "Exclude cudaMemcpyAsync/cudaMemsetAsync") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := kernel.LaunchStatsOpts{ + RuntimeCutoff: *runtimeCutoff, + LaunchDelayCutoff: *launchDelayCutoff, + IncludeMemory: !*noMemory, + } + if *ranksStr != "" { + for _, s := range strings.Split(*ranksStr, ",") { + r, err := strconv.Atoi(strings.TrimSpace(s)) + if err != nil { + log.Fatalf("invalid rank %q: %v", s, err) + } + opts.Ranks = append(opts.Ranks, r) + } + } + + result, err := kernel.CUDAKernelLaunchStats(db, opts) + if err != nil { + log.Fatalf("cuda kernel launch stats: %v", err) + } + + // Collect and sort rank keys for deterministic output. + ranks := make([]int, 0, len(result)) + for r := range result { + ranks = append(ranks, r) + } + sort.Ints(ranks) + + for _, rank := range ranks { + rows := result[rank] + fmt.Printf("## Rank %d\n\n", rank) + fmt.Println("| correlation | cpu_duration | gpu_duration | launch_delay |") + fmt.Println("|-------------|--------------|--------------|--------------|") + for _, r := range rows { + fmt.Printf("| %d | %d | %d | %d |\n", + r.Correlation, r.CPUDuration, r.GPUDuration, r.LaunchDelay) + } + fmt.Println() + } +} + +func cmdIdleTimeBreakdown(args []string) { + fs := flag.NewFlagSet("idle-time-breakdown", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksStr := fs.String("ranks", "", "Comma-separated ranks") + streamsStr := fs.String("streams", "", "Comma-separated CUDA streams") + showStats := fs.Bool("show-idle-interval-stats", false, "Show interval statistics") + kernelDelay := fs.Int64("consecutive-kernel-delay", 30, "Threshold (us)") + fs.Parse(args) + + opts := temporal.IdleTimeOpts{ + ConsecutiveKernelDelay: *kernelDelay, + ShowIdleIntervalStats: *showStats, + Ranks: parseCSVInts(*ranksStr), + Streams: parseCSVInts(*streamsStr), + } + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + results, stats, err := temporal.IdleTimeBreakdown(db, opts) + if err != nil { + log.Fatalf("idle time breakdown: %v", err) + } + + fmt.Println("| rank | stream | idle_category | idle_time(us) | idle_time_ratio |") + fmt.Println("|------|--------|---------------|---------------|-----------------|") + for _, r := range results { + fmt.Printf("| %d | %d | %s | %d | %.2f |\n", + r.Rank, r.Stream, r.IdleCategory, r.IdleTimeUs, r.IdleTimeRatio) + } + + if *showStats && len(stats) > 0 { + fmt.Println() + fmt.Println("| rank | stream | idle_category | count | mean | std | min | 25% | 50% | 75% | max |") + fmt.Println("|------|--------|---------------|-------|------|-----|-----|-----|-----|-----|-----|") + for _, s := range stats { + fmt.Printf("| %d | %d | %s | %d | %.2f | %.2f | %d | %.2f | %.2f | %.2f | %d |\n", + s.Rank, s.Stream, s.IdleCategory, s.Count, s.Mean, s.Std, s.Min, s.Pct25, s.Pct50, s.Pct75, s.Max) + } + } +} + +func cmdBlockedOnFullQueue(args []string) { + fs := flag.NewFlagSet("blocked-on-full-queue", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated rank list (default: all)") + maxQL := fs.Int("max-queue-length", 1024, "Max CUDA launch queue length per stream") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := resource.BlockedQueueOpts{ + Ranks: parseRanks(*ranksFlag), + MaxQueueLength: *maxQL, + } + + results, err := resource.BlockedOnFullQueue(db, opts) + if err != nil { + log.Fatalf("blocked-on-full-queue: %v", err) + } + + if len(results) == 0 { + fmt.Println("No streams reached maximum queue length.") + return + } + + fmt.Println("| rank | stream | duration_at_max_queue_length | relative_duration |") + fmt.Println("|------|--------|-----------------------------|-------------------|") + for _, r := range results { + fmt.Printf("| %d | %d | %d | %.6f |\n", + r.Rank, r.Stream, r.Duration, r.RelativeDuration) + } +} + +func cmdMemoryBWSummary(args []string) { + fs := flag.NewFlagSet("memory-bw-summary", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated rank list (default: all)") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := resource.MemoryBWOpts{ + Ranks: parseRanks(*ranksFlag), + } + + results, err := resource.MemoryBWSummary(db, opts) + if err != nil { + log.Fatalf("memory bw summary: %v", err) + } + + if len(results) == 0 { + fmt.Println("(no data)") + return + } + + fmt.Println("| rank | name | count | mean | std | min | 25% | 50% | 75% | max |") + fmt.Println("|------|------|-------|------|-----|-----|-----|-----|-----|-----|") + for _, r := range results { + fmt.Printf("| %d | %s | %d | %.2f | %.2f | %.2f | %.2f | %.2f | %.2f | %.2f |\n", + r.Rank, r.Name, r.Count, r.Mean, r.Std, r.Min, r.P25, r.P50, r.P75, r.Max) + } +} + +func cmdMemoryBWTimeSeries(args []string) { + fs := flag.NewFlagSet("memory-bw-time-series", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated ranks (default: all)") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + ranks := parseRanks(*ranksFlag) + + result, err := resource.MemoryBWTimeSeries(db, ranks) + if err != nil { + log.Fatalf("memory-bw-time-series: %v", err) + } + + sortedRanks := make([]int, 0, len(result)) + for r := range result { + sortedRanks = append(sortedRanks, r) + } + sort.Ints(sortedRanks) + + for _, r := range sortedRanks { + points := result[r] + fmt.Printf("## Rank %d\n\n", r) + fmt.Println("| ts | pid | name | memory_bw_gbps |") + fmt.Println("|----|-----|------|----------------|") + for _, p := range points { + fmt.Printf("| %d | %d | %s | %.4f |\n", p.Timestamp, p.ProcessID, p.Name, p.MemoryBWGbps) + } + fmt.Println() + } +} + +func cmdGPUKernelsWithAnnotations(args []string) { + fs := flag.NewFlagSet("gpu-kernels-with-annotations", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + rank := fs.Int("rank", -1, "Rank to analyze (required)") + noExpandNames := fs.Bool("no-expand-names", false, "Skip expanding symbol IDs to names") + noShortenNames := fs.Bool("no-shorten-names", false, "Skip shortening kernel names") + fs.Parse(args) + + if *rank < 0 { + fmt.Fprintln(os.Stderr, "error: --rank is required") + fs.Usage() + os.Exit(1) + } + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := kernel.AnnotationOpts{ + Rank: *rank, + ExpandNames: !*noExpandNames, + ShortenNames: !*noShortenNames, + } + + results, err := kernel.GPUKernelsWithAnnotations(db, opts) + if err != nil { + log.Fatalf("gpu kernels with annotations: %v", err) + } + + if len(results) == 0 { + fmt.Println("(no data)") + return + } + + fmt.Println("| started_at | ended_at | kernel | annotation |") + fmt.Println("|------------|----------|--------|------------|") + for _, r := range results { + fmt.Printf("| %d | %d | %s | %s |\n", + r.StartedAt, r.EndedAt, r.KernelName, r.UserAnnotation) + } +} + +func cmdGenerateTraceWithCounters(args []string) { + fs := flag.NewFlagSet("generate-trace-with-counters", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated rank list (default: all)") + timeSeries := fs.String("time-series", "both", "Time series to embed: queue_length, memcpy_bandwidth, or both") + outputSuffix := fs.String("output-suffix", "_with_counters", "Suffix for output file names") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + var counters resource.CounterType + switch *timeSeries { + case "queue_length": + counters = resource.CounterQueueLength + case "memcpy_bandwidth": + counters = resource.CounterMemoryBW + case "both": + counters = resource.CounterAll + default: + log.Fatalf("invalid --time-series value %q: must be queue_length, memcpy_bandwidth, or both", *timeSeries) + } + + opts := resource.GenerateCountersOpts{ + Ranks: parseRanks(*ranksFlag), + Counters: counters, + OutputSuffix: *outputSuffix, + } + + outputFiles, err := resource.GenerateTraceWithCounters(db, opts) + if err != nil { + log.Fatalf("generate-trace-with-counters: %v", err) + } + + for _, f := range outputFiles { + fmt.Println(f) + } +} + +func cmdAtenOpKernelsAndDelay(args []string) { + fs := flag.NewFlagSet("aten-op-kernels-and-delay", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated rank list (default: all)") + sortByFlag := fs.String("sort-by", "occurrence_count", "Comma-separated sort columns") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := kernel.AtenDelayOpts{ + Ranks: parseRanks(*ranksFlag), + } + if *sortByFlag != "" { + for _, s := range strings.Split(*sortByFlag, ",") { + s = strings.TrimSpace(s) + if s != "" { + opts.SortBy = append(opts.SortBy, s) + } + } + } + + result, err := kernel.AtenOpKernelsAndDelay(db, opts) + if err != nil { + log.Fatalf("aten-op-kernels-and-delay: %v", err) + } + + sortedRanks := make([]int, 0, len(result)) + for r := range result { + sortedRanks = append(sortedRanks, r) + } + sort.Ints(sortedRanks) + + for _, r := range sortedRanks { + rows := result[r] + fmt.Printf("## Rank %d\n\n", r) + fmt.Println("| aten_op_name | kernel_sequence | occurrence_count | avg_aten_op_launch_delay | avg_runtime_delay |") + fmt.Println("|--------------|-----------------|------------------|--------------------------|-------------------|") + for _, row := range rows { + fmt.Printf("| %s | %s | %d | %.3f | %.3f |\n", + row.AtenOpName, row.KernelSequence, row.OccurrenceCount, + row.AvgAtenOpLaunchDelay, row.AvgRuntimeDelay) + } + fmt.Println() + } +} + +func cmdFrequentCUDAKernelSequences(args []string) { + fs := flag.NewFlagSet("frequent-cuda-kernel-sequences", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + operatorName := fs.String("operator-name", "", "CPU operator name substring to match (required)") + outputDir := fs.String("output-dir", "", "Directory for overlay trace output") + minPatternLen := fs.Int("min-pattern-len", 3, "Minimum pattern length (operator + kernels)") + rank := fs.Int("rank", 0, "Rank to analyze") + topK := fs.Int("top-k", 5, "Number of top patterns to return") + fs.Parse(args) + + if *operatorName == "" { + fmt.Fprintln(os.Stderr, "error: --operator-name is required") + fs.Usage() + os.Exit(1) + } + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := kernel.KernelSeqOpts{ + OperatorName: *operatorName, + OutputDir: *outputDir, + MinPatternLen: *minPatternLen, + Rank: *rank, + TopK: *topK, + } + + results, err := kernel.FrequentCUDAKernelSequences(db, opts) + if err != nil { + log.Fatalf("frequent cuda kernel sequences: %v", err) + } + + if len(results) == 0 { + fmt.Println("No frequent CUDA kernel sequences found.") + return + } + + fmt.Println("| pattern | count | GPU kernel duration (us) | CPU op duration (us) |") + fmt.Println("|---------|-------|--------------------------|----------------------|") + for _, r := range results { + fmt.Printf("| %s | %d | %d | %d |\n", + r.Pattern, r.Count, r.GPUKernelDurUs, r.CPUOpDurUs) + } +} + +func cmdCriticalPath(args []string) { + fs := flag.NewFlagSet("critical-path", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + rank := fs.Int("rank", -1, "Rank to analyze (required)") + annotation := fs.String("annotation", "", "Annotation name to match (required)") + instanceID := fs.String("instance-id", "", "Instance ID: single int or 'start,end' (required)") + outputDir := fs.String("output-dir", "", "Directory for overlay trace output (required)") + dataLoadEvents := fs.String("data-load-events", "", "Comma-separated regex patterns for data loading ops") + showAllEdges := fs.Bool("show-all-edges", false, "Show all edges in overlay (not just critical path)") + fs.Parse(args) + + if *rank < 0 { + fmt.Fprintln(os.Stderr, "error: --rank is required") + fs.Usage() + os.Exit(1) + } + if *annotation == "" { + fmt.Fprintln(os.Stderr, "error: --annotation is required") + fs.Usage() + os.Exit(1) + } + if *instanceID == "" { + fmt.Fprintln(os.Stderr, "error: --instance-id is required") + fs.Usage() + os.Exit(1) + } + if *outputDir == "" { + fmt.Fprintln(os.Stderr, "error: --output-dir is required") + fs.Usage() + os.Exit(1) + } + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := criticalpath.CriticalPathOpts{ + Rank: *rank, + Annotation: *annotation, + InstanceID: *instanceID, + OutputDir: *outputDir, + ShowAllEdges: *showAllEdges, + } + + if *dataLoadEvents != "" { + for _, s := range strings.Split(*dataLoadEvents, ",") { + s = strings.TrimSpace(s) + if s != "" { + opts.DataLoadEvents = append(opts.DataLoadEvents, s) + } + } + } + + result, err := criticalpath.CriticalPath(db, opts) + if err != nil { + log.Fatalf("critical-path: %v", err) + } + + fmt.Println("## Critical Path Summary") + fmt.Println() + fmt.Printf("Nodes: %d, Edges: %d, Path length: %d us\n\n", result.NumNodes, result.NumEdges, result.PathLength) + + fmt.Println("## Breakdown by Bound Type") + fmt.Println() + fmt.Println("| bound_type | duration(us) | percentage |") + fmt.Println("|------------|-------------|------------|") + for _, r := range result.Summary { + fmt.Printf("| %s | %d | %.2f |\n", r.BoundType, r.Duration, r.Percentage) + } + + if result.OverlayFile != "" { + fmt.Printf("\nOverlay trace: %s\n", result.OverlayFile) + } +} + +func cmdCUPTICounterData(args []string) { + fs := flag.NewFlagSet("cupti-counter-data", flag.ExitOnError) + dbPath := fs.String("db", "trace.db", "SQLite database path") + ranksFlag := fs.String("ranks", "", "Comma-separated rank list (default: all)") + fs.Parse(args) + + db, err := store.Create(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + opts := resource.CUPTICounterOpts{ + Ranks: parseRanks(*ranksFlag), + } + + result, err := resource.CUPTICounterData(db, opts) + if err != nil { + log.Fatalf("cupti-counter-data: %v", err) + } + + if len(result) == 0 { + fmt.Println("(no data)") + return + } + + sortedRanks := make([]int, 0, len(result)) + for r := range result { + sortedRanks = append(sortedRanks, r) + } + sort.Ints(sortedRanks) + + for _, r := range sortedRanks { + rows := result[r] + fmt.Printf("## Rank %d\n\n", r) + + // Discover counter column names from data. + counterNames := make(map[string]bool) + for _, row := range rows { + for k := range row.Counters { + counterNames[k] = true + } + } + sortedCounters := make([]string, 0, len(counterNames)) + for k := range counterNames { + sortedCounters = append(sortedCounters, k) + } + sort.Strings(sortedCounters) + + // Header. + header := "| name | top_level_op | bottom_level_op | op_stack" + sep := "|------|--------------|-----------------|----------" + for _, c := range sortedCounters { + header += " | " + c + sep += "|---" + } + header += " |" + sep += "|" + fmt.Println(header) + fmt.Println(sep) + + // Rows. + for _, row := range rows { + opStack := strings.Join(row.OpStack, " -> ") + fmt.Printf("| %s | %s | %s | %s", + row.KernelName, row.TopLevelOp, row.BottomLevelOp, opStack) + for _, c := range sortedCounters { + fmt.Printf(" | %.0f", row.Counters[c]) + } + fmt.Println(" |") + } + fmt.Println() + } +} + +func parseCSVInts(s string) []int { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + ints := make([]int, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + v, err := strconv.Atoi(p) + if err != nil { + log.Fatalf("invalid integer %q: %v", p, err) + } + ints = append(ints, v) + } + return ints +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b85aed1 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module hta + +go 1.24.11 + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/sys v0.37.0 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.46.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5cf3e13 --- /dev/null +++ b/go.sum @@ -0,0 +1,23 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= diff --git a/pkg/analysis/callstack.go b/pkg/analysis/callstack.go new file mode 100644 index 0000000..59e80d2 --- /dev/null +++ b/pkg/analysis/callstack.go @@ -0,0 +1,125 @@ +package analysis + +import "sort" + +// CallStackNode represents a node in the CPU call stack tree. +type CallStackNode struct { + EvIdx int // index into the event slice + Parent int // parent EvIdx (-1 for root-level nodes) + Children []int // child EvIdx values + Depth int +} + +// CallStack holds the call tree for one CPU thread. +type CallStack struct { + ThreadID int64 + Nodes map[int]*CallStackNode + Roots []int // top-level event indices +} + +// CSEvent is used for call stack construction. +type CSEvent struct { + Idx int + ThreadID int64 + Start int64 + End int64 +} + +// BuildCallStacks builds call stacks from CPU events sorted by +// (thread_id, started_at, -duration). Returns one CallStack per thread. +func BuildCallStacks(events []CSEvent) []*CallStack { + if len(events) == 0 { + return nil + } + + // Sort: (ThreadID, Start ASC, duration DESC i.e. End DESC) + sorted := make([]CSEvent, len(events)) + copy(sorted, events) + sort.Slice(sorted, func(i, j int) bool { + a, b := sorted[i], sorted[j] + if a.ThreadID != b.ThreadID { + return a.ThreadID < b.ThreadID + } + if a.Start != b.Start { + return a.Start < b.Start + } + // Longer duration (later end) first at the same start + return a.End > b.End + }) + + var result []*CallStack + var cs *CallStack + var curThread int64 = -1 << 62 // impossible sentinel + type stackEntry struct { + evIdx int + end int64 + } + var stack []stackEntry + + for _, ev := range sorted { + if ev.ThreadID != curThread { + // Flush previous thread + if cs != nil { + result = append(result, cs) + } + cs = &CallStack{ + ThreadID: ev.ThreadID, + Nodes: make(map[int]*CallStackNode), + } + curThread = ev.ThreadID + stack = stack[:0] + } + + // Pop stack entries whose end <= event start (siblings, not parents) + for len(stack) > 0 && stack[len(stack)-1].end <= ev.Start { + stack = stack[:len(stack)-1] + } + + parentIdx := -1 + depth := 0 + if len(stack) > 0 { + parentIdx = stack[len(stack)-1].evIdx + depth = cs.Nodes[parentIdx].Depth + 1 + } + + node := &CallStackNode{ + EvIdx: ev.Idx, + Parent: parentIdx, + Depth: depth, + } + cs.Nodes[ev.Idx] = node + + if parentIdx == -1 { + cs.Roots = append(cs.Roots, ev.Idx) + } else { + cs.Nodes[parentIdx].Children = append(cs.Nodes[parentIdx].Children, ev.Idx) + } + + stack = append(stack, stackEntry{evIdx: ev.Idx, end: ev.End}) + } + + if cs != nil { + result = append(result, cs) + } + return result +} + +// DFSTraverse performs a depth-first traversal of the call stack, +// calling enter before processing children and exit after. +func (cs *CallStack) DFSTraverse(enter, exit func(evIdx int, node *CallStackNode)) { + for _, root := range cs.Roots { + cs.dfsVisit(root, enter, exit) + } +} + +func (cs *CallStack) dfsVisit(evIdx int, enter, exit func(evIdx int, node *CallStackNode)) { + node := cs.Nodes[evIdx] + if node == nil { + return + } + enter(evIdx, node) + for _, child := range node.Children { + cs.dfsVisit(child, enter, exit) + } + exit(evIdx, node) +} diff --git a/pkg/analysis/criticalpath/critical_path.go b/pkg/analysis/criticalpath/critical_path.go new file mode 100644 index 0000000..79b61ea --- /dev/null +++ b/pkg/analysis/criticalpath/critical_path.go @@ -0,0 +1,1726 @@ +package criticalpath + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "math" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "hta/pkg/analysis" + "hta/pkg/analysis/kernel" + "hta/pkg/analysis/resource" + "hta/pkg/store" + "hta/pkg/symbol" +) + +// --------------------------------------------------------------------------- +// Edge and bound type enums +// --------------------------------------------------------------------------- + +// CPEdgeType classifies edges in the critical path DAG. +type CPEdgeType int + +const ( + EdgeOperatorKernel CPEdgeType = iota // span edge: start→end of same event + EdgeDependency // zero-weight sequential dependency + EdgeKernelLaunchDelay // CPU launch → GPU kernel start + EdgeKernelKernelDelay // prev GPU kernel end → next GPU kernel start + EdgeSyncDependency // GPU→CPU or GPU→GPU sync edge +) + +func (t CPEdgeType) String() string { + switch t { + case EdgeOperatorKernel: + return "operator_kernel" + case EdgeDependency: + return "dependency" + case EdgeKernelLaunchDelay: + return "kernel_launch_delay" + case EdgeKernelKernelDelay: + return "kernel_kernel_delay" + case EdgeSyncDependency: + return "sync_dependency" + default: + return "unknown" + } +} + +// BoundType classifies what the critical path is "bound by" for each segment. +type BoundType int + +const ( + BoundCPU BoundType = iota + BoundDataLoading // CPU-side data loading + BoundGPUCompute // GPU compute kernel + BoundGPUCommunication // GPU communication kernel (NCCL etc.) + BoundGPUKernelKernelOverhead // gap between consecutive GPU kernels + BoundGPUKernelLaunchOverhead // CPU→GPU launch delay +) + +func (b BoundType) String() string { + switch b { + case BoundCPU: + return "cpu_bound" + case BoundDataLoading: + return "data_loading" + case BoundGPUCompute: + return "gpu_compute_bound" + case BoundGPUCommunication: + return "gpu_communication_bound" + case BoundGPUKernelKernelOverhead: + return "gpu_kernel_kernel_overhead" + case BoundGPUKernelLaunchOverhead: + return "gpu_kernel_launch_overhead" + default: + return "unknown" + } +} + +// --------------------------------------------------------------------------- +// DAG data structures +// --------------------------------------------------------------------------- + +type cpNode struct { + Idx int + EvIdx int // index into cpEvent slice + Ts int64 // timestamp (start or end of the event) + IsStart bool + IsBlocking bool // blocking CUDA sync call +} + +type cpEdge struct { + Begin int // node index + End int // node index + Type CPEdgeType + Weight int64 +} + +type cpDAG struct { + nodes []cpNode + adj [][]int // node → outgoing edge indices + revAdj [][]int // node → incoming edge indices + edges []cpEdge + evToStart map[int]int // event index → start node index + evToEnd map[int]int // event index → end node index + edgeToEv map[[2]int]int // (begin,end) node pair → attributed event index +} + +func newCPDAG() *cpDAG { + return &cpDAG{ + evToStart: make(map[int]int), + evToEnd: make(map[int]int), + edgeToEv: make(map[[2]int]int), + } +} + +func (g *cpDAG) addNode(n cpNode) int { + idx := len(g.nodes) + n.Idx = idx + g.nodes = append(g.nodes, n) + g.adj = append(g.adj, nil) + g.revAdj = append(g.revAdj, nil) + return idx +} + +func (g *cpDAG) addEdge(begin, end int, typ CPEdgeType, weight int64) int { + idx := len(g.edges) + g.edges = append(g.edges, cpEdge{Begin: begin, End: end, Type: typ, Weight: weight}) + g.adj[begin] = append(g.adj[begin], idx) + g.revAdj[end] = append(g.revAdj[end], idx) + return idx +} + +// addEdgeHelper adds an edge with weight derived from timestamps, following the +// Python convention: DEPENDENCY and SYNC_DEPENDENCY always get weight=0, +// others get dest.ts - src.ts unless zeroWeight is true. +func (g *cpDAG) addEdgeHelper(srcIdx, dstIdx int, typ CPEdgeType, zeroWeight bool) { + src := g.nodes[srcIdx] + dst := g.nodes[dstIdx] + var w int64 + switch typ { + case EdgeDependency, EdgeSyncDependency: + w = 0 + default: + if zeroWeight { + w = 0 + } else { + w = dst.Ts - src.Ts + } + } + g.addEdge(srcIdx, dstIdx, typ, w) +} + +// --------------------------------------------------------------------------- +// cpEvent — the working event struct for critical path +// --------------------------------------------------------------------------- + +type cpEvent struct { + Idx int // position in the cpEvent slice + DBID int64 // trace_event.id + StartedAt int64 + Duration int64 + EndedAt int64 + NameID int + CategoryID int + DeviceType string // "cpu" or "gpu" + ThreadID int64 + ProcessID int64 + Stream int + Correlation int + + WaitOnCudaEventRecordCorrID int + WaitOnStream int +} + +func (e *cpEvent) isCPU() bool { return e.DeviceType == "cpu" } +func (e *cpEvent) isGPU() bool { return e.DeviceType == "gpu" } + +// --------------------------------------------------------------------------- +// Blocking sync call detection +// --------------------------------------------------------------------------- + +var blockingSyncNames = map[string]bool{ + "cudaDeviceSynchronize": true, + "cudaStreamSynchronize": true, + "cudaEventQuery": true, + "cudaEventSynchronize": true, + "cudaMemcpy": true, + "cudaMemcpyAsync": true, +} + +// CUDA runtime lock names — events that acquire the CUDA runtime lock. +var cudaRuntimeLockNames = map[string]bool{ + "cudaHostAlloc": true, + "cudaLaunchKernel": true, + "cudaLaunchKernelExC": true, + "cudaMemcpyAsync": true, + "cudaMemsetAsync": true, +} + +// CUDA sync event names for Phase 4. +var cudaSyncNames = map[string]bool{ + "cudaDeviceSynchronize": true, + "cudaStreamSynchronize": true, + "cudaEventSynchronize": true, + "cudaStreamWaitEvent": true, +} + +const kernelKernelDelayThresholdUs int64 = 1500 + +// --------------------------------------------------------------------------- +// CriticalPathOpts controls the critical path analysis. +// --------------------------------------------------------------------------- + +// CriticalPathOpts holds parameters for the critical path analysis. +type CriticalPathOpts struct { + Rank int + Annotation string // annotation name to match + InstanceID string // "N" or "start,end" + OutputDir string // for overlay trace output + DataLoadEvents []string // regex patterns for data loading ops + ShowAllEdges bool +} + +// CriticalPathSummaryRow holds one row of the bound-type breakdown. +type CriticalPathSummaryRow struct { + BoundType string + Duration int64 + Percentage float64 +} + +// CriticalPathResult holds the complete critical path analysis output. +type CriticalPathResult struct { + Summary []CriticalPathSummaryRow + NumNodes int + NumEdges int + PathLength int64 + OverlayFile string +} + +// --------------------------------------------------------------------------- +// Main entry point +// --------------------------------------------------------------------------- + +// CriticalPath runs the critical path analysis for a single rank. +func CriticalPath(db *sql.DB, opts CriticalPathOpts) (*CriticalPathResult, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + // Load all events for the rank. + allRows, err := store.LoadAllEventsForRank(db, opts.Rank) + if err != nil { + return nil, fmt.Errorf("loading events: %w", err) + } + if len(allRows) == 0 { + return nil, fmt.Errorf("no events found for rank %d", opts.Rank) + } + + // Convert to cpEvent slice. + allEvents := make([]cpEvent, len(allRows)) + for i, r := range allRows { + allEvents[i] = cpEvent{ + Idx: i, + DBID: r.ID, + StartedAt: r.StartedAt, + Duration: r.Duration, + EndedAt: r.EndedAt, + NameID: r.NameID, + CategoryID: r.CategoryID, + DeviceType: r.DeviceType, + ThreadID: r.ThreadID, + ProcessID: r.ProcessID, + Stream: r.Stream, + Correlation: r.Correlation, + WaitOnCudaEventRecordCorrID: r.WaitOnCudaEventRecordCorrID, + WaitOnStream: r.WaitOnStream, + } + } + + // Apply annotation-based region of interest filtering. + events, err := clipToAnnotation(allEvents, sym, opts) + if err != nil { + return nil, fmt.Errorf("annotation clipping: %w", err) + } + + // Re-index events. + for i := range events { + events[i].Idx = i + } + + // Build name lookup for quick classification. + nameLookup := make(map[int]string) + for _, s := range sym.All() { + nameLookup[s.ID] = s.Name + } + + // Build data loading regexps. + var dataLoadRegexps []*regexp.Regexp + for _, pat := range opts.DataLoadEvents { + re, err := regexp.Compile(pat) + if err != nil { + return nil, fmt.Errorf("invalid data-load-events pattern %q: %w", pat, err) + } + dataLoadRegexps = append(dataLoadRegexps, re) + } + + // Build the DAG. + dag := newCPDAG() + + // Phase 1: Create event nodes. + createEventNodes(dag, events, nameLookup) + + // Phase 2: CUDA runtime lock edges. + constructRuntimeLockEdges(dag, events, nameLookup) + + // Phase 3: Call stack edges. + constructCallStackEdges(dag, events, nameLookup, sym) + + // Phase 4: GPU kernel edges. + err = constructKernelEdges(dag, events, nameLookup, db, sym, opts.Rank, allEvents) + if err != nil { + return nil, fmt.Errorf("constructing kernel edges: %w", err) + } + + // Validate DAG. + validateDAG(dag) + + // Find longest path. + path, totalWeight := dag.longestPath() + + // Build critical path edge and event sets. + cpEdgesSet := make(map[int]bool) + cpEventsSet := make(map[int]bool) + for _, nIdx := range path { + cpEventsSet[dag.nodes[nIdx].EvIdx] = true + } + for i := 0; i < len(path)-1; i++ { + u, v := path[i], path[i+1] + for _, eIdx := range dag.adj[u] { + if dag.edges[eIdx].End == v { + cpEdgesSet[eIdx] = true + break + } + } + } + + // Breakdown by bound type. + summary := computeBreakdown(dag, events, cpEdgesSet, nameLookup, dataLoadRegexps) + + result := &CriticalPathResult{ + Summary: summary, + NumNodes: len(dag.nodes), + NumEdges: len(dag.edges), + PathLength: totalWeight, + } + + // Trace overlay. + if opts.OutputDir != "" { + overlayFile, err := overlayCriticalPath( + db, opts.Rank, dag, events, cpEdgesSet, cpEventsSet, + opts.OutputDir, opts.ShowAllEdges, + ) + if err != nil { + return result, fmt.Errorf("overlay: %w", err) + } + result.OverlayFile = overlayFile + } + + return result, nil +} + +// --------------------------------------------------------------------------- +// Phase 1: Create event nodes +// --------------------------------------------------------------------------- + +func createEventNodes(dag *cpDAG, events []cpEvent, names map[int]string) { + // Determine which events get nodes. + type nodeCandidate struct { + evIdx int + startTs int64 + endTs int64 + } + var candidates []nodeCandidate + + for i := range events { + ev := &events[i] + include := false + + if ev.isCPU() { + // Include CPU operators and runtime events. + // Basically everything except pure metadata. + if ev.Duration > 0 || ev.Duration == 0 { + include = true + } + } else if ev.isGPU() { + // Include GPU events with valid stream and correlation. + if ev.Stream >= 0 && ev.Correlation >= 0 { + include = true + } + } + + if include { + candidates = append(candidates, nodeCandidate{ + evIdx: i, + startTs: ev.StartedAt, + endTs: ev.EndedAt, + }) + } + } + + // Create start and end nodes for each candidate. + type rawNode struct { + cpNode + sortKey [2]int64 // (ts, evIdx) for sorting + } + var rawNodes []rawNode + + for _, c := range candidates { + isBlocking := false + name := names[events[c.evIdx].NameID] + if blockingSyncNames[name] { + isBlocking = true + } + + rawNodes = append(rawNodes, rawNode{ + cpNode: cpNode{ + EvIdx: c.evIdx, + Ts: c.startTs, + IsStart: true, + IsBlocking: isBlocking, + }, + sortKey: [2]int64{c.startTs, int64(c.evIdx)}, + }) + rawNodes = append(rawNodes, rawNode{ + cpNode: cpNode{ + EvIdx: c.evIdx, + Ts: c.endTs, + IsStart: false, + IsBlocking: isBlocking, + }, + sortKey: [2]int64{c.endTs, int64(c.evIdx)}, + }) + } + + // Sort by (ts, evIdx). + sort.Slice(rawNodes, func(i, j int) bool { + if rawNodes[i].sortKey[0] != rawNodes[j].sortKey[0] { + return rawNodes[i].sortKey[0] < rawNodes[j].sortKey[0] + } + return rawNodes[i].sortKey[1] < rawNodes[j].sortKey[1] + }) + + // Add sorted nodes to DAG and build maps. + for _, rn := range rawNodes { + idx := dag.addNode(rn.cpNode) + if rn.IsStart { + dag.evToStart[rn.EvIdx] = idx + } else { + dag.evToEnd[rn.EvIdx] = idx + } + } +} + +// --------------------------------------------------------------------------- +// Phase 2: CUDA runtime lock edges +// --------------------------------------------------------------------------- + +func constructRuntimeLockEdges(dag *cpDAG, events []cpEvent, names map[int]string) { + // Build set of event indices that are CUDA runtime lock events. + runtimeEvSet := make(map[int]bool) + for i := range events { + if events[i].isCPU() { + name := names[events[i].NameID] + if cudaRuntimeLockNames[name] { + runtimeEvSet[i] = true + } + } + } + if len(runtimeEvSet) == 0 { + return + } + + // Iterate through all nodes in time-sorted order (they are already sorted). + type stackEntry struct { + nodeIdx int + evIdx int + } + var startNodes []stackEntry + lastNodeIdx := -1 + + for _, node := range dag.nodes { + evIdx := node.EvIdx + if !runtimeEvSet[evIdx] { + continue + } + + if node.IsStart { + startNodes = append(startNodes, stackEntry{nodeIdx: node.Idx, evIdx: evIdx}) + } else { + // End node for a runtime event. + if len(startNodes) > 0 { + top := startNodes[len(startNodes)-1] + startNodes = startNodes[:len(startNodes)-1] + + if top.evIdx == evIdx { + // Same event: OPERATOR_KERNEL span edge. + dag.addEdgeHelper(top.nodeIdx, node.Idx, EdgeOperatorKernel, node.IsBlocking) + } else { + // Different event: DEPENDENCY. + dag.addEdgeHelper(top.nodeIdx, node.Idx, EdgeDependency, false) + } + } + + if lastNodeIdx >= 0 && !dag.nodes[lastNodeIdx].IsStart { + // Chain end→end for lock release sequence. + dag.addEdgeHelper(lastNodeIdx, node.Idx, EdgeDependency, false) + } + } + + lastNodeIdx = node.Idx + } +} + +// runtimeLockEventSet builds the set of event indices sharing the CUDA runtime lock. +func runtimeLockEventSet(events []cpEvent, names map[int]string) map[int]bool { + set := make(map[int]bool) + for i := range events { + if events[i].isCPU() { + name := names[events[i].NameID] + if cudaRuntimeLockNames[name] { + set[i] = true + } + } + } + return set +} + +// --------------------------------------------------------------------------- +// Phase 3: Call stack edges +// --------------------------------------------------------------------------- + +func constructCallStackEdges(dag *cpDAG, events []cpEvent, names map[int]string, sym *symbol.Table) { + runtimeSet := runtimeLockEventSet(events, names) + + // Detect forward/backward thread merging. + // Forward thread: has events with "forward" in name. + // Backward thread: has events with "autograd" in name. + type threadInfo struct { + hasForward bool + hasAutograd bool + threadID int64 + } + threadInfos := make(map[int64]*threadInfo) + for i := range events { + ev := &events[i] + if !ev.isCPU() { + continue + } + ti, ok := threadInfos[ev.ThreadID] + if !ok { + ti = &threadInfo{threadID: ev.ThreadID} + threadInfos[ev.ThreadID] = ti + } + name := names[ev.NameID] + if strings.Contains(name, "forward") { + ti.hasForward = true + } + if strings.Contains(name, "autograd") { + ti.hasAutograd = true + } + } + + // Find forward and backward threads. + var forwardThread, backwardThread int64 = -1, -1 + for tid, ti := range threadInfos { + if ti.hasForward && forwardThread == -1 { + forwardThread = tid + } + if ti.hasAutograd && !ti.hasForward && backwardThread == -1 { + backwardThread = tid + } + } + + // Remap: if forward and backward are different threads, merge backward into forward. + remapTID := func(tid int64) int64 { + if forwardThread >= 0 && backwardThread >= 0 && forwardThread != backwardThread { + if tid == backwardThread { + return forwardThread + } + } + return tid + } + + // Collect CPU events for call stack building. + var csEvents []analysis.CSEvent + for i := range events { + ev := &events[i] + if !ev.isCPU() { + continue + } + // Only include events that have nodes in the DAG. + if _, ok := dag.evToStart[i]; !ok { + continue + } + csEvents = append(csEvents, analysis.CSEvent{ + Idx: i, + ThreadID: remapTID(ev.ThreadID), + Start: ev.StartedAt, + End: ev.EndedAt, + }) + } + + stacks := analysis.BuildCallStacks(csEvents) + + for _, cs := range stacks { + lastNodeIdx := -1 + opDepth := 0 + + cs.DFSTraverse( + func(evIdx int, node *analysis.CallStackNode) { + startNIdx, startOK := dag.evToStart[evIdx] + _, endOK := dag.evToEnd[evIdx] + if !startOK || !endOK { + return + } + + if lastNodeIdx >= 0 && !sharesCudaRuntimeLock(dag, lastNodeIdx, startNIdx, events, runtimeSet) { + if opDepth == 0 { + dag.addEdgeHelper(lastNodeIdx, startNIdx, EdgeDependency, false) + } else { + dag.addEdgeHelper(lastNodeIdx, startNIdx, EdgeOperatorKernel, false) + } + } + + lastNodeIdx = startNIdx + opDepth++ + }, + func(evIdx int, node *analysis.CallStackNode) { + _, startOK := dag.evToStart[evIdx] + endNIdx, endOK := dag.evToEnd[evIdx] + if !startOK || !endOK { + return + } + + if lastNodeIdx >= 0 && !sharesCudaRuntimeLock(dag, lastNodeIdx, endNIdx, events, runtimeSet) { + startNIdx := dag.evToStart[evIdx] + zeroWeight := dag.nodes[startNIdx].IsBlocking + dag.addEdgeHelper(lastNodeIdx, endNIdx, EdgeOperatorKernel, zeroWeight) + } + + lastNodeIdx = endNIdx + opDepth-- + }, + ) + } +} + +// sharesCudaRuntimeLock checks if two nodes belong to events that both hold the +// CUDA runtime lock (already handled by Phase 2). +func sharesCudaRuntimeLock(dag *cpDAG, nodeA, nodeB int, events []cpEvent, runtimeSet map[int]bool) bool { + evA := dag.nodes[nodeA].EvIdx + evB := dag.nodes[nodeB].EvIdx + return runtimeSet[evA] && runtimeSet[evB] +} + +// --------------------------------------------------------------------------- +// Phase 4: GPU kernel edges +// --------------------------------------------------------------------------- + +func constructKernelEdges(dag *cpDAG, events []cpEvent, names map[int]string, + db *sql.DB, sym *symbol.Table, rank int, allEvents []cpEvent) error { + + // Build correlation map: correlation → cpEvent index (for CPU runtime events). + cpuCorrToIdx := make(map[int]int) + for i := range events { + if events[i].isCPU() && events[i].Correlation >= 0 { + cpuCorrToIdx[events[i].Correlation] = i + } + } + + // Build GPU event correlation → index map. + gpuCorrToIdx := make(map[int]int) + for i := range events { + if events[i].isGPU() && events[i].Correlation >= 0 { + gpuCorrToIdx[events[i].Correlation] = i + } + } + + // Get queue length time series for the rank. + qlPoints, err := resource.QueueLengthTimeSeriesForRank(db, sym, rank) + if err != nil { + return fmt.Errorf("queue length: %w", err) + } + + // Build queue length lookups by (stream, correlation). + // For each CPU launch event: the queue length at that timestamp. + // For each GPU kernel: the queue length at that timestamp. + type qlKey struct { + stream int + correlation int + } + qlAtRuntime := make(map[qlKey]int) // queue length when CPU launched + qlAtKernel := make(map[qlKey]int) // queue length when GPU kernel started + + if len(qlPoints) > 0 { + // Build per-stream sorted time series. + type streamQL struct { + ts int64 + ql int + } + streamQLMap := make(map[int][]streamQL) + for _, p := range qlPoints { + streamQLMap[p.Stream] = append(streamQLMap[p.Stream], streamQL{ts: p.Timestamp, ql: p.QueueLength}) + } + + // For matched CPU-GPU pairs, find queue length. + for i := range events { + ev := &events[i] + if !ev.isGPU() || ev.Stream < 0 || ev.Correlation < 0 { + continue + } + cpuIdx, hasCPU := cpuCorrToIdx[ev.Correlation] + if !hasCPU { + continue + } + cpuEv := &events[cpuIdx] + key := qlKey{stream: ev.Stream, correlation: ev.Correlation} + + series := streamQLMap[ev.Stream] + if len(series) > 0 { + // Queue length at CPU launch time. + idx := sort.Search(len(series), func(j int) bool { + return series[j].ts > cpuEv.StartedAt + }) + if idx > 0 { + qlAtRuntime[key] = series[idx-1].ql + } + // Queue length at GPU kernel start time. + idx = sort.Search(len(series), func(j int) bool { + return series[j].ts > ev.StartedAt + }) + if idx > 0 { + qlAtKernel[key] = series[idx-1].ql + } + } + } + } + + // Build CUDA event record and stream wait event maps. + cudaEventRecordMap := buildCudaEventRecordMap(events, names, cpuCorrToIdx, gpuCorrToIdx) + cudaStreamWaitMap := buildCudaStreamWaitEventMap(events, names, cpuCorrToIdx, gpuCorrToIdx) + + // Collect GPU events, sorted by start time for regular kernels + // and by end time for sync events. + type gpuEntry struct { + evIdx int + sortKey int64 + isSync bool + } + var gpuEntries []gpuEntry + for i := range events { + ev := &events[i] + if !ev.isGPU() { + continue + } + if _, ok := dag.evToStart[i]; !ok { + continue + } + gpuEntries = append(gpuEntries, gpuEntry{ + evIdx: i, + sortKey: ev.StartedAt, + isSync: false, + }) + } + + // Also collect CPU sync events that target GPU. + for i := range events { + ev := &events[i] + if !ev.isCPU() { + continue + } + name := names[ev.NameID] + if cudaSyncNames[name] { + if _, ok := dag.evToStart[i]; !ok { + continue + } + gpuEntries = append(gpuEntries, gpuEntry{ + evIdx: i, + sortKey: ev.EndedAt, // sync events sorted by end time + isSync: true, + }) + } + } + + sort.Slice(gpuEntries, func(i, j int) bool { + if gpuEntries[i].sortKey != gpuEntries[j].sortKey { + return gpuEntries[i].sortKey < gpuEntries[j].sortKey + } + return gpuEntries[i].evIdx < gpuEntries[j].evIdx + }) + + // Track last node per stream. + lastNodePerStream := make(map[int]int) // stream → node index + + // Deferred kernel sync map: dest GPU event idx → source GPU event idx. + kernelSync := make(map[int]int) + + for _, entry := range gpuEntries { + evIdx := entry.evIdx + ev := &events[evIdx] + + if entry.isSync { + // Handle CPU sync event. + handleCUDASync(dag, events, names, evIdx, lastNodePerStream, + cudaEventRecordMap, cpuCorrToIdx, gpuCorrToIdx) + continue + } + + // Regular GPU kernel. + startNIdx, startOK := dag.evToStart[evIdx] + endNIdx, endOK := dag.evToEnd[evIdx] + if !startOK || !endOK { + continue + } + + // 1. OPERATOR_KERNEL span edge (start→end). + dag.addEdgeHelper(startNIdx, endNIdx, EdgeOperatorKernel, false) + + // 2. Check for deferred kernel sync. + if srcEvIdx, ok := kernelSync[evIdx]; ok { + if srcEndNIdx, ok := dag.evToEnd[srcEvIdx]; ok { + dag.addEdgeHelper(srcEndNIdx, startNIdx, EdgeSyncDependency, false) + } + delete(kernelSync, evIdx) + } + + // 3. Process cudaStreamWaitEvent deferred syncs. + if dest, ok := cudaStreamWaitMap[evIdx]; ok { + // dest is the destination GPU kernel event index. + // evIdx is the source GPU kernel. Store for later. + // Actually the map is: streamWaitEvent cpuIdx → (srcGPUEvIdx, destGPUEvIdx) + // Let me use the map differently. + _ = dest + } + + // 4. Determine edge type: kernel launch delay vs kernel-kernel delay. + key := qlKey{stream: ev.Stream, correlation: ev.Correlation} + qlRuntime := qlAtRuntime[key] + qlKernel := qlAtKernel[key] + + cpuRuntimeIdx, hasCPURuntime := cpuCorrToIdx[ev.Correlation] + addedLaunchDelay := false + + if hasCPURuntime { + prevOnStream, hasPrev := lastNodePerStream[ev.Stream] + // Check for sync dependency + _, hasSyncDep := kernelSync[evIdx] + + isLaunchDelayBound := qlRuntime == 1 && qlKernel == 0 + + if isLaunchDelayBound { + // Additional checks: no previous kernel on stream blocking, + // or previous kernel ended before runtime call. + runtimeStart := events[cpuRuntimeIdx].StartedAt + prevClear := !hasPrev || dag.nodes[prevOnStream].Ts <= runtimeStart + syncClear := !hasSyncDep + + if prevClear && syncClear { + if runtimeStartNIdx, ok := dag.evToStart[cpuRuntimeIdx]; ok { + dag.addEdgeHelper(runtimeStartNIdx, startNIdx, EdgeKernelLaunchDelay, false) + addedLaunchDelay = true + } + } + } + } + + if !addedLaunchDelay { + if prevNIdx, ok := lastNodePerStream[ev.Stream]; ok { + gap := dag.nodes[startNIdx].Ts - dag.nodes[prevNIdx].Ts + if gap >= 0 && gap < kernelKernelDelayThresholdUs { + dag.addEdgeHelper(prevNIdx, startNIdx, EdgeKernelKernelDelay, false) + } + } + } + + // 5. Update last node per stream. + lastNodePerStream[ev.Stream] = endNIdx + } + + // Process cudaStreamWaitEvent entries → deferred syncs. + for waitEvIdx, syncInfo := range cudaStreamWaitMap { + srcEvIdx := syncInfo.srcGPUEvIdx + destEvIdx := syncInfo.destGPUEvIdx + if srcEvIdx >= 0 && destEvIdx >= 0 { + // Check if dest kernel already processed. If so, add edge now. + srcEndNIdx, srcOK := dag.evToEnd[srcEvIdx] + destStartNIdx, destOK := dag.evToStart[destEvIdx] + if srcOK && destOK { + dag.addEdgeHelper(srcEndNIdx, destStartNIdx, EdgeSyncDependency, false) + } + } + _ = waitEvIdx + } + + return nil +} + +// --------------------------------------------------------------------------- +// CUDA event record / stream wait event maps +// --------------------------------------------------------------------------- + +type syncInfo struct { + srcGPUEvIdx int // source GPU kernel event index + destGPUEvIdx int // destination GPU kernel event index +} + +// buildCudaEventRecordMap maps cudaEventRecord correlation → previous GPU kernel ev index. +func buildCudaEventRecordMap(events []cpEvent, names map[int]string, + cpuCorrToIdx, gpuCorrToIdx map[int]int) map[int]int { + + // For each cudaEventRecord: find the previous kernel launch on the same stream/pid. + // Returns: cudaEventRecord corrID → GPU kernel event index + + // Collect CUDA runtime launches sorted by (pid, stream, startedAt). + type launch struct { + evIdx int + pid int64 + stream int + startAt int64 + corr int + } + var launches []launch + for i := range events { + ev := &events[i] + if ev.isCPU() && ev.Correlation >= 0 { + name := names[ev.NameID] + if isRuntimeLaunchName(name) { + // Find the GPU kernel to get its stream. + if gpuIdx, ok := gpuCorrToIdx[ev.Correlation]; ok { + launches = append(launches, launch{ + evIdx: i, + pid: ev.ProcessID, + stream: events[gpuIdx].Stream, + startAt: ev.StartedAt, + corr: ev.Correlation, + }) + } + } + } + } + + sort.Slice(launches, func(i, j int) bool { + a, b := launches[i], launches[j] + if a.pid != b.pid { + return a.pid < b.pid + } + if a.stream != b.stream { + return a.stream < b.stream + } + return a.startAt < b.startAt + }) + + // Collect cudaEventRecord events. They use wait_on_cuda_event_record_corr_id + // from cudaStreamWaitEvent to map back to the GPU event's stream. + result := make(map[int]int) + + // Build map: cudaEventRecord event → associated stream via cudaStreamWaitEvent. + // First find all cudaStreamWaitEvent events and their wait_on_cuda_event_record_corr_id. + eventRecordToStream := make(map[int]int) // cudaEventRecord correlation → stream + for i := range events { + ev := &events[i] + name := names[ev.NameID] + if name == "cudaStreamWaitEvent" && ev.WaitOnCudaEventRecordCorrID >= 0 && ev.WaitOnStream >= 0 { + eventRecordToStream[ev.WaitOnCudaEventRecordCorrID] = ev.WaitOnStream + } + } + + // For each cudaEventRecord, find the previous GPU kernel launch on the same (pid, stream). + for i := range events { + ev := &events[i] + name := names[ev.NameID] + if name != "cudaEventRecord" { + continue + } + stream, ok := eventRecordToStream[ev.Correlation] + if !ok { + continue + } + + // Binary search for the latest launch on (pid, stream) before this event's timestamp. + bestIdx := -1 + for _, l := range launches { + if l.pid == ev.ProcessID && l.stream == stream && l.startAt <= ev.StartedAt { + bestIdx = l.evIdx + } + } + if bestIdx >= 0 { + // Map to the GPU kernel via correlation. + if gpuIdx, ok := gpuCorrToIdx[events[bestIdx].Correlation]; ok { + result[ev.Correlation] = gpuIdx + } + } + } + + return result +} + +func isRuntimeLaunchName(name string) bool { + for _, n := range resource.RuntimeLaunchNames { + if name == n { + return true + } + } + return false +} + +// buildCudaStreamWaitEventMap maps relevant events for inter-stream sync. +// Returns a map from cudaStreamWaitEvent CPU evIdx → syncInfo with src/dest GPU kernels. +func buildCudaStreamWaitEventMap(events []cpEvent, names map[int]string, + cpuCorrToIdx, gpuCorrToIdx map[int]int) map[int]syncInfo { + + result := make(map[int]syncInfo) + + // Collect CUDA runtime launches sorted by (pid, tid, stream, startedAt). + type launch struct { + evIdx int + pid int64 + tid int64 + stream int + startAt int64 + corr int + } + var launches []launch + for i := range events { + ev := &events[i] + if ev.isCPU() && ev.Correlation >= 0 { + name := names[ev.NameID] + if isRuntimeLaunchName(name) { + if gpuIdx, ok := gpuCorrToIdx[ev.Correlation]; ok { + launches = append(launches, launch{ + evIdx: i, + pid: ev.ProcessID, + tid: ev.ThreadID, + stream: events[gpuIdx].Stream, + startAt: ev.StartedAt, + corr: ev.Correlation, + }) + } + } + } + } + + sort.Slice(launches, func(i, j int) bool { + a, b := launches[i], launches[j] + if a.pid != b.pid { + return a.pid < b.pid + } + if a.tid != b.tid { + return a.tid < b.tid + } + if a.stream != b.stream { + return a.stream < b.stream + } + return a.startAt < b.startAt + }) + + // Build cudaEventRecord correlation → source GPU kernel map. + eventRecordMap := buildCudaEventRecordMap(events, names, cpuCorrToIdx, gpuCorrToIdx) + + for i := range events { + ev := &events[i] + name := names[ev.NameID] + if name != "cudaStreamWaitEvent" { + continue + } + if ev.WaitOnCudaEventRecordCorrID < 0 { + continue + } + + // Source GPU kernel: the one referenced by the event record. + srcGPUIdx, hasSrc := eventRecordMap[ev.WaitOnCudaEventRecordCorrID] + + // Destination GPU kernel: the next launch on this (pid, tid, stream) after the wait event. + destGPUIdx := -1 + // Find GPU stream for the GPU kernel that this wait will sync to. + // The stream is the stream of the cudaStreamWaitEvent's correlated GPU event. + if gpuIdx, ok := gpuCorrToIdx[ev.Correlation]; ok { + destStream := events[gpuIdx].Stream + // Find next launch after this wait event on (pid, tid, destStream). + for _, l := range launches { + if l.pid == ev.ProcessID && l.tid == ev.ThreadID && l.stream == destStream && l.startAt > ev.StartedAt { + if gpuNext, ok := gpuCorrToIdx[l.corr]; ok { + destGPUIdx = gpuNext + } + break + } + } + } + + if hasSrc { + si := syncInfo{srcGPUEvIdx: srcGPUIdx, destGPUEvIdx: destGPUIdx} + result[i] = si + } + } + + return result +} + +// handleCUDASync processes a CPU sync event and adds appropriate sync edges. +func handleCUDASync(dag *cpDAG, events []cpEvent, names map[int]string, + evIdx int, lastNodePerStream map[int]int, + cudaEventRecordMap map[int]int, cpuCorrToIdx, gpuCorrToIdx map[int]int) { + + ev := &events[evIdx] + name := names[ev.NameID] + endNIdx, endOK := dag.evToEnd[evIdx] + if !endOK { + return + } + + switch name { + case "cudaDeviceSynchronize": + // Context sync: all streams → this CPU event. + for _, nodeIdx := range lastNodePerStream { + dag.addEdgeHelper(nodeIdx, endNIdx, EdgeSyncDependency, false) + } + + case "cudaStreamSynchronize": + // Stream sync: specific stream → this CPU event. + // The stream is determined by the correlated GPU event. + if gpuIdx, ok := gpuCorrToIdx[ev.Correlation]; ok { + stream := events[gpuIdx].Stream + if lastN, ok := lastNodePerStream[stream]; ok { + dag.addEdgeHelper(lastN, endNIdx, EdgeSyncDependency, false) + } + } + + case "cudaEventSynchronize": + // Event sync: GPU kernel end → CPU event end. + if srcGPUIdx, ok := cudaEventRecordMap[ev.Correlation]; ok { + if srcEndNIdx, endOK := dag.evToEnd[srcGPUIdx]; endOK { + dag.addEdgeHelper(srcEndNIdx, endNIdx, EdgeSyncDependency, false) + } + } + + case "cudaStreamWaitEvent": + // Stream wait: handled via deferred kernel sync in the main loop. + // No immediate edge needed here. + } +} + +// --------------------------------------------------------------------------- +// DAG validation +// --------------------------------------------------------------------------- + +func validateDAG(dag *cpDAG) { + // Clamp -1 weights to 0 (nanosecond precision issues). + for i := range dag.edges { + if dag.edges[i].Weight == -1 { + dag.edges[i].Weight = 0 + } else if dag.edges[i].Weight < -1 { + log.Printf("WARNING: edge %d has negative weight %d (nodes %d→%d)", + i, dag.edges[i].Weight, dag.edges[i].Begin, dag.edges[i].End) + dag.edges[i].Weight = 0 + } + } +} + +// --------------------------------------------------------------------------- +// Longest path (topological sort + DP) +// --------------------------------------------------------------------------- + +func (g *cpDAG) longestPath() (path []int, totalWeight int64) { + n := len(g.nodes) + if n == 0 { + return nil, 0 + } + + // Kahn's algorithm for topological ordering. + inDeg := make([]int, n) + for _, e := range g.edges { + inDeg[e.End]++ + } + + var queue []int + for i := 0; i < n; i++ { + if inDeg[i] == 0 { + queue = append(queue, i) + } + } + + topoOrder := make([]int, 0, n) + for len(queue) > 0 { + u := queue[0] + queue = queue[1:] + topoOrder = append(topoOrder, u) + for _, eIdx := range g.adj[u] { + v := g.edges[eIdx].End + inDeg[v]-- + if inDeg[v] == 0 { + queue = append(queue, v) + } + } + } + + if len(topoOrder) != n { + log.Printf("WARNING: DAG has cycles (%d/%d nodes in topological order)", len(topoOrder), n) + } + + // Forward DP: dist[v] = max over all incoming edges of (dist[u] + weight(u,v)). + dist := make([]int64, n) + pred := make([]int, n) + for i := range dist { + dist[i] = 0 + pred[i] = -1 + } + + for _, u := range topoOrder { + for _, eIdx := range g.adj[u] { + e := g.edges[eIdx] + v := e.End + if dist[u]+e.Weight > dist[v] { + dist[v] = dist[u] + e.Weight + pred[v] = u + } + } + } + + // Find the node with maximum distance. + maxDist := int64(math.MinInt64) + maxNode := 0 + for i := 0; i < n; i++ { + if dist[i] > maxDist { + maxDist = dist[i] + maxNode = i + } + } + + // Backtrack to reconstruct the path. + var reversePath []int + cur := maxNode + for cur >= 0 { + reversePath = append(reversePath, cur) + cur = pred[cur] + } + + // Reverse. + path = make([]int, len(reversePath)) + for i, v := range reversePath { + path[len(reversePath)-1-i] = v + } + + return path, maxDist +} + +// --------------------------------------------------------------------------- +// Breakdown by bound type +// --------------------------------------------------------------------------- + +func computeBreakdown(dag *cpDAG, events []cpEvent, cpEdgesSet map[int]bool, + names map[int]string, dataLoadRegexps []*regexp.Regexp) []CriticalPathSummaryRow { + + boundDurations := make(map[BoundType]int64) + + for eIdx := range cpEdgesSet { + e := dag.edges[eIdx] + + if e.Weight <= 0 { + continue + } + + var bt BoundType + switch e.Type { + case EdgeKernelKernelDelay: + bt = BoundGPUKernelKernelOverhead + case EdgeKernelLaunchDelay: + bt = BoundGPUKernelLaunchOverhead + case EdgeDependency, EdgeSyncDependency: + continue // zero weight, skip + case EdgeOperatorKernel: + // Attribute to the event this edge corresponds to. + evIdx := attributeEdge(dag, e) + if evIdx < 0 { + continue + } + ev := &events[evIdx] + if ev.isCPU() { + name := names[ev.NameID] + isDataLoad := false + for _, re := range dataLoadRegexps { + if re.MatchString(name) { + isDataLoad = true + break + } + } + if isDataLoad { + bt = BoundDataLoading + } else { + bt = BoundCPU + } + } else { + // GPU event. + name, _ := getName(names, ev.NameID) + ktype := analysis.ClassifyKernel(name) + if ktype == analysis.KernelCommunication { + bt = BoundGPUCommunication + } else { + bt = BoundGPUCompute + } + } + default: + continue + } + + boundDurations[bt] += e.Weight + } + + // Compute total and percentages. + var total int64 + for _, d := range boundDurations { + total += d + } + + // Build result in a consistent order. + allBounds := []BoundType{ + BoundCPU, + BoundDataLoading, + BoundGPUCompute, + BoundGPUCommunication, + BoundGPUKernelKernelOverhead, + BoundGPUKernelLaunchOverhead, + } + + var result []CriticalPathSummaryRow + for _, bt := range allBounds { + dur := boundDurations[bt] + if dur == 0 { + continue + } + pct := 0.0 + if total > 0 { + pct = float64(dur) / float64(total) * 100.0 + } + result = append(result, CriticalPathSummaryRow{ + BoundType: bt.String(), + Duration: dur, + Percentage: math.Round(pct*100) / 100, + }) + } + + return result +} + +func getName(names map[int]string, nameID int) (string, bool) { + n, ok := names[nameID] + return n, ok +} + +// attributeEdge determines which event "owns" a critical path edge. +func attributeEdge(dag *cpDAG, e cpEdge) int { + src := dag.nodes[e.Begin] + dst := dag.nodes[e.End] + + switch e.Type { + case EdgeKernelKernelDelay: + return src.EvIdx // previous kernel + case EdgeOperatorKernel: + if src.IsStart && dst.IsStart { + return src.EvIdx + } + if src.IsStart && !dst.IsStart { + return src.EvIdx // same event (leaf span) + } + if !src.IsStart && !dst.IsStart { + return dst.EvIdx + } + if !src.IsStart && dst.IsStart { + // Gap between siblings — attribute to parent. + // Use src event's parent if available. + return src.EvIdx + } + } + return -1 +} + +// --------------------------------------------------------------------------- +// Annotation-based region of interest filtering +// --------------------------------------------------------------------------- + +func clipToAnnotation(events []cpEvent, sym *symbol.Table, opts CriticalPathOpts) ([]cpEvent, error) { + if opts.Annotation == "" { + // No annotation: use all events. + return events, nil + } + + // Find annotation events. + var annotationEvIdxs []int + for i := range events { + ev := &events[i] + if !ev.isCPU() { + continue + } + name, err := sym.GetName(ev.NameID) + if err != nil { + continue + } + if strings.Contains(name, opts.Annotation) { + annotationEvIdxs = append(annotationEvIdxs, i) + } + } + + if len(annotationEvIdxs) == 0 { + return nil, fmt.Errorf("no annotation events matching %q found", opts.Annotation) + } + + // Sort annotation events by start time. + sort.Slice(annotationEvIdxs, func(i, j int) bool { + return events[annotationEvIdxs[i]].StartedAt < events[annotationEvIdxs[j]].StartedAt + }) + + // Parse instance ID. + startInstance, endInstance, err := parseInstanceID(opts.InstanceID, len(annotationEvIdxs)) + if err != nil { + return nil, err + } + + // Compute time window. + var startTs, endTs int64 + startTs = math.MaxInt64 + endTs = math.MinInt64 + for i := startInstance; i <= endInstance && i < len(annotationEvIdxs); i++ { + ev := &events[annotationEvIdxs[i]] + if ev.StartedAt < startTs { + startTs = ev.StartedAt + } + if ev.EndedAt > endTs { + endTs = ev.EndedAt + } + } + + // Build correlation set: correlations from CPU events in the time window. + cpuCorrelations := make(map[int]bool) + for i := range events { + ev := &events[i] + if ev.isCPU() && ev.StartedAt >= startTs && ev.StartedAt <= endTs && ev.Duration > 0 { + if ev.Correlation >= 0 { + cpuCorrelations[ev.Correlation] = true + } + } + } + + // Filter events. + var filtered []cpEvent + for i := range events { + ev := &events[i] + if ev.isCPU() { + if ev.StartedAt >= startTs && ev.StartedAt <= endTs && ev.Duration > 0 { + filtered = append(filtered, *ev) + } + } else if ev.isGPU() { + // Include GPU events whose CPU runtime was in window. + if ev.Correlation >= 0 && cpuCorrelations[ev.Correlation] { + filtered = append(filtered, *ev) + } + } + } + + // Always include Stream Wait Event events regardless of timing + // (needed for inter-stream sync). + sweName := sym.GetID("Stream Wait Event") + if sweName >= 0 { + included := make(map[int64]bool) + for _, ev := range filtered { + included[ev.DBID] = true + } + for i := range events { + ev := &events[i] + if ev.NameID == sweName && !included[ev.DBID] { + filtered = append(filtered, *ev) + } + } + } + + return filtered, nil +} + +func parseInstanceID(s string, numAnnotations int) (start, end int, err error) { + s = strings.TrimSpace(s) + if s == "" { + return 0, 0, nil + } + + if strings.Contains(s, ",") { + parts := strings.SplitN(s, ",", 2) + start, err = parseInt(strings.TrimSpace(parts[0])) + if err != nil { + return 0, 0, fmt.Errorf("invalid instance-id start: %w", err) + } + end, err = parseInt(strings.TrimSpace(parts[1])) + if err != nil { + return 0, 0, fmt.Errorf("invalid instance-id end: %w", err) + } + return start, end, nil + } + + val, err := parseInt(s) + if err != nil { + return 0, 0, fmt.Errorf("invalid instance-id: %w", err) + } + return val, val, nil +} + +func parseInt(s string) (int, error) { + var v int + _, err := fmt.Sscanf(s, "%d", &v) + return v, err +} + +// --------------------------------------------------------------------------- +// Trace overlay +// --------------------------------------------------------------------------- + +func overlayCriticalPath(db *sql.DB, rank int, dag *cpDAG, + events []cpEvent, cpEdgesSet, cpEventsSet map[int]bool, + outputDir string, showAllEdges bool) (string, error) { + + // Get trace file path. + tracePath, err := store.TraceFile(db, rank) + if err != nil { + return "", fmt.Errorf("loading trace file path: %w", err) + } + + // Read raw trace. + rawData, err := kernel.ReadTraceFile(tracePath) + if err != nil { + return "", fmt.Errorf("reading trace file %s: %w", tracePath, err) + } + + // Parse top-level JSON. + var traceDoc map[string]json.RawMessage + if err := json.Unmarshal(rawData, &traceDoc); err != nil { + return "", fmt.Errorf("parsing trace JSON: %w", err) + } + + eventsRaw, ok := traceDoc["traceEvents"] + if !ok { + return "", fmt.Errorf("traceEvents key not found in trace file") + } + + var rawEvents []map[string]any + if err := json.Unmarshal(eventsRaw, &rawEvents); err != nil { + return "", fmt.Errorf("parsing traceEvents: %w", err) + } + + // Compute DB-to-raw timestamp offset. + dbMinTs, err := store.MinStartedAt(db, rank) + if err != nil { + return "", fmt.Errorf("loading min started_at: %w", err) + } + + var rawMinTs float64 + first := true + for _, ev := range rawEvents { + ts, ok := ev["ts"] + if !ok { + continue + } + var tsFloat float64 + switch v := ts.(type) { + case float64: + tsFloat = v + case json.Number: + tsFloat, _ = v.Float64() + default: + continue + } + if first || tsFloat < rawMinTs { + rawMinTs = tsFloat + first = false + } + } + offset := rawMinTs - float64(dbMinTs) + + // Build DB ID → cpEvent index map. + dbIDToEvIdx := make(map[int64]int, len(events)) + for i := range events { + dbIDToEvIdx[events[i].DBID] = i + } + + // Build adjusted timestamp lookup for matching raw events to DB events. + type tsMatch struct { + adjustedTs int64 + dbID int64 + evIdx int + } + var cpEventMatches []tsMatch + for evIdx := range cpEventsSet { + ev := &events[evIdx] + cpEventMatches = append(cpEventMatches, tsMatch{ + adjustedTs: int64(float64(ev.StartedAt) + offset + 0.5), + dbID: ev.DBID, + evIdx: evIdx, + }) + } + matchSet := make(map[int64]bool, len(cpEventMatches)) + for _, m := range cpEventMatches { + matchSet[m.adjustedTs] = true + } + + // Mark critical path events in raw trace. + for i, ev := range rawEvents { + ts, ok := ev["ts"] + if !ok { + continue + } + var tsFloat float64 + switch v := ts.(type) { + case float64: + tsFloat = v + case json.Number: + tsFloat, _ = v.Float64() + default: + continue + } + tsKey := int64(tsFloat + 0.5) + if matchSet[tsKey] { + args, _ := ev["args"].(map[string]any) + if args == nil { + args = make(map[string]any) + } + args["critical"] = 1 + rawEvents[i]["args"] = args + } + } + + // Create flow events for critical path edges. + var flowEvents []map[string]any + flowID := 0 + + edgeSet := cpEdgesSet + if showAllEdges { + edgeSet = make(map[int]bool) + for i := range dag.edges { + edgeSet[i] = true + } + } + + for eIdx := range edgeSet { + e := dag.edges[eIdx] + srcNode := dag.nodes[e.Begin] + dstNode := dag.nodes[e.End] + + srcEv := &events[srcNode.EvIdx] + dstEv := &events[dstNode.EvIdx] + + srcTs := float64(srcNode.Ts) + offset + dstTs := float64(dstNode.Ts) + offset + + // Adjust end-node timestamps slightly for visibility. + if !srcNode.IsStart && srcEv.Duration > 0 { + srcTs = float64(srcEv.StartedAt) + offset + float64(srcEv.Duration) - math.Min(1, float64(srcEv.Duration)) + } + if !dstNode.IsStart && dstEv.Duration > 0 { + dstTs = float64(dstEv.StartedAt) + offset + float64(dstEv.Duration) - math.Min(1, float64(dstEv.Duration)) + } + + cat := e.Type.String() + + // Flow start event. + flowEvents = append(flowEvents, map[string]any{ + "ph": "s", + "id": flowID, + "pid": srcEv.ProcessID, + "tid": srcEv.ThreadID, + "ts": srcTs, + "cat": cat, + "name": cat, + }) + // Flow end event. + flowEvents = append(flowEvents, map[string]any{ + "ph": "f", + "id": flowID, + "pid": dstEv.ProcessID, + "tid": dstEv.ThreadID, + "ts": dstTs, + "cat": cat, + "name": cat, + "bp": "e", + }) + flowID++ + } + + // Append flow events. + rawEvents = append(rawEvents, flowEvents...) + + // Marshal back. + annotatedEvents, err := json.Marshal(rawEvents) + if err != nil { + return "", fmt.Errorf("marshaling annotated events: %w", err) + } + traceDoc["traceEvents"] = annotatedEvents + outputData, err := json.Marshal(traceDoc) + if err != nil { + return "", fmt.Errorf("marshaling output trace: %w", err) + } + + // Write output file. + baseName := filepath.Base(tracePath) + baseName, _ = strings.CutSuffix(baseName, ".gz") + outputFile := filepath.Join(outputDir, "overlaid_critical_path_"+baseName) + if err := os.MkdirAll(outputDir, 0755); err != nil { + return "", fmt.Errorf("creating output dir: %w", err) + } + if err := os.WriteFile(outputFile, outputData, 0644); err != nil { + return "", fmt.Errorf("writing overlay file: %w", err) + } + + return outputFile, nil +} diff --git a/pkg/analysis/criticalpath/critical_path_test.go b/pkg/analysis/criticalpath/critical_path_test.go new file mode 100644 index 0000000..fdd7d29 --- /dev/null +++ b/pkg/analysis/criticalpath/critical_path_test.go @@ -0,0 +1,158 @@ +package criticalpath + +import ( + "os" + "path/filepath" + "testing" + + "hta/pkg/analysis" + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func TestCriticalPathAlexnet(t *testing.T) { + traceDir := filepath.Join("..", "..", "..", "tests", "data", "critical_path", "alexnet") + if _, err := os.Stat(traceDir); os.IsNotExist(err) { + t.Skipf("test data not found: %s", traceDir) + } + + // Set up in-memory database. + db, err := store.Create(":memory:") + if err != nil { + t.Fatalf("creating db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Fatalf("pre-processing: %v", err) + } + + // Create temp output dir for overlay. + outputDir := t.TempDir() + + opts := CriticalPathOpts{ + Rank: 0, + Annotation: "[param|pytorch.model.alex_net|0|0|0|measure|forward]", + InstanceID: "1", + OutputDir: outputDir, + } + + result, err := CriticalPath(db, opts) + if err != nil { + t.Fatalf("CriticalPath: %v", err) + } + + // Basic sanity checks. + if result.NumNodes == 0 { + t.Error("expected non-zero node count") + } + if result.NumEdges == 0 { + t.Error("expected non-zero edge count") + } + if result.PathLength <= 0 { + t.Error("expected positive path length") + } + if len(result.Summary) == 0 { + t.Error("expected non-empty summary") + } + + // Check that percentages sum to approximately 100%. + var totalPct float64 + for _, r := range result.Summary { + totalPct += r.Percentage + t.Logf(" %s: %d us (%.2f%%)", r.BoundType, r.Duration, r.Percentage) + } + t.Logf("Total percentage: %.2f%%", totalPct) + t.Logf("Nodes: %d, Edges: %d, Path length: %d us", result.NumNodes, result.NumEdges, result.PathLength) + + if totalPct < 90 || totalPct > 110 { + t.Errorf("percentages sum to %.2f%%, expected ~100%%", totalPct) + } + + // Check overlay file was created. + if result.OverlayFile == "" { + t.Error("expected overlay file path") + } else { + if _, err := os.Stat(result.OverlayFile); os.IsNotExist(err) { + t.Errorf("overlay file not found: %s", result.OverlayFile) + } else { + t.Logf("Overlay file: %s", result.OverlayFile) + } + } +} + +func TestCallStackBuild(t *testing.T) { + t.Parallel() + events := []analysis.CSEvent{ + {Idx: 0, ThreadID: 1, Start: 0, End: 100}, + {Idx: 1, ThreadID: 1, Start: 10, End: 50}, + {Idx: 2, ThreadID: 1, Start: 20, End: 30}, + {Idx: 3, ThreadID: 1, Start: 60, End: 90}, + } + + stacks := analysis.BuildCallStacks(events) + if len(stacks) != 1 { + t.Fatalf("expected 1 stack, got %d", len(stacks)) + } + + cs := stacks[0] + if len(cs.Roots) != 1 { + t.Fatalf("expected 1 root, got %d", len(cs.Roots)) + } + if cs.Roots[0] != 0 { + t.Errorf("expected root evIdx 0, got %d", cs.Roots[0]) + } + + // Node 0 (root) should have children [1, 3]. + root := cs.Nodes[0] + if root.Depth != 0 { + t.Errorf("root depth = %d, want 0", root.Depth) + } + if len(root.Children) != 2 { + t.Fatalf("root children = %v, want [1, 3]", root.Children) + } + + // Node 1 should have child [2]. + n1 := cs.Nodes[1] + if n1.Depth != 1 { + t.Errorf("node 1 depth = %d, want 1", n1.Depth) + } + if len(n1.Children) != 1 || n1.Children[0] != 2 { + t.Errorf("node 1 children = %v, want [2]", n1.Children) + } + + // Node 2 should be a leaf at depth 2. + n2 := cs.Nodes[2] + if n2.Depth != 2 { + t.Errorf("node 2 depth = %d, want 2", n2.Depth) + } + if len(n2.Children) != 0 { + t.Errorf("node 2 children = %v, want []", n2.Children) + } + + // DFS traversal order. + var enterOrder, exitOrder []int + cs.DFSTraverse( + func(evIdx int, node *analysis.CallStackNode) { + enterOrder = append(enterOrder, evIdx) + }, + func(evIdx int, node *analysis.CallStackNode) { + exitOrder = append(exitOrder, evIdx) + }, + ) + + expectedEnter := []int{0, 1, 2, 3} + expectedExit := []int{2, 1, 3, 0} + for i, v := range expectedEnter { + if i >= len(enterOrder) || enterOrder[i] != v { + t.Errorf("enter order mismatch at %d: got %v, want %v", i, enterOrder, expectedEnter) + break + } + } + for i, v := range expectedExit { + if i >= len(exitOrder) || exitOrder[i] != v { + t.Errorf("exit order mismatch at %d: got %v, want %v", i, exitOrder, expectedExit) + break + } + } +} diff --git a/pkg/analysis/helpers_test.go b/pkg/analysis/helpers_test.go new file mode 100644 index 0000000..f4886d4 --- /dev/null +++ b/pkg/analysis/helpers_test.go @@ -0,0 +1,20 @@ +package analysis + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func testDataDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + // pkg/analysis/helpers_test.go → project root + root := filepath.Join(filepath.Dir(thisFile), "..", "..") + dir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(dir); err != nil { + t.Skipf("test data not found: %s", dir) + } + return dir +} diff --git a/pkg/analysis/intervals.go b/pkg/analysis/intervals.go new file mode 100644 index 0000000..785edc1 --- /dev/null +++ b/pkg/analysis/intervals.go @@ -0,0 +1,32 @@ +package analysis + +// Interval represents a time range [Start, End) in microseconds. +type Interval struct { + Start int64 + End int64 +} + +// MergeIntervals merges overlapping intervals from a sorted slice. +// Input must be sorted by Start (ascending). +// Returns merged intervals with no overlaps. +func MergeIntervals(sorted []Interval) []Interval { + if len(sorted) == 0 { + return nil + } + merged := make([]Interval, 0, len(sorted)/2+1) + cur := sorted[0] + for i := 1; i < len(sorted); i++ { + if sorted[i].Start <= cur.End { + // Overlap: extend current interval + if sorted[i].End > cur.End { + cur.End = sorted[i].End + } + } else { + // No overlap: emit current, start new + merged = append(merged, cur) + cur = sorted[i] + } + } + merged = append(merged, cur) + return merged +} diff --git a/pkg/analysis/intervals_test.go b/pkg/analysis/intervals_test.go new file mode 100644 index 0000000..1d7d6e8 --- /dev/null +++ b/pkg/analysis/intervals_test.go @@ -0,0 +1,60 @@ +package analysis + +import ( + "reflect" + "testing" +) + +func TestMergeIntervals(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []Interval + expect []Interval + }{ + { + name: "empty", + input: nil, + expect: nil, + }, + { + name: "single", + input: []Interval{{0, 10}}, + expect: []Interval{{0, 10}}, + }, + { + name: "no overlap", + input: []Interval{{0, 5}, {10, 15}, {20, 25}}, + expect: []Interval{{0, 5}, {10, 15}, {20, 25}}, + }, + { + name: "full overlap", + input: []Interval{{0, 10}, {2, 8}}, + expect: []Interval{{0, 10}}, + }, + { + name: "partial overlap", + input: []Interval{{0, 5}, {3, 8}, {10, 15}}, + expect: []Interval{{0, 8}, {10, 15}}, + }, + { + name: "adjacent (touching)", + input: []Interval{{0, 5}, {5, 10}}, + expect: []Interval{{0, 10}}, + }, + { + name: "chain overlap", + input: []Interval{{0, 3}, {1, 5}, {4, 8}}, + expect: []Interval{{0, 8}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MergeIntervals(tt.input) + if !reflect.DeepEqual(got, tt.expect) { + t.Errorf("MergeIntervals(%v) = %v; want %v", tt.input, got, tt.expect) + } + }) + } +} diff --git a/pkg/analysis/kernel/annotation.go b/pkg/analysis/kernel/annotation.go new file mode 100644 index 0000000..e3fbba2 --- /dev/null +++ b/pkg/analysis/kernel/annotation.go @@ -0,0 +1,131 @@ +package kernel + +import ( + "database/sql" + "strconv" + + "hta/pkg/analysis" + "hta/pkg/store" +) + +// AnnotationOpts configures the GPUKernelsWithAnnotations analysis. +type AnnotationOpts struct { + Rank int + ExpandNames bool // resolve symbol IDs to human-readable strings + ShortenNames bool // remove template/function params from names +} + +// AnnotatedKernelRow is a single GPU kernel with its associated user annotation. +type AnnotatedKernelRow struct { + StartedAt int64 + EndedAt int64 + KernelName string // resolved name (or numeric ID as string) + UserAnnotation string // annotation name, or "" if unmatched +} + +// GPUKernelsWithAnnotations associates each GPU kernel with its innermost +// (leaf) user annotation by interval overlap. Annotations are processed +// longest-first so the shortest (leaf) always wins. +func GPUKernelsWithAnnotations(db *sql.DB, opts AnnotationOpts) ([]AnnotatedKernelRow, error) { + symTable, err := store.LoadSymbolTable(db) + if err != nil { + return nil, err + } + + // Check if gpu_user_annotation exists in the symbol table. + if symTable.GetID("gpu_user_annotation") == -1 { + return nil, nil // no annotations in this trace + } + + annotations, err := store.LoadUserAnnotations(db, opts.Rank) + if err != nil { + return nil, err + } + if len(annotations) == 0 { + return nil, nil + } + + kernels, err := store.LoadGPUKernelsForAnnotation(db, opts.Rank) + if err != nil { + return nil, err + } + if len(kernels) == 0 { + return nil, nil + } + + // annotationID[i] holds the name ID of the annotation assigned to kernel i. + // -1 means unmatched. + annotationID := make([]int, len(kernels)) + for i := range annotationID { + annotationID[i] = -1 + } + + // Group annotations by (pid, tid). + type pidTid struct { + pid, tid int + } + annoGroups := make(map[pidTid][]store.UserAnnotationRow) + for _, a := range annotations { + key := pidTid{a.ProcessID, a.ThreadID} + annoGroups[key] = append(annoGroups[key], a) + } + + // For each (pid, tid) group, match kernels to annotations. + for key, annos := range annoGroups { + // Pre-filter kernel indices with matching (pid, tid). + var matchIdx []int + for i, k := range kernels { + if k.ProcessID == key.pid && k.ThreadID == key.tid { + matchIdx = append(matchIdx, i) + } + } + if len(matchIdx) == 0 { + continue + } + + // Annotations are already sorted by duration DESC (longest first). + // Processing in this order means the leaf (shortest) annotation + // always overwrites, giving us the innermost match. + for _, a := range annos { + for _, idx := range matchIdx { + k := kernels[idx] + // Overlap check: k.Start < a.End && k.End > a.Start + if k.StartedAt < a.EndedAt && k.EndedAt > a.StartedAt { + annotationID[idx] = a.NameID + } + } + } + } + + // Build result rows. + result := make([]AnnotatedKernelRow, len(kernels)) + for i, k := range kernels { + row := AnnotatedKernelRow{ + StartedAt: k.StartedAt, + EndedAt: k.EndedAt, + } + + if opts.ExpandNames { + name, _ := symTable.GetName(k.NameID) + if opts.ShortenNames { + name = analysis.ShortenName(name) + } + row.KernelName = name + + if annotationID[i] != -1 { + annoName, _ := symTable.GetName(annotationID[i]) + row.UserAnnotation = annoName + } + } else { + row.KernelName = strconv.Itoa(k.NameID) + if annotationID[i] != -1 { + row.UserAnnotation = strconv.Itoa(annotationID[i]) + } + } + + result[i] = row + } + + return result, nil +} + diff --git a/pkg/analysis/kernel/annotation_test.go b/pkg/analysis/kernel/annotation_test.go new file mode 100644 index 0000000..b26330f --- /dev/null +++ b/pkg/analysis/kernel/annotation_test.go @@ -0,0 +1,85 @@ +package kernel + +import ( + "testing" +) + +func TestGPUKernelsWithAnnotationsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedNSDB(t) + defer db.Close() + + opts := AnnotationOpts{ + Rank: 0, + ExpandNames: true, + ShortenNames: true, + } + results, err := GPUKernelsWithAnnotations(db, opts) + if err != nil { + t.Fatalf("GPUKernelsWithAnnotations: %v", err) + } + + // Python reference: len(gpu_kernels_df) == 4876 + if len(results) != 4876 { + t.Errorf("expected 4876 kernels, got %d", len(results)) + } + + // Count unique annotations (Python: 3 unique + 1 for unmatched = 4) + annoSet := make(map[string]bool) + for _, r := range results { + annoSet[r.UserAnnotation] = true + } + // Python has 4 unique values including -1 (which maps to ""). + // In Go, unmatched is "" so we expect 4 unique values: "", plus 3 annotations. + if len(annoSet) != 4 { + t.Errorf("expected 4 unique annotations (including unmatched), got %d: %v", len(annoSet), annoSet) + } + + // Count kernels with "Optimizer.step#SGD.step" annotation (Python: 27) + sgdCount := 0 + for _, r := range results { + if r.UserAnnotation == "Optimizer.step#SGD.step" { + sgdCount++ + } + } + if sgdCount != 27 { + t.Errorf("expected 27 kernels with Optimizer.step#SGD.step, got %d", sgdCount) + } + + // Verify at least some kernels have non-empty annotations + annotatedCount := 0 + for _, r := range results { + if r.UserAnnotation != "" { + annotatedCount++ + } + } + if annotatedCount == 0 { + t.Error("expected some kernels to have annotations") + } +} + +func TestGPUKernelsWithAnnotationsNoAnnotations(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // vision_transformer trace does not have gpu_user_annotations + db := openSharedVTDB(t) + defer db.Close() + + opts := AnnotationOpts{ + Rank: 0, + ExpandNames: true, + } + results, err := GPUKernelsWithAnnotations(db, opts) + if err != nil { + t.Fatalf("GPUKernelsWithAnnotations: %v", err) + } + + if results != nil { + t.Errorf("expected nil for trace without annotations, got %d results", len(results)) + } +} diff --git a/pkg/analysis/kernel/aten_delay.go b/pkg/analysis/kernel/aten_delay.go new file mode 100644 index 0000000..94b4514 --- /dev/null +++ b/pkg/analysis/kernel/aten_delay.go @@ -0,0 +1,308 @@ +package kernel + +import ( + "database/sql" + "fmt" + "math" + "sort" + "strings" + + "hta/pkg/store" + "hta/pkg/symbol" +) + +// AtenDelayOpts controls the ATen op kernels and delay analysis. +type AtenDelayOpts struct { + Ranks []int // nil = all ranks + SortBy []string // default: ["occurrence_count"] +} + +// AtenDelayRow holds aggregated delay metrics for one (aten_op, kernel_sequence) pair. +type AtenDelayRow struct { + AtenOpName string + KernelSequence string + OccurrenceCount int + AvgAtenOpLaunchDelay float64 // microseconds, rounded to 3 decimals + AvgRuntimeDelay float64 // microseconds, rounded to 3 decimals +} + +// AtenOpKernelsAndDelay maps ATen operators to their launched GPU kernels and +// computes delay metrics: time from ATen op start to CUDA runtime launch, and +// time from runtime launch end to GPU kernel start. +func AtenOpKernelsAndDelay(db *sql.DB, opts AtenDelayOpts) (map[int][]AtenDelayRow, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + // Collect ATen operator name IDs. + var atenNameIDs []int + for _, entry := range sym.All() { + if strings.HasPrefix(entry.Name, "aten::") { + atenNameIDs = append(atenNameIDs, entry.ID) + } + } + if len(atenNameIDs) == 0 { + return nil, fmt.Errorf("no aten:: symbols found in symbol table") + } + + // Collect runtime launch name IDs. + var launchNameIDs []int + for _, name := range []string{"cudaLaunchKernel", "cudaLaunchKernelExC"} { + id := sym.GetID(name) + if id >= 0 { + launchNameIDs = append(launchNameIDs, id) + } + } + if len(launchNameIDs) == 0 { + return nil, fmt.Errorf("no cuda launch symbols found in symbol table") + } + + // Determine ranks. + ranks := opts.Ranks + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + sortBy := opts.SortBy + if len(sortBy) == 0 { + sortBy = []string{"occurrence_count"} + } + + result := make(map[int][]AtenDelayRow, len(ranks)) + for _, rank := range ranks { + rows, err := atenDelayPerRank(db, rank, atenNameIDs, launchNameIDs, sym, sortBy) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + result[rank] = rows + } + return result, nil +} + +// rawMatch holds one matched runtime launch → ATen op → GPU kernel triple. +type rawMatch struct { + threadID int64 + atenOpName string + atenOpStart int64 + kernelName string + atenOpDelay int64 // runtime.StartedAt - atenOp.StartedAt + runtimeDelay int64 // gpuKernel.StartedAt - runtime.EndedAt +} + +func atenDelayPerRank( + db *sql.DB, + rank int, + atenNameIDs, launchNameIDs []int, + sym *symbol.Table, + sortBy []string, +) ([]AtenDelayRow, error) { + // Load ATen ops and group by thread. + atenOps, err := store.LoadATenOperators(db, rank, atenNameIDs) + if err != nil { + return nil, fmt.Errorf("loading ATen ops: %w", err) + } + atenByThread := make(map[int64][]store.ATenOpRow) + for _, op := range atenOps { + atenByThread[op.ThreadID] = append(atenByThread[op.ThreadID], op) + } + + // Load runtime launch events (with thread info). + runtimeEvents, err := store.LoadCPURuntimeEventsWithThread(db, rank, launchNameIDs) + if err != nil { + return nil, fmt.Errorf("loading runtime events: %w", err) + } + + // Load GPU kernels and build correlation map. + gpuKernels, err := store.LoadGPUKernelsWithCorrelation(db, rank) + if err != nil { + return nil, fmt.Errorf("loading GPU kernels: %w", err) + } + gpuByCorr := make(map[int]store.GPUKernelCorrelationRow, len(gpuKernels)) + for _, k := range gpuKernels { + gpuByCorr[k.Correlation] = k + } + + // Match: for each runtime launch, find the deepest wrapping ATen op + // and the correlated GPU kernel. + var matches []rawMatch + for _, rt := range runtimeEvents { + ops := atenByThread[rt.ThreadID] + if len(ops) == 0 { + continue + } + + gpu, ok := gpuByCorr[rt.Correlation] + if !ok { + continue + } + + // Find the deepest ATen op that wraps this runtime event. + // ops are sorted by started_at; binary search for insertion point. + atenOp, found := findDeepestWrapper(ops, rt.StartedAt, rt.EndedAt) + if !found { + continue + } + + atenName, err := sym.GetName(atenOp.NameID) + if err != nil { + continue + } + kernelName, err := sym.GetName(gpu.NameID) + if err != nil { + continue + } + + matches = append(matches, rawMatch{ + threadID: rt.ThreadID, + atenOpName: atenName, + atenOpStart: atenOp.StartedAt, + kernelName: kernelName, + atenOpDelay: rt.StartedAt - atenOp.StartedAt, + runtimeDelay: gpu.StartedAt - rt.EndedAt, + }) + } + + // First grouping: by (threadID, atenOpName, atenOpStart) to build kernel sequences. + type groupKey struct { + threadID int64 + atenOpName string + atenOpStart int64 + } + type perOpGroup struct { + kernelNames []string + firstLaunchDelay int64 + firstRuntimeDelay int64 + } + + // Use a slice to preserve insertion order within each group. + groupOrder := make([]groupKey, 0) + groups := make(map[groupKey]*perOpGroup) + for _, m := range matches { + key := groupKey{m.threadID, m.atenOpName, m.atenOpStart} + g, exists := groups[key] + if !exists { + g = &perOpGroup{ + firstLaunchDelay: m.atenOpDelay, + firstRuntimeDelay: m.runtimeDelay, + } + groups[key] = g + groupOrder = append(groupOrder, key) + } + g.kernelNames = append(g.kernelNames, m.kernelName) + } + + // Final aggregation: by (atenOpName, kernelSequence). + type aggKey struct { + atenOpName string + kernelSequence string + } + type aggValue struct { + count int + launchDelays []float64 + runtimeDelays []float64 + } + + aggs := make(map[aggKey]*aggValue) + for _, key := range groupOrder { + g := groups[key] + kernelSeq := strings.Join(g.kernelNames, " -> ") + ak := aggKey{key.atenOpName, kernelSeq} + av, exists := aggs[ak] + if !exists { + av = &aggValue{} + aggs[ak] = av + } + av.count++ + av.launchDelays = append(av.launchDelays, float64(g.firstLaunchDelay)) + av.runtimeDelays = append(av.runtimeDelays, float64(g.firstRuntimeDelay)) + } + + // Build result rows. + rows := make([]AtenDelayRow, 0, len(aggs)) + for ak, av := range aggs { + rows = append(rows, AtenDelayRow{ + AtenOpName: ak.atenOpName, + KernelSequence: ak.kernelSequence, + OccurrenceCount: av.count, + AvgAtenOpLaunchDelay: round3(mean(av.launchDelays)), + AvgRuntimeDelay: round3(mean(av.runtimeDelays)), + }) + } + + // Sort. + sortAtenDelayRows(rows, sortBy) + + return rows, nil +} + +// findDeepestWrapper finds the deepest ATen op that wraps the interval +// [startedAt, endedAt]. ops must be sorted by StartedAt. +func findDeepestWrapper(ops []store.ATenOpRow, startedAt, endedAt int64) (store.ATenOpRow, bool) { + // Binary search: find the rightmost op where StartedAt <= startedAt. + idx := sort.Search(len(ops), func(i int) bool { + return ops[i].StartedAt > startedAt + }) + // idx is the first op with StartedAt > startedAt; scan backwards for wrappers. + var best store.ATenOpRow + found := false + for i := idx - 1; i >= 0; i-- { + if ops[i].StartedAt > startedAt { + continue + } + if ops[i].EndedAt >= endedAt { + // This wraps the event. The first one found scanning backwards + // from the insertion point is the deepest (latest start, still wrapping). + if !found { + best = ops[i] + found = true + break + } + } + } + return best, found +} + +func mean(vals []float64) float64 { + if len(vals) == 0 { + return 0 + } + var sum float64 + for _, v := range vals { + sum += v + } + return sum / float64(len(vals)) +} + +func round3(v float64) float64 { + return math.Round(v*1000) / 1000 +} + +func sortAtenDelayRows(rows []AtenDelayRow, sortBy []string) { + sort.SliceStable(rows, func(i, j int) bool { + for _, col := range sortBy { + switch col { + case "occurrence_count": + if rows[i].OccurrenceCount != rows[j].OccurrenceCount { + return rows[i].OccurrenceCount > rows[j].OccurrenceCount + } + case "avg_aten_op_launch_delay": + if rows[i].AvgAtenOpLaunchDelay != rows[j].AvgAtenOpLaunchDelay { + return rows[i].AvgAtenOpLaunchDelay > rows[j].AvgAtenOpLaunchDelay + } + case "avg_runtime_delay": + if rows[i].AvgRuntimeDelay != rows[j].AvgRuntimeDelay { + return rows[i].AvgRuntimeDelay > rows[j].AvgRuntimeDelay + } + case "aten_op_name": + if rows[i].AtenOpName != rows[j].AtenOpName { + return rows[i].AtenOpName < rows[j].AtenOpName + } + } + } + return false + }) +} diff --git a/pkg/analysis/kernel/aten_delay_test.go b/pkg/analysis/kernel/aten_delay_test.go new file mode 100644 index 0000000..9c4f008 --- /dev/null +++ b/pkg/analysis/kernel/aten_delay_test.go @@ -0,0 +1,58 @@ +package kernel + +import ( + "strings" + "testing" +) + +func TestAtenOpKernelsAndDelayIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + result, err := AtenOpKernelsAndDelay(db, AtenDelayOpts{}) + if err != nil { + t.Fatalf("aten op kernels and delay: %v", err) + } + + if len(result) == 0 { + t.Fatal("expected non-empty results") + } + + for rank, rows := range result { + if len(rows) == 0 { + t.Errorf("rank %d: expected non-empty rows", rank) + continue + } + + for i, r := range rows { + if !strings.HasPrefix(r.AtenOpName, "aten::") { + t.Errorf("rank %d row %d: AtenOpName %q does not start with aten::", + rank, i, r.AtenOpName) + } + if r.KernelSequence == "" { + t.Errorf("rank %d row %d: KernelSequence is empty", rank, i) + } + if r.OccurrenceCount <= 0 { + t.Errorf("rank %d row %d: OccurrenceCount %d is not positive", + rank, i, r.OccurrenceCount) + } + if r.AvgAtenOpLaunchDelay < 0 { + t.Errorf("rank %d row %d: AvgAtenOpLaunchDelay %.3f is negative", + rank, i, r.AvgAtenOpLaunchDelay) + } + } + + // Verify default sort: occurrence_count descending. + for i := 1; i < len(rows); i++ { + if rows[i].OccurrenceCount > rows[i-1].OccurrenceCount { + t.Errorf("rank %d: rows not sorted by occurrence_count descending at index %d (%d > %d)", + rank, i, rows[i].OccurrenceCount, rows[i-1].OccurrenceCount) + break + } + } + } +} diff --git a/pkg/analysis/kernel/helpers_test.go b/pkg/analysis/kernel/helpers_test.go new file mode 100644 index 0000000..3ef3105 --- /dev/null +++ b/pkg/analysis/kernel/helpers_test.go @@ -0,0 +1,60 @@ +package kernel + +import ( + "database/sql" + "os" + "path/filepath" + "runtime" + "testing" + + "hta/pkg/store" +) + +func testDataDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + dir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(dir); err != nil { + t.Skipf("test data not found: %s", dir) + } + return dir +} + +func nsResolutionTraceDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + dir := filepath.Join(root, "tests", "data", "ns_resolution_trace") + if _, err := os.Stat(dir); err != nil { + t.Skipf("test data not found: %s", dir) + } + return dir +} + +// openSharedVTDB opens a read handle to the pre-built vision_transformer DB. +// The DB was preprocessed once in TestMain. Each caller gets its own *sql.DB. +func openSharedVTDB(t *testing.T) *sql.DB { + t.Helper() + if sharedVTDBPath == "" { + t.Skip("vision_transformer test data not available") + } + db, err := store.Create(sharedVTDBPath) + if err != nil { + t.Fatalf("open shared VT db: %v", err) + } + return db +} + +// openSharedNSDB opens a read handle to the pre-built ns_resolution_trace DB. +func openSharedNSDB(t *testing.T) *sql.DB { + t.Helper() + if sharedNSDBPath == "" { + t.Skip("ns_resolution_trace test data not available") + } + db, err := store.Create(sharedNSDBPath) + if err != nil { + t.Fatalf("open shared NS db: %v", err) + } + return db +} diff --git a/pkg/analysis/kernel/kernel_breakdown.go b/pkg/analysis/kernel/kernel_breakdown.go new file mode 100644 index 0000000..25db548 --- /dev/null +++ b/pkg/analysis/kernel/kernel_breakdown.go @@ -0,0 +1,268 @@ +package kernel + +import ( + "database/sql" + "fmt" + "math" + "sort" + + "hta/pkg/analysis" + "hta/pkg/store" +) + +// KernelBreakdownOpts configures the GPU kernel breakdown analysis. +type KernelBreakdownOpts struct { + DurationRatio float64 // cumulative % cutoff (default 0.8) + NumKernels int // max kernels per type per rank (default 10) + IncludeMemory bool // include MEMORY type (default false) +} + +// KernelTypeRow holds the GPU time distribution for one kernel type. +type KernelTypeRow struct { + KernelType string + SumUs int64 + Percentage float64 +} + +// TopKernelRow holds statistics for an individual kernel (or "others" aggregate). +type TopKernelRow struct { + Name string + SumUs int64 + MaxUs int64 + MinUs int64 + MeanUs int64 + Stddev float64 + KernelType string + Rank int +} + +// GPUKernelBreakdownResult holds both the type-level breakdown and top-N kernel details. +type GPUKernelBreakdownResult struct { + TypeBreakdown []KernelTypeRow + TopKernels []TopKernelRow +} + +type kernelStats struct { + Name string + SumUs int64 + MaxUs int64 + MinUs int64 + MeanUs int64 + Stddev float64 +} + +// GPUKernelBreakdown computes the GPU kernel breakdown for all ranks in the DB. +func GPUKernelBreakdown(db *sql.DB, opts KernelBreakdownOpts) (*GPUKernelBreakdownResult, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + ranks, err := store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + + typesToAnalyze := []analysis.KernelType{analysis.KernelComputation, analysis.KernelCommunication} + if opts.IncludeMemory { + typesToAnalyze = append(typesToAnalyze, analysis.KernelMemory) + } + + typeSumMap := make(map[string]int64) + var allTopKernels []TopKernelRow + + for _, rank := range ranks { + kernels, err := store.LoadGPUKernels(db, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: loading GPU kernels: %w", rank, err) + } + + // Classify kernels and group durations by type and name. + typeNameDurs := make(map[analysis.KernelType]map[string][]int64) + typeDurSum := make(map[analysis.KernelType]int64) + + for _, k := range kernels { + name, err := sym.GetName(k.NameID) + if err != nil { + return nil, fmt.Errorf("symbol lookup: %w", err) + } + kt := analysis.ClassifyKernel(name) + dur := k.EndedAt - k.StartedAt + if typeNameDurs[kt] == nil { + typeNameDurs[kt] = make(map[string][]int64) + } + typeNameDurs[kt][name] = append(typeNameDurs[kt][name], dur) + typeDurSum[kt] += dur + } + + // Accumulate type sums across ranks. + for _, kt := range typesToAnalyze { + typeSumMap[kt.String()] += typeDurSum[kt] + } + + // For each type, compute top kernels with aggregation. + for _, kt := range typesToAnalyze { + nameDurs := typeNameDurs[kt] + if len(nameDurs) == 0 { + continue + } + + stats := aggrAndComputeStats(nameDurs, opts.NumKernels, opts.DurationRatio) + for _, s := range stats { + allTopKernels = append(allTopKernels, TopKernelRow{ + Name: s.Name, + SumUs: s.SumUs, + MaxUs: s.MaxUs, + MinUs: s.MinUs, + MeanUs: s.MeanUs, + Stddev: s.Stddev, + KernelType: kt.String(), + Rank: rank, + }) + } + } + } + + // Build type breakdown with percentages. + var totalSum int64 + for _, kt := range typesToAnalyze { + totalSum += typeSumMap[kt.String()] + } + + var typeBreakdown []KernelTypeRow + for _, kt := range typesToAnalyze { + s := typeSumMap[kt.String()] + pctg := 0.0 + if totalSum > 0 { + pctg = analysis.RoundTo(100*float64(s)/float64(totalSum), 1) + } + typeBreakdown = append(typeBreakdown, KernelTypeRow{ + KernelType: kt.String(), + SumUs: s, + Percentage: pctg, + }) + } + + // Sort by sum descending. + sort.Slice(typeBreakdown, func(i, j int) bool { + return typeBreakdown[i].SumUs > typeBreakdown[j].SumUs + }) + + // Sort top kernels by (kernel_type, name, rank). + sort.Slice(allTopKernels, func(i, j int) bool { + if allTopKernels[i].KernelType != allTopKernels[j].KernelType { + return allTopKernels[i].KernelType < allTopKernels[j].KernelType + } + if allTopKernels[i].Name != allTopKernels[j].Name { + return allTopKernels[i].Name < allTopKernels[j].Name + } + return allTopKernels[i].Rank < allTopKernels[j].Rank + }) + + return &GPUKernelBreakdownResult{ + TypeBreakdown: typeBreakdown, + TopKernels: allTopKernels, + }, nil +} + +// computeStats computes sum/max/min/mean/stddev over a slice of durations. +// Uses ddof=1 for stddev (matching pandas default). Single element yields stddev=0. +func computeStats(name string, durations []int64) kernelStats { + if len(durations) == 0 { + return kernelStats{Name: name} + } + + sum := durations[0] + minVal := durations[0] + maxVal := durations[0] + for _, d := range durations[1:] { + sum += d + if d > maxVal { + maxVal = d + } + if d < minVal { + minVal = d + } + } + + mean := sum / int64(len(durations)) + + var stddev float64 + if len(durations) > 1 { + fMean := float64(sum) / float64(len(durations)) + var sumSqDiff float64 + for _, d := range durations { + diff := float64(d) - fMean + sumSqDiff += diff * diff + } + stddev = math.Sqrt(sumSqDiff / float64(len(durations)-1)) + } + + return kernelStats{ + Name: name, + SumUs: sum, + MaxUs: maxVal, + MinUs: minVal, + MeanUs: mean, + Stddev: stddev, + } +} + +// aggrAndComputeStats groups kernels by name, sorts by sum descending, and +// aggregates tail kernels into "others" when there are more than numKernels +// unique names. The durationRatio quantile threshold is applied to cumulative sums. +func aggrAndComputeStats(nameDurations map[string][]int64, numKernels int, durationRatio float64) []kernelStats { + type nameSum struct { + name string + sum int64 + } + names := make([]nameSum, 0, len(nameDurations)) + for name, durs := range nameDurations { + var s int64 + for _, d := range durs { + s += d + } + names = append(names, nameSum{name, s}) + } + + // Sort by sum descending. + sort.Slice(names, func(i, j int) bool { + return names[i].sum > names[j].sum + }) + + // No aggregation needed. + if len(names) <= numKernels { + result := make([]kernelStats, len(names)) + for i, ns := range names { + result[i] = computeStats(ns.name, nameDurations[ns.name]) + } + return result + } + + // Compute cumulative sums for quantile calculation. + cumSums := make([]float64, len(names)) + cumSums[0] = float64(names[0].sum) + for i := 1; i < len(names); i++ { + cumSums[i] = cumSums[i-1] + float64(names[i].sum) + } + quantile := analysis.QuantileLinear(cumSums, durationRatio) + + // Partition into kept names and "others". + var otherDurs []int64 + var result []kernelStats + for i, ns := range names { + isOther := cumSums[i] > quantile || i >= numKernels + if isOther { + otherDurs = append(otherDurs, nameDurations[ns.name]...) + } else { + result = append(result, computeStats(ns.name, nameDurations[ns.name])) + } + } + + if len(otherDurs) > 0 { + result = append(result, computeStats("others", otherDurs)) + } + + return result +} + diff --git a/pkg/analysis/kernel/kernel_breakdown_test.go b/pkg/analysis/kernel/kernel_breakdown_test.go new file mode 100644 index 0000000..b5c6a83 --- /dev/null +++ b/pkg/analysis/kernel/kernel_breakdown_test.go @@ -0,0 +1,214 @@ +package kernel + +import ( + "math" + "testing" + + "hta/pkg/analysis" +) + +func TestQuantileLinear(t *testing.T) { + t.Parallel() + tests := []struct { + name string + sorted []float64 + q float64 + want float64 + }{ + {"empty", nil, 0.5, 0}, + {"single", []float64{42}, 0.5, 42}, + {"single q=0", []float64{42}, 0.0, 42}, + {"single q=1", []float64{42}, 1.0, 42}, + {"two elements q=0", []float64{10, 20}, 0.0, 10}, + {"two elements q=1", []float64{10, 20}, 1.0, 20}, + {"two elements q=0.5", []float64{10, 20}, 0.5, 15}, + {"three elements q=0.5", []float64{10, 20, 30}, 0.5, 20}, + {"three elements q=0.8", []float64{10, 20, 30}, 0.8, 26}, + {"five elements q=0.8", []float64{1, 2, 3, 4, 5}, 0.8, 4.2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := analysis.QuantileLinear(tt.sorted, tt.q) + if math.Abs(got-tt.want) > 1e-9 { + t.Errorf("QuantileLinear(%v, %v) = %v, want %v", tt.sorted, tt.q, got, tt.want) + } + }) + } +} + +func TestComputeStats(t *testing.T) { + t.Parallel() + t.Run("single element", func(t *testing.T) { + s := computeStats("kern", []int64{100}) + if s.SumUs != 100 || s.MaxUs != 100 || s.MinUs != 100 || s.MeanUs != 100 { + t.Errorf("unexpected stats for single element: %+v", s) + } + if s.Stddev != 0 { + t.Errorf("stddev for single element should be 0, got %v", s.Stddev) + } + }) + + t.Run("multiple elements", func(t *testing.T) { + // durations: 10, 20, 30, 40 + // sum=100, max=40, min=10, mean=25 + // variance (ddof=1) = ((10-25)^2 + (20-25)^2 + (30-25)^2 + (40-25)^2) / 3 + // = (225 + 25 + 25 + 225) / 3 = 500/3 + // stddev = sqrt(500/3) ≈ 12.9099 + s := computeStats("kern", []int64{10, 20, 30, 40}) + if s.SumUs != 100 { + t.Errorf("sum: got %d, want 100", s.SumUs) + } + if s.MaxUs != 40 { + t.Errorf("max: got %d, want 40", s.MaxUs) + } + if s.MinUs != 10 { + t.Errorf("min: got %d, want 10", s.MinUs) + } + if s.MeanUs != 25 { + t.Errorf("mean: got %d, want 25", s.MeanUs) + } + wantStd := math.Sqrt(500.0 / 3.0) + if math.Abs(s.Stddev-wantStd) > 1e-6 { + t.Errorf("stddev: got %v, want %v", s.Stddev, wantStd) + } + }) + + t.Run("empty", func(t *testing.T) { + s := computeStats("empty", nil) + if s.SumUs != 0 || s.MaxUs != 0 || s.MinUs != 0 || s.MeanUs != 0 || s.Stddev != 0 { + t.Errorf("unexpected stats for empty: %+v", s) + } + }) +} + +func TestAggrAndComputeStats(t *testing.T) { + t.Parallel() + t.Run("no aggregation needed", func(t *testing.T) { + nameDurs := map[string][]int64{ + "kernA": {100, 200}, + "kernB": {50}, + } + stats := aggrAndComputeStats(nameDurs, 10, 0.8) + if len(stats) != 2 { + t.Fatalf("expected 2 stats, got %d", len(stats)) + } + // Should be sorted by sum descending: kernA (300) before kernB (50) + if stats[0].Name != "kernA" || stats[0].SumUs != 300 { + t.Errorf("first should be kernA with sum 300, got %s with sum %d", stats[0].Name, stats[0].SumUs) + } + if stats[1].Name != "kernB" || stats[1].SumUs != 50 { + t.Errorf("second should be kernB with sum 50, got %s with sum %d", stats[1].Name, stats[1].SumUs) + } + }) + + t.Run("aggregation into others", func(t *testing.T) { + // 15 kernels, numKernels=3 — everything past top 3 becomes "others" + nameDurs := make(map[string][]int64) + for i := 0; i < 15; i++ { + name := string(rune('A' + i)) + nameDurs[name] = []int64{int64(100 - i*5)} + } + stats := aggrAndComputeStats(nameDurs, 3, 0.8) + + // Should have at most 4 entries (3 kept + "others") + if len(stats) > 4 { + t.Errorf("expected at most 4 entries, got %d", len(stats)) + } + + // Check "others" exists + hasOthers := false + for _, s := range stats { + if s.Name == "others" { + hasOthers = true + } + } + if !hasOthers { + t.Error("expected 'others' entry") + } + }) +} + +func TestGPUKernelBreakdownIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + opts := KernelBreakdownOpts{ + DurationRatio: 0.8, + NumKernels: 10, + IncludeMemory: false, + } + result, err := GPUKernelBreakdown(db, opts) + if err != nil { + t.Fatalf("gpu kernel breakdown: %v", err) + } + + // Verify type breakdown. + if len(result.TypeBreakdown) == 0 { + t.Fatal("expected non-empty type breakdown") + } + + // Check that COMPUTATION and COMMUNICATION are present. + typeMap := make(map[string]KernelTypeRow) + for _, r := range result.TypeBreakdown { + typeMap[r.KernelType] = r + } + if _, ok := typeMap["COMPUTATION"]; !ok { + t.Error("COMPUTATION type not found") + } + if _, ok := typeMap["COMMUNICATION"]; !ok { + t.Error("COMMUNICATION type not found") + } + + // Percentages should sum to approximately 100%. + var pctgSum float64 + for _, r := range result.TypeBreakdown { + pctgSum += r.Percentage + } + if math.Abs(pctgSum-100.0) > 1.0 { + t.Errorf("percentages sum to %.1f, want ~100", pctgSum) + } + + // Type breakdown should be sorted by sum descending. + for i := 1; i < len(result.TypeBreakdown); i++ { + if result.TypeBreakdown[i].SumUs > result.TypeBreakdown[i-1].SumUs { + t.Errorf("type breakdown not sorted by sum desc at index %d", i) + } + } + + // Verify top kernels. + if len(result.TopKernels) == 0 { + t.Fatal("expected non-empty top kernels") + } + + // Top kernels should be sorted by (kernel_type, name, rank). + for i := 1; i < len(result.TopKernels); i++ { + a, b := result.TopKernels[i-1], result.TopKernels[i] + if a.KernelType > b.KernelType { + t.Errorf("top kernels not sorted by kernel_type at index %d", i) + } + if a.KernelType == b.KernelType && a.Name > b.Name { + t.Errorf("top kernels not sorted by name at index %d", i) + } + if a.KernelType == b.KernelType && a.Name == b.Name && a.Rank > b.Rank { + t.Errorf("top kernels not sorted by rank at index %d", i) + } + } + + // Stats consistency checks. + for _, tk := range result.TopKernels { + if tk.SumUs < tk.MinUs { + t.Errorf("kernel %s: sum (%d) < min (%d)", tk.Name, tk.SumUs, tk.MinUs) + } + if tk.MaxUs < tk.MinUs { + t.Errorf("kernel %s: max (%d) < min (%d)", tk.Name, tk.MaxUs, tk.MinUs) + } + if tk.Stddev < 0 { + t.Errorf("kernel %s: negative stddev (%.2f)", tk.Name, tk.Stddev) + } + } +} diff --git a/pkg/analysis/kernel/kernel_sequences.go b/pkg/analysis/kernel/kernel_sequences.go new file mode 100644 index 0000000..ad7dd9a --- /dev/null +++ b/pkg/analysis/kernel/kernel_sequences.go @@ -0,0 +1,500 @@ +package kernel + +import ( + "compress/gzip" + "database/sql" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "hta/pkg/store" +) + +// KernelSeqOpts controls the frequent CUDA kernel sequences analysis. +type KernelSeqOpts struct { + OperatorName string // required: substring to match CPU operator names + OutputDir string // optional: directory for overlay trace output + MinPatternLen int // minimum pattern length (operator + kernels); default 3 + Rank int // rank to analyze; default 0 + TopK int // top-K patterns to return; default 5 +} + +// PatternResult holds one frequent kernel sequence pattern. +type PatternResult struct { + Pattern string // "op_name|kernel1|kernel2|..." + Count int + GPUKernelDurUs int64 + CPUOpDurUs int64 +} + +// patternKey is a sortable tuple of name IDs used internally. +type patternKey struct { + ids string // comma-joined name IDs for map keying +} + +// patternData tracks accumulated data for one pattern. +type patternData struct { + nameIDs []int + count int + gpuDurSum int64 + cpuDurSum int64 + eventIDs map[int64]struct{} // trace_event IDs for overlay +} + +// FrequentCUDAKernelSequences finds frequent GPU kernel launch patterns +// under CPU operators matching opts.OperatorName. +func FrequentCUDAKernelSequences(db *sql.DB, opts KernelSeqOpts) ([]PatternResult, error) { + if opts.MinPatternLen == 0 { + opts.MinPatternLen = 3 + } + if opts.TopK == 0 { + opts.TopK = 5 + } + + // 1. Load symbol table. + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + // 2. Find name IDs matching the operator name substring. + var matchingNameIDs []int + for _, entry := range sym.All() { + if strings.Contains(entry.Name, opts.OperatorName) { + matchingNameIDs = append(matchingNameIDs, entry.ID) + } + } + if len(matchingNameIDs) == 0 { + return nil, nil + } + + // 3. Load CPU operators. + cpuOps, err := store.LoadCPUOperators(db, opts.Rank, matchingNameIDs) + if err != nil { + return nil, fmt.Errorf("loading CPU operators: %w", err) + } + if len(cpuOps) == 0 { + return nil, nil + } + + // 4. Find root operators (not nested in another matching operator on same thread). + roots := findRootOperators(cpuOps) + + // 5. Resolve CUDA launch event symbol IDs. + launchNames := []string{"cudaLaunchKernel", "cudaLaunchKernelExC"} + var launchNameIDs []int + for _, name := range launchNames { + id := sym.GetID(name) + if id >= 0 { + launchNameIDs = append(launchNameIDs, id) + } + } + if len(launchNameIDs) == 0 { + return nil, nil + } + + // 6. Load CPU launch events with thread info. + cpuLaunches, err := store.LoadCPURuntimeLaunchesWithThread(db, opts.Rank, launchNameIDs) + if err != nil { + return nil, fmt.Errorf("loading CPU launch events: %w", err) + } + + // 7. Load GPU kernels with correlation. + gpuKernels, err := store.LoadGPUKernelsWithCorrelation(db, opts.Rank) + if err != nil { + return nil, fmt.Errorf("loading GPU kernels: %w", err) + } + + // 8. Build gpuByCorr map: correlation → GPU kernel. + gpuByCorr := make(map[int]store.GPUKernelCorrelationRow, len(gpuKernels)) + for _, k := range gpuKernels { + gpuByCorr[k.Correlation] = k + } + + // 9. Group CPU launch events by thread. + launchesByThread := make(map[int64][]store.CPURuntimeLaunchRow) + for _, l := range cpuLaunches { + launchesByThread[l.ThreadID] = append(launchesByThread[l.ThreadID], l) + } + + // 10. For each root operator, find GPU kernels and build patterns. + patterns := make(map[string]*patternData) + for _, op := range roots { + threadLaunches := launchesByThread[op.ThreadID] + if len(threadLaunches) == 0 { + continue + } + + // Binary search for the first launch event >= op.StartedAt on this thread. + startIdx := sort.Search(len(threadLaunches), func(i int) bool { + return threadLaunches[i].StartedAt >= op.StartedAt + }) + + // Collect GPU kernels launched within [op.StartedAt, op.EndedAt]. + type gpuInfo struct { + startedAt int64 + nameID int + duration int64 + } + var gpuKernelsInOp []gpuInfo + for i := startIdx; i < len(threadLaunches); i++ { + l := threadLaunches[i] + if l.StartedAt > op.EndedAt { + break + } + gpu, ok := gpuByCorr[l.Correlation] + if !ok { + continue + } + gpuKernelsInOp = append(gpuKernelsInOp, gpuInfo{ + startedAt: gpu.StartedAt, + nameID: gpu.NameID, + duration: gpu.Duration, + }) + } + + // Sort GPU kernels by start time. + sort.Slice(gpuKernelsInOp, func(i, j int) bool { + return gpuKernelsInOp[i].startedAt < gpuKernelsInOp[j].startedAt + }) + + // Pattern = [operator_name_id, kernel1_id, kernel2_id, ...]. + // Skip if too short. + if len(gpuKernelsInOp)+1 < opts.MinPatternLen { + continue + } + + // Build pattern key. + nameIDs := make([]int, 0, len(gpuKernelsInOp)+1) + nameIDs = append(nameIDs, op.NameID) + var gpuDurTotal int64 + for _, g := range gpuKernelsInOp { + nameIDs = append(nameIDs, g.nameID) + gpuDurTotal += g.duration + } + + key := patternKeyString(nameIDs) + pd, exists := patterns[key] + if !exists { + pd = &patternData{ + nameIDs: nameIDs, + eventIDs: make(map[int64]struct{}), + } + patterns[key] = pd + } + pd.count++ + pd.gpuDurSum += gpuDurTotal + pd.cpuDurSum += op.Duration + pd.eventIDs[op.ID] = struct{}{} + } + + if len(patterns) == 0 { + return nil, nil + } + + // 11. Sort patterns: count DESC, then pattern string ASC. + type sortEntry struct { + key string + data *patternData + } + sorted := make([]sortEntry, 0, len(patterns)) + for k, d := range patterns { + sorted = append(sorted, sortEntry{key: k, data: d}) + } + + // Pre-resolve pattern strings for sorting. + patternStrings := make(map[string]string, len(sorted)) + for _, e := range sorted { + patternStrings[e.key] = resolvePatternString(sym, e.data.nameIDs) + } + + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].data.count != sorted[j].data.count { + return sorted[i].data.count > sorted[j].data.count + } + return patternStrings[sorted[i].key] < patternStrings[sorted[j].key] + }) + + // 12. Select top-K. + topK := min(opts.TopK, len(sorted)) + sorted = sorted[:topK] + + // 13. Build results. + results := make([]PatternResult, len(sorted)) + for i, e := range sorted { + results[i] = PatternResult{ + Pattern: patternStrings[e.key], + Count: e.data.count, + GPUKernelDurUs: e.data.gpuDurSum, + CPUOpDurUs: e.data.cpuDurSum, + } + } + + // 14. Write overlay trace if output dir specified. + if opts.OutputDir != "" { + topPatterns := make([]*patternData, len(sorted)) + topPatternStrings := make([]string, len(sorted)) + for i, e := range sorted { + topPatterns[i] = e.data + topPatternStrings[i] = patternStrings[e.key] + } + if err := overlayFrequentPatterns(db, opts.Rank, opts.OutputDir, topPatterns, topPatternStrings); err != nil { + return results, fmt.Errorf("writing overlay trace: %w", err) + } + } + + return results, nil +} + +// findRootOperators selects root (non-nested) operators from a list sorted by (thread_id, started_at). +func findRootOperators(ops []store.CPUOperatorRow) []store.CPUOperatorRow { + var roots []store.CPUOperatorRow + var curThread int64 = -1 + var curEnd int64 = -1 + + for _, op := range ops { + if op.ThreadID != curThread || op.StartedAt > curEnd { + roots = append(roots, op) + curThread = op.ThreadID + curEnd = op.EndedAt + } + } + return roots +} + +// patternKeyString builds a map key from a slice of name IDs. +func patternKeyString(nameIDs []int) string { + parts := make([]string, len(nameIDs)) + for i, id := range nameIDs { + parts[i] = fmt.Sprintf("%d", id) + } + return strings.Join(parts, ",") +} + +// resolvePatternString converts name IDs to human-readable "op|kernel1|kernel2|..." string. +func resolvePatternString(sym interface{ GetName(int) (string, error) }, nameIDs []int) string { + parts := make([]string, len(nameIDs)) + for i, id := range nameIDs { + name, err := sym.GetName(id) + if err != nil { + name = fmt.Sprintf("unknown_%d", id) + } + parts[i] = name + } + return strings.Join(parts, "|") +} + +// overlayFrequentPatterns reads the raw trace file and annotates events belonging +// to the top-K patterns, writing the result to output_dir/overlaid_. +func overlayFrequentPatterns(db *sql.DB, rank int, outputDir string, topPatterns []*patternData, patternStrings []string) error { + // Get trace file path. + tracePath, err := store.TraceFile(db, rank) + if err != nil { + return fmt.Errorf("loading trace file path: %w", err) + } + + // Read raw trace. + rawData, err := ReadTraceFile(tracePath) + if err != nil { + return fmt.Errorf("reading trace file %s: %w", tracePath, err) + } + + // Parse top-level JSON. + var traceDoc map[string]json.RawMessage + if err := json.Unmarshal(rawData, &traceDoc); err != nil { + return fmt.Errorf("parsing trace JSON: %w", err) + } + + eventsRaw, ok := traceDoc["traceEvents"] + if !ok { + return fmt.Errorf("traceEvents key not found in trace file") + } + + var events []map[string]any + if err := json.Unmarshal(eventsRaw, &events); err != nil { + return fmt.Errorf("parsing traceEvents: %w", err) + } + + // Get DB min timestamp to compute offset. + dbMinTs, err := store.MinStartedAt(db, rank) + if err != nil { + return fmt.Errorf("loading min started_at: %w", err) + } + + // Find min timestamp in raw trace events. + var rawMinTs float64 + first := true + for _, ev := range events { + ts, ok := ev["ts"] + if !ok { + continue + } + var tsFloat float64 + switch v := ts.(type) { + case float64: + tsFloat = v + case json.Number: + tsFloat, _ = v.Float64() + default: + continue + } + if first || tsFloat < rawMinTs { + rawMinTs = tsFloat + first = false + } + } + + offset := rawMinTs - float64(dbMinTs) + + // Build event ID → pattern info mapping from top patterns. + type patternInfo struct { + patterns map[string]int // pattern string → count + } + eventPatterns := make(map[int64]*patternInfo) + for i, pd := range topPatterns { + for eid := range pd.eventIDs { + info, ok := eventPatterns[eid] + if !ok { + info = &patternInfo{patterns: make(map[string]int)} + eventPatterns[eid] = info + } + info.patterns[patternStrings[i]] = pd.count + } + } + + // Annotate matching events by matching (ts, name) in raw trace. + // Build lookup by adjusted timestamp. + type dbEvent struct { + adjustedTs float64 + nameID int + info *patternInfo + } + + // Load all operator events for the rank to get their IDs, timestamps, name IDs. + // We match raw trace events to DB events via adjusted timestamp. + matchRows, err := db.Query( + "SELECT id, started_at, name FROM trace_event WHERE gpu_rank = ?", + rank, + ) + if err != nil { + return fmt.Errorf("querying events for overlay: %w", err) + } + defer matchRows.Close() + + type idTsName struct { + id int64 + ts float64 // adjusted (DB ts + offset) + nameID int + } + var dbEvents []idTsName + for matchRows.Next() { + var e idTsName + var rawStartedAt int64 + if err := matchRows.Scan(&e.id, &rawStartedAt, &e.nameID); err != nil { + return fmt.Errorf("scanning event: %w", err) + } + e.ts = float64(rawStartedAt) + offset + if _, ok := eventPatterns[e.id]; ok { + dbEvents = append(dbEvents, e) + } + } + if err := matchRows.Err(); err != nil { + return fmt.Errorf("iterating events: %w", err) + } + + // Build a lookup by rounded timestamp for matching. + type tsKey struct { + ts int64 // rounded to nearest int + } + tsLookup := make(map[tsKey]*patternInfo) + for _, e := range dbEvents { + info := eventPatterns[e.id] + key := tsKey{ts: int64(e.ts + 0.5)} // round to nearest + tsLookup[key] = info + } + + // Annotate raw events. + for i, ev := range events { + ts, ok := ev["ts"] + if !ok { + continue + } + var tsFloat float64 + switch v := ts.(type) { + case float64: + tsFloat = v + case json.Number: + tsFloat, _ = v.Float64() + default: + continue + } + + key := tsKey{ts: int64(tsFloat + 0.5)} + info, ok := tsLookup[key] + if !ok { + continue + } + + args, _ := ev["args"].(map[string]any) + if args == nil { + args = make(map[string]any) + } + args["Patterns"] = info.patterns + events[i]["args"] = args + } + + // Marshal back. + annotatedEvents, err := json.Marshal(events) + if err != nil { + return fmt.Errorf("marshaling annotated events: %w", err) + } + traceDoc["traceEvents"] = annotatedEvents + outputData, err := json.Marshal(traceDoc) + if err != nil { + return fmt.Errorf("marshaling output trace: %w", err) + } + + // Write output file. + baseName := filepath.Base(tracePath) + // Strip .gz if present. + baseName, _ = strings.CutSuffix(baseName, ".gz") + outputFile := filepath.Join(outputDir, "overlaid_"+baseName) + if err := os.WriteFile(outputFile, outputData, 0644); err != nil { + return fmt.Errorf("writing overlay file: %w", err) + } + + return nil +} + +// ReadTraceFile reads a trace file, handling .json.gz compression. +func ReadTraceFile(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + if strings.HasSuffix(path, ".gz") { + gr, err := gzip.NewReader(f) + if err != nil { + return nil, err + } + defer gr.Close() + var buf strings.Builder + b := make([]byte, 32*1024) + for { + n, readErr := gr.Read(b) + if n > 0 { + buf.Write(b[:n]) + } + if readErr != nil { + break + } + } + return []byte(buf.String()), nil + } + return os.ReadFile(path) +} diff --git a/pkg/analysis/kernel/kernel_sequences_test.go b/pkg/analysis/kernel/kernel_sequences_test.go new file mode 100644 index 0000000..89ca1ad --- /dev/null +++ b/pkg/analysis/kernel/kernel_sequences_test.go @@ -0,0 +1,230 @@ +package kernel + +import ( + "os" + "testing" + + "hta/pkg/store" +) + +func TestFindRootOperators(t *testing.T) { + t.Parallel() + tests := []struct { + name string + ops []store.CPUOperatorRow + want []store.CPUOperatorRow + }{ + { + name: "nested ops on same thread", + ops: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 500, ThreadID: 1}, + {ID: 2, StartedAt: 200, EndedAt: 400, ThreadID: 1}, // nested in 1 + {ID: 3, StartedAt: 300, EndedAt: 350, ThreadID: 1}, // nested in 1 + }, + want: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 500, ThreadID: 1}, + }, + }, + { + name: "non-overlapping ops on same thread", + ops: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + {ID: 2, StartedAt: 300, EndedAt: 400, ThreadID: 1}, + {ID: 3, StartedAt: 500, EndedAt: 600, ThreadID: 1}, + }, + want: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + {ID: 2, StartedAt: 300, EndedAt: 400, ThreadID: 1}, + {ID: 3, StartedAt: 500, EndedAt: 600, ThreadID: 1}, + }, + }, + { + name: "multiple threads independent", + ops: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 500, ThreadID: 1}, + {ID: 2, StartedAt: 200, EndedAt: 400, ThreadID: 1}, // nested in 1 + {ID: 3, StartedAt: 100, EndedAt: 500, ThreadID: 2}, // different thread, root + {ID: 4, StartedAt: 200, EndedAt: 300, ThreadID: 2}, // nested in 3 + }, + want: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 500, ThreadID: 1}, + {ID: 3, StartedAt: 100, EndedAt: 500, ThreadID: 2}, + }, + }, + { + name: "empty input", + ops: nil, + want: nil, + }, + { + name: "single op", + ops: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + }, + want: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + }, + }, + { + name: "boundary: starts exactly at curEnd", + ops: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + {ID: 2, StartedAt: 200, EndedAt: 300, ThreadID: 1}, // starts at curEnd, nested + }, + want: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + }, + }, + { + name: "boundary: starts one after curEnd", + ops: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + {ID: 2, StartedAt: 201, EndedAt: 300, ThreadID: 1}, // starts after curEnd, new root + }, + want: []store.CPUOperatorRow{ + {ID: 1, StartedAt: 100, EndedAt: 200, ThreadID: 1}, + {ID: 2, StartedAt: 201, EndedAt: 300, ThreadID: 1}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := findRootOperators(tt.ops) + if len(got) != len(tt.want) { + t.Fatalf("got %d roots, want %d", len(got), len(tt.want)) + } + for i := range got { + if got[i].ID != tt.want[i].ID { + t.Errorf("root[%d].ID = %d, want %d", i, got[i].ID, tt.want[i].ID) + } + } + }) + } +} + +func TestFrequentCUDAKernelSequences(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + opts := KernelSeqOpts{ + OperatorName: "aten::", + MinPatternLen: 1, + Rank: 0, + TopK: 5, + } + results, err := FrequentCUDAKernelSequences(db, opts) + if err != nil { + t.Fatalf("frequent cuda kernel sequences: %v", err) + } + + if len(results) == 0 { + t.Fatal("expected non-empty results") + } + + // Verify sorted by count DESC, then pattern ASC. + for i := 1; i < len(results); i++ { + prev := results[i-1] + cur := results[i] + if prev.Count < cur.Count { + t.Errorf("results not sorted by count DESC: [%d].Count=%d < [%d].Count=%d", + i-1, prev.Count, i, cur.Count) + } + if prev.Count == cur.Count && prev.Pattern > cur.Pattern { + t.Errorf("results with same count not sorted by pattern ASC: [%d].Pattern=%q > [%d].Pattern=%q", + i-1, prev.Pattern, i, cur.Pattern) + } + } + + // All counts should be positive. + for i, r := range results { + if r.Count <= 0 { + t.Errorf("result[%d].Count = %d, want > 0", i, r.Count) + } + if r.Pattern == "" { + t.Errorf("result[%d].Pattern is empty", i) + } + } + + // Should not exceed TopK. + if len(results) > opts.TopK { + t.Errorf("got %d results, want <= %d (TopK)", len(results), opts.TopK) + } +} + +func TestFrequentCUDAKernelSequencesNoMatch(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + opts := KernelSeqOpts{ + OperatorName: "nonexistent_operator_xyz_12345", + MinPatternLen: 1, + Rank: 0, + TopK: 5, + } + results, err := FrequentCUDAKernelSequences(db, opts) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 0 { + t.Errorf("expected empty results for nonexistent operator, got %d", len(results)) + } +} + +func TestFrequentCUDAKernelSequencesOverlay(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + outputDir := t.TempDir() + opts := KernelSeqOpts{ + OperatorName: "aten::", + OutputDir: outputDir, + MinPatternLen: 1, + Rank: 0, + TopK: 5, + } + results, err := FrequentCUDAKernelSequences(db, opts) + if err != nil { + t.Fatalf("frequent cuda kernel sequences: %v", err) + } + + if len(results) == 0 { + t.Fatal("expected non-empty results to test overlay") + } + + // Check that an overlay file was created. + entries, err := os.ReadDir(outputDir) + if err != nil { + t.Fatalf("reading output dir: %v", err) + } + found := false + for _, e := range entries { + if len(e.Name()) > 9 && e.Name()[:9] == "overlaid_" { + found = true + // Verify file is non-empty. + info, err := e.Info() + if err != nil { + t.Fatalf("stat overlay file: %v", err) + } + if info.Size() == 0 { + t.Error("overlay file is empty") + } + break + } + } + if !found { + t.Error("no overlaid_* file found in output directory") + } +} diff --git a/pkg/analysis/kernel/launch_stats.go b/pkg/analysis/kernel/launch_stats.go new file mode 100644 index 0000000..8a30a7b --- /dev/null +++ b/pkg/analysis/kernel/launch_stats.go @@ -0,0 +1,105 @@ +package kernel + +import ( + "database/sql" + "fmt" + + "hta/pkg/store" +) + +// LaunchStatsOpts controls the CUDA kernel launch statistics analysis. +type LaunchStatsOpts struct { + Ranks []int // nil = all ranks + RuntimeCutoff int // µs, default 50 (informational; not used in filtering here) + LaunchDelayCutoff int // µs, default 100 (informational; not used in filtering here) + IncludeMemory bool // include cudaMemcpyAsync/cudaMemsetAsync events +} + +// LaunchStatRow holds a single joined CPU-GPU kernel launch pair. +type LaunchStatRow struct { + Correlation int + CPUDuration int64 + GPUDuration int64 + LaunchDelay int64 +} + +// CUDAKernelLaunchStats computes per-rank CPU-to-GPU kernel launch overhead +// by joining CPU runtime events with their correlated GPU kernels. +func CUDAKernelLaunchStats(db *sql.DB, opts LaunchStatsOpts) (map[int][]LaunchStatRow, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + // Resolve runtime event name IDs. + runtimeNames := []string{"cudaLaunchKernel", "cudaLaunchKernelExC"} + if opts.IncludeMemory { + runtimeNames = append(runtimeNames, "cudaMemcpyAsync", "cudaMemsetAsync") + } + var nameIDs []int + for _, name := range runtimeNames { + id := sym.GetID(name) + if id >= 0 { + nameIDs = append(nameIDs, id) + } + } + if len(nameIDs) == 0 { + return nil, fmt.Errorf("no runtime symbols found in symbol table") + } + + // Determine ranks. + ranks := opts.Ranks + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + result := make(map[int][]LaunchStatRow, len(ranks)) + for _, rank := range ranks { + rows, err := launchStatsPerRank(db, rank, nameIDs) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + result[rank] = rows + } + return result, nil +} + +func launchStatsPerRank(db *sql.DB, rank int, nameIDs []int) ([]LaunchStatRow, error) { + // Load CPU runtime events and build a correlation map. + cpuEvents, err := store.LoadCPURuntimeEvents(db, rank, nameIDs) + if err != nil { + return nil, fmt.Errorf("loading CPU runtime events: %w", err) + } + cpuByCorr := make(map[int]store.CPURuntimeRow, len(cpuEvents)) + for _, e := range cpuEvents { + cpuByCorr[e.Correlation] = e + } + + // Load GPU kernels with correlation and join. + gpuEvents, err := store.LoadGPUKernelsWithCorrelation(db, rank) + if err != nil { + return nil, fmt.Errorf("loading GPU kernels: %w", err) + } + + var rows []LaunchStatRow + for _, gpu := range gpuEvents { + cpu, ok := cpuByCorr[gpu.Correlation] + if !ok { + continue + } + delay := gpu.StartedAt - (cpu.StartedAt + cpu.Duration) + if delay < 0 { + delay = 0 + } + rows = append(rows, LaunchStatRow{ + Correlation: gpu.Correlation, + CPUDuration: cpu.Duration, + GPUDuration: gpu.Duration, + LaunchDelay: delay, + }) + } + return rows, nil +} diff --git a/pkg/analysis/kernel/launch_stats_test.go b/pkg/analysis/kernel/launch_stats_test.go new file mode 100644 index 0000000..9fef7a2 --- /dev/null +++ b/pkg/analysis/kernel/launch_stats_test.go @@ -0,0 +1,88 @@ +package kernel + +import ( + "testing" +) + +func TestCUDAKernelLaunchStatsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + opts := LaunchStatsOpts{ + RuntimeCutoff: 50, + LaunchDelayCutoff: 100, + IncludeMemory: true, + } + result, err := CUDAKernelLaunchStats(db, opts) + if err != nil { + t.Fatalf("cuda kernel launch stats: %v", err) + } + + if len(result) == 0 { + t.Fatal("expected non-empty results") + } + + for rank, rows := range result { + if len(rows) == 0 { + t.Errorf("rank %d: expected non-empty rows", rank) + continue + } + + seen := make(map[int]bool) + for _, r := range rows { + // Launch delay must be non-negative (clipped to 0). + if r.LaunchDelay < 0 { + t.Errorf("rank %d: correlation %d has negative launch_delay %d", + rank, r.Correlation, r.LaunchDelay) + } + // CPU and GPU durations must be positive. + if r.CPUDuration <= 0 { + t.Errorf("rank %d: correlation %d has non-positive cpu_duration %d", + rank, r.Correlation, r.CPUDuration) + } + if r.GPUDuration <= 0 { + t.Errorf("rank %d: correlation %d has non-positive gpu_duration %d", + rank, r.Correlation, r.GPUDuration) + } + // Correlation IDs should be unique per rank. + if seen[r.Correlation] { + t.Errorf("rank %d: duplicate correlation %d", rank, r.Correlation) + } + seen[r.Correlation] = true + } + } +} + +func TestCUDAKernelLaunchStatsNoMemory(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + withMem := LaunchStatsOpts{IncludeMemory: true} + withMemResult, err := CUDAKernelLaunchStats(db, withMem) + if err != nil { + t.Fatalf("with memory: %v", err) + } + + noMem := LaunchStatsOpts{IncludeMemory: false} + noMemResult, err := CUDAKernelLaunchStats(db, noMem) + if err != nil { + t.Fatalf("no memory: %v", err) + } + + // Without memory events we should get fewer or equal results. + for rank, withRows := range withMemResult { + noRows := noMemResult[rank] + if len(noRows) > len(withRows) { + t.Errorf("rank %d: no-memory (%d) > with-memory (%d)", + rank, len(noRows), len(withRows)) + } + } +} diff --git a/pkg/analysis/kernel/testmain_test.go b/pkg/analysis/kernel/testmain_test.go new file mode 100644 index 0000000..4a758c2 --- /dev/null +++ b/pkg/analysis/kernel/testmain_test.go @@ -0,0 +1,86 @@ +package kernel + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + "hta/pkg/pipeline" + "hta/pkg/store" +) + +// sharedVTDBPath and sharedNSDBPath hold paths to pre-built SQLite DBs +// created once by TestMain. Each integration test opens its own read handle +// via openSharedVTDB / openSharedNSDB instead of re-running the expensive +// pipeline.RunWithDB preprocessing. +var ( + sharedVTDBPath string + sharedNSDBPath string + sharedTmpDir string +) + +func TestMain(m *testing.M) { + flag.Parse() + + if !testing.Short() { + if err := setupSharedDBs(); err != nil { + fmt.Fprintf(os.Stderr, "kernel test setup: %v\n", err) + os.Exit(1) + } + } + + code := m.Run() + + if sharedTmpDir != "" { + os.RemoveAll(sharedTmpDir) + } + os.Exit(code) +} + +func setupSharedDBs() error { + _, thisFile, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + + tmp, err := os.MkdirTemp("", "kernel-test-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + sharedTmpDir = tmp + + // Preprocess vision_transformer trace data once. + vtDir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(vtDir); err == nil { + dbPath := filepath.Join(tmp, "vt.db") + db, err := store.Create(dbPath) + if err != nil { + return fmt.Errorf("create vt db: %w", err) + } + if err := pipeline.RunWithDB(vtDir, db); err != nil { + db.Close() + return fmt.Errorf("preprocess vt: %w", err) + } + db.Close() + sharedVTDBPath = dbPath + } + + // Preprocess ns_resolution_trace data once. + nsDir := filepath.Join(root, "tests", "data", "ns_resolution_trace") + if _, err := os.Stat(nsDir); err == nil { + dbPath := filepath.Join(tmp, "ns.db") + db, err := store.Create(dbPath) + if err != nil { + return fmt.Errorf("create ns db: %w", err) + } + if err := pipeline.RunWithDB(nsDir, db); err != nil { + db.Close() + return fmt.Errorf("preprocess ns: %w", err) + } + db.Close() + sharedNSDBPath = dbPath + } + + return nil +} diff --git a/pkg/analysis/kerneltype.go b/pkg/analysis/kerneltype.go new file mode 100644 index 0000000..c787cea --- /dev/null +++ b/pkg/analysis/kerneltype.go @@ -0,0 +1,66 @@ +package analysis + +import "regexp" + +// KernelType classifies a GPU kernel. +type KernelType int + +const ( + KernelComputation KernelType = iota + KernelCommunication + KernelMemory +) + +func (k KernelType) String() string { + switch k { + case KernelComputation: + return "COMPUTATION" + case KernelCommunication: + return "COMMUNICATION" + case KernelMemory: + return "MEMORY" + default: + return "UNKNOWN" + } +} + +// Python re.match anchors at start, so alternatives without ^ must use .* prefix. +// Go MatchString finds a match anywhere, so we anchor with ^ where Python's re.match +// would implicitly anchor at position 0. +var ( + communicationRe = regexp.MustCompile(`^(?:nccl.*|hip|hccl|.*ncclKernel.*)`) + memoryRe = regexp.MustCompile(`^(?:(?:hip)?Memcpy|(?:hip)?Memset|dma)`) + // Compute is the inverse: if nonComputeRe matches → NOT compute. + nonComputeRe = regexp.MustCompile(`^(?:nccl|.*Memcpy|.*Memset|.*dma|.*Sync|Stream Wait Event|cuda.*LaunchKernel|runFunction|hip|job_exe|packet|pe_exe|cpum_exe_subgraph|hccl|collective_subgraph_execution)`) +) + +// MemoryKernelType returns the memory operation subtype for a kernel name. +// Returns "Memset" for Memset operations, "Memcpy DtoH"/"Memcpy HtoD" etc. +// for Memcpy operations, or "Memcpy Unknown" as fallback. +func MemoryKernelType(name string) string { + if len(name) >= 6 && name[:6] == "Memset" { + return "Memset" + } + if len(name) >= 6 && name[:6] == "Memcpy" { + if len(name) >= 11 { + return name[:11] + } + return name + } + return "Memcpy Unknown" +} + +// ClassifyKernel returns the kernel type for a given kernel name string. +// Priority: communication > memory > compute (inverse of nonComputeRe). +func ClassifyKernel(name string) KernelType { + if communicationRe.MatchString(name) { + return KernelCommunication + } + if memoryRe.MatchString(name) { + return KernelMemory + } + if nonComputeRe.MatchString(name) { + return KernelCommunication // non-compute, non-memory, non-comm → treated as non-compute + } + return KernelComputation +} diff --git a/pkg/analysis/kerneltype_test.go b/pkg/analysis/kerneltype_test.go new file mode 100644 index 0000000..23eec35 --- /dev/null +++ b/pkg/analysis/kerneltype_test.go @@ -0,0 +1,36 @@ +package analysis + +import "testing" + +func TestClassifyKernel(t *testing.T) { + t.Parallel() + tests := []struct { + name string + want KernelType + }{ + // Communication kernels + {"ncclAllReduceRingLLKernel_sum_f32", KernelCommunication}, + {"ncclKernel_SendRecv", KernelCommunication}, + {"nccl_anything", KernelCommunication}, + + // Memory kernels + {"Memcpy DtoH (Device -> Host)", KernelMemory}, + {"Memcpy HtoD (Host -> Device)", KernelMemory}, + {"Memset (Device)", KernelMemory}, + {"hipMemcpyAsync", KernelCommunication}, // ^hip matches comm before memory + + // Computation kernels (do NOT match nonComputeRe) + {"void at::native::vectorized_elementwise_kernel<4, ...>", KernelComputation}, + {"ampere_sgemm_128x32_tn", KernelComputation}, + {"volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn", KernelComputation}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ClassifyKernel(tt.name) + if got != tt.want { + t.Errorf("ClassifyKernel(%q) = %v, want %v", tt.name, got, tt.want) + } + }) + } +} diff --git a/pkg/analysis/mathutil.go b/pkg/analysis/mathutil.go new file mode 100644 index 0000000..2fcda28 --- /dev/null +++ b/pkg/analysis/mathutil.go @@ -0,0 +1,30 @@ +package analysis + +import "math" + +// RoundTo rounds val to the specified number of decimal places. +func RoundTo(val float64, places int) float64 { + pow := math.Pow(10, float64(places)) + return math.Round(val*pow) / pow +} + +// QuantileLinear computes a quantile using linear interpolation, matching +// pandas default: pos = q*(n-1), interpolate between floor(pos) and ceil(pos). +func QuantileLinear(sorted []float64, q float64) float64 { + n := len(sorted) + if n == 0 { + return 0 + } + if n == 1 { + return sorted[0] + } + + pos := q * float64(n-1) + lo := int(math.Floor(pos)) + hi := int(math.Ceil(pos)) + if lo == hi { + return sorted[lo] + } + frac := pos - float64(lo) + return sorted[lo]*(1-frac) + sorted[hi]*frac +} diff --git a/pkg/analysis/names.go b/pkg/analysis/names.go new file mode 100644 index 0000000..ee7ded7 --- /dev/null +++ b/pkg/analysis/names.go @@ -0,0 +1,55 @@ +package analysis + +import "strings" + +// ShortenName removes template arguments (<...>), function parameters ((...)), +// and return types from a CUDA kernel or CPU operator name. +// Memory kernel names (Memcpy/Memset) are returned unchanged. +// +// Ported from hta/utils/utils.py:shorten_name. +func ShortenName(name string) string { + if isMemoryName(name) { + return name + } + + if !strings.Contains(name, "<") && !strings.Contains(name, "(") { + return name + } + + s := strings.ReplaceAll(name, "->", "") + var stack []byte + for i := 0; i < len(s); i++ { + c := s[i] + switch c { + case '>': + // Pop back to matching '<' + for len(stack) > 0 && stack[len(stack)-1] != '<' { + stack = stack[:len(stack)-1] + } + if len(stack) > 0 && stack[len(stack)-1] == '<' { + stack = stack[:len(stack)-1] + } + case ')': + // Pop back to matching '(' + for len(stack) > 0 && stack[len(stack)-1] != '(' { + stack = stack[:len(stack)-1] + } + if len(stack) > 0 && stack[len(stack)-1] == '(' { + stack = stack[:len(stack)-1] + } + default: + stack = append(stack, c) + } + } + + result := string(stack) + if idx := strings.LastIndex(result, " "); idx >= 0 { + result = result[idx+1:] + } + return result +} + +// isMemoryName returns true for Memcpy/Memset kernel names. +func isMemoryName(name string) bool { + return strings.HasPrefix(name, "Memcpy") || strings.HasPrefix(name, "Memset") +} diff --git a/pkg/analysis/names_test.go b/pkg/analysis/names_test.go new file mode 100644 index 0000000..7c0bb67 --- /dev/null +++ b/pkg/analysis/names_test.go @@ -0,0 +1,61 @@ +package analysis + +import "testing" + +func TestShortenName(t *testing.T) { + t.Parallel() + tests := []struct { + input string + want string + }{ + // Template removal (keeps namespace prefix, removes return type via space split) + { + "void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor>", + "at::native::vectorized_elementwise_kernel", + }, + // Function param removal + { + "void func(int a, float b)", + "func", + }, + // Nested templates + { + "void foo>", + "foo", + }, + // Memcpy passthrough (memory kernels are returned unchanged) + { + "Memcpy DtoD (Device -> Device)", + "Memcpy DtoD (Device -> Device)", + }, + // Memset passthrough + { + "Memset (Device)", + "Memset (Device)", + }, + // Simple name unchanged (no templates or params) + { + "ncclKernel_AllReduce", + "ncclKernel_AllReduce", + }, + // Complex real-world name from Python test reference + { + "at::native::::multi_tensor_apply_kernel<4, at::native::(anonymous namespace)::CudaMultiTensorSGDFunctor>(int, at::native::TensorListMetadata<4>, at::native::(anonymous namespace)::CudaMultiTensorSGDFunctor)", + "at::native::::multi_tensor_apply_kernel", + }, + // Name without return type prefix + { + "cudaLaunchKernel", + "cudaLaunchKernel", + }, + // Empty string + {"", ""}, + } + + for _, tc := range tests { + got := ShortenName(tc.input) + if got != tc.want { + t.Errorf("ShortenName(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/pkg/analysis/profiler_steps.go b/pkg/analysis/profiler_steps.go new file mode 100644 index 0000000..f17186b --- /dev/null +++ b/pkg/analysis/profiler_steps.go @@ -0,0 +1,42 @@ +package analysis + +import ( + "database/sql" + "fmt" + "regexp" + "sort" + "strconv" + + "hta/pkg/store" +) + +var ProfilerStepRe = regexp.MustCompile(`ProfilerStep\s*#\s*(\d+)`) + +// ProfilerSteps extracts profiler step numbers from the symbol table. +// It returns a sorted, deduplicated slice of step numbers. +func ProfilerSteps(db *sql.DB) ([]int, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + seen := make(map[int]struct{}) + for _, entry := range sym.All() { + m := ProfilerStepRe.FindStringSubmatch(entry.Name) + if m == nil { + continue + } + n, err := strconv.Atoi(m[1]) + if err != nil { + continue + } + seen[n] = struct{}{} + } + + steps := make([]int, 0, len(seen)) + for n := range seen { + steps = append(steps, n) + } + sort.Ints(steps) + return steps, nil +} diff --git a/pkg/analysis/profiler_steps_test.go b/pkg/analysis/profiler_steps_test.go new file mode 100644 index 0000000..dc12b21 --- /dev/null +++ b/pkg/analysis/profiler_steps_test.go @@ -0,0 +1,88 @@ +package analysis + +import ( + "path/filepath" + "testing" + + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func TestProfilerStepsRegex(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string + wantN int + wantOK bool + }{ + {"basic", "ProfilerStep#0", 0, true}, + {"space before hash", "ProfilerStep #5", 5, true}, + {"extra spaces", "ProfilerStep # 10", 10, true}, + {"large number", "ProfilerStep#123", 123, true}, + {"no match kernel", "SomeOtherKernel", 0, false}, + {"no match empty", "", 0, false}, + {"partial match", "ProfilerStep", 0, false}, + {"no digit", "ProfilerStep#", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := ProfilerStepRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if m == nil { + t.Fatalf("expected match for %q, got nil", tt.input) + } + if m[1] != "" { + // Already validated by regex to be digits + var got int + for _, ch := range m[1] { + got = got*10 + int(ch-'0') + } + if got != tt.wantN { + t.Errorf("got step %d, want %d", got, tt.wantN) + } + } + } else { + if m != nil { + t.Errorf("expected no match for %q, got %v", tt.input, m) + } + } + }) + } +} + +func TestProfilerStepsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + traceDir := testDataDir(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + db, err := store.Create(dbPath) + if err != nil { + t.Fatalf("create db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Fatalf("preprocess: %v", err) + } + + steps, err := ProfilerSteps(db) + if err != nil { + t.Fatalf("profiler steps: %v", err) + } + + // vision_transformer test data has steps 15-19 + want := []int{15, 16, 17, 18, 19} + if len(steps) != len(want) { + t.Fatalf("got %d steps %v, want %d steps %v", len(steps), steps, len(want), want) + } + for i := range want { + if steps[i] != want[i] { + t.Errorf("step[%d] = %d, want %d", i, steps[i], want[i]) + } + } +} diff --git a/pkg/analysis/resource/cupti_counters.go b/pkg/analysis/resource/cupti_counters.go new file mode 100644 index 0000000..0a0caa1 --- /dev/null +++ b/pkg/analysis/resource/cupti_counters.go @@ -0,0 +1,201 @@ +package resource + +import ( + "database/sql" + "fmt" + "log" + "sort" + + "hta/pkg/store" +) + +// CUPTICounterOpts controls the CUPTI counter data analysis. +type CUPTICounterOpts struct { + Ranks []int // nil = all ranks +} + +// CUPTICounterRow holds one CUPTI kernel with its operator stack and counter values. +type CUPTICounterRow struct { + KernelName string + OpStack []string // outermost first + TopLevelOp string + BottomLevelOp string + Counters map[string]float64 +} + +// CUPTICounterData extracts CUPTI performance profiler counter data for GPU +// kernels and attributes each kernel to its CPU operator stack. +func CUPTICounterData(db *sql.DB, opts CUPTICounterOpts) (map[int][]CUPTICounterRow, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + // Check for cuda_profiler_range category. + cuptiCatID := sym.GetID("cuda_profiler_range") + if cuptiCatID < 0 { + return nil, nil // no CUPTI data + } + + // Check for cpu_op category. + cpuOpCatID := sym.GetID("cpu_op") + + // Check for cuda_runtime category. + cudaRuntimeCatID := sym.GetID("cuda_runtime") + + // Collect cudaLaunchKernel name IDs. + var launchNameIDs []int + for _, name := range []string{"cudaLaunchKernel", "cudaLaunchKernelExC"} { + id := sym.GetID(name) + if id >= 0 { + launchNameIDs = append(launchNameIDs, id) + } + } + + // Determine ranks. + ranks := opts.Ranks + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + result := make(map[int][]CUPTICounterRow, len(ranks)) + for _, rank := range ranks { + rows, err := cuptiCounterPerRank(db, rank, cuptiCatID, cpuOpCatID, cudaRuntimeCatID, launchNameIDs, sym) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + if len(rows) > 0 { + result[rank] = rows + } + } + return result, nil +} + +func cuptiCounterPerRank( + db *sql.DB, + rank int, + cuptiCatID, cpuOpCatID, cudaRuntimeCatID int, + launchNameIDs []int, + sym interface{ GetName(int) (string, error) }, +) ([]CUPTICounterRow, error) { + // 1. Load CUPTI kernels sorted by started_at. + cuptiKernels, err := store.LoadCUPTIKernels(db, rank, cuptiCatID) + if err != nil { + return nil, fmt.Errorf("loading CUPTI kernels: %w", err) + } + if len(cuptiKernels) == 0 { + return nil, nil + } + + // 2. Load cudaLaunchKernel events sorted by started_at. + if cudaRuntimeCatID < 0 || len(launchNameIDs) == 0 { + return nil, nil + } + launches, err := store.LoadCUPTILaunches(db, rank, cudaRuntimeCatID, launchNameIDs) + if err != nil { + return nil, fmt.Errorf("loading CUPTI launches: %w", err) + } + + // 3. Validate 1:1 count match. + if len(cuptiKernels) != len(launches) { + log.Printf("rank %d: CUPTI kernel count (%d) != launch count (%d), skipping", + rank, len(cuptiKernels), len(launches)) + return nil, nil + } + + // 4. Load cpu_op events grouped by thread. + var cpuOps []store.CPUOpEventRow + if cpuOpCatID >= 0 { + cpuOps, err = store.LoadCPUOpEvents(db, rank, cpuOpCatID) + if err != nil { + return nil, fmt.Errorf("loading cpu_op events: %w", err) + } + } + opsByThread := make(map[int64][]store.CPUOpEventRow) + for _, op := range cpuOps { + opsByThread[op.ThreadID] = append(opsByThread[op.ThreadID], op) + } + + // 5. Load CUPTI counter args in batch. + eventIDs := make([]int64, len(cuptiKernels)) + for i, k := range cuptiKernels { + eventIDs[i] = k.EventID + } + argsMap, err := store.LoadEventArgsBatch(db, eventIDs) + if err != nil { + return nil, fmt.Errorf("loading event args: %w", err) + } + + // 6. For each (kernel, launch) pair matched by position, build result row. + rows := make([]CUPTICounterRow, 0, len(cuptiKernels)) + for i, kernel := range cuptiKernels { + launch := launches[i] + + kernelName, err := sym.GetName(kernel.NameID) + if err != nil { + kernelName = fmt.Sprintf("unknown_%d", kernel.NameID) + } + + // Find containing cpu_op events on the launch's thread. + launchStart := launch.StartedAt + launchEnd := launch.StartedAt + launch.Duration + threadOps := opsByThread[launch.ThreadID] + + containingOps := findContainingOps(threadOps, launchStart, launchEnd) + + // Resolve op names and build stack (outermost first — already sorted by started_at). + var opStack []string + for _, op := range containingOps { + name, err := sym.GetName(op.NameID) + if err != nil { + name = fmt.Sprintf("unknown_%d", op.NameID) + } + opStack = append(opStack, name) + } + + var topLevelOp, bottomLevelOp string + if len(opStack) > 0 { + topLevelOp = opStack[0] + bottomLevelOp = opStack[len(opStack)-1] + } + + counters := argsMap[kernel.EventID] + + rows = append(rows, CUPTICounterRow{ + KernelName: kernelName, + OpStack: opStack, + TopLevelOp: topLevelOp, + BottomLevelOp: bottomLevelOp, + Counters: counters, + }) + } + return rows, nil +} + +// findContainingOps returns all ops that fully contain the interval [startedAt, endedAt]. +// ops must be sorted by StartedAt. Results are in order of StartedAt (outermost first). +func findContainingOps(ops []store.CPUOpEventRow, startedAt, endedAt int64) []store.CPUOpEventRow { + if len(ops) == 0 { + return nil + } + // Binary search: find the first op where StartedAt > startedAt. + idx := sort.Search(len(ops), func(i int) bool { + return ops[i].StartedAt > startedAt + }) + // All candidates have index < idx (StartedAt <= startedAt). + // Scan backwards to collect all containing ops. + var result []store.CPUOpEventRow + for i := idx - 1; i >= 0; i-- { + if ops[i].EndedAt >= endedAt { + result = append(result, ops[i]) + } + } + // Reverse to get outermost (earliest start) first. + for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { + result[i], result[j] = result[j], result[i] + } + return result +} diff --git a/pkg/analysis/resource/cupti_counters_test.go b/pkg/analysis/resource/cupti_counters_test.go new file mode 100644 index 0000000..be3ee76 --- /dev/null +++ b/pkg/analysis/resource/cupti_counters_test.go @@ -0,0 +1,160 @@ +package resource + +import ( + "testing" + + "hta/pkg/store" +) + +func TestCUPTICounterDataIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedCUPTIDB(t) + defer db.Close() + + result, err := CUPTICounterData(db, CUPTICounterOpts{}) + if err != nil { + t.Fatalf("cupti counter data: %v", err) + } + + if len(result) == 0 { + t.Fatal("expected non-empty results") + } + + for rank, rows := range result { + if len(rows) == 0 { + t.Errorf("rank %d: expected non-empty rows", rank) + continue + } + + for i, r := range rows { + if r.KernelName == "" { + t.Errorf("rank %d row %d: KernelName is empty", rank, i) + } + if len(r.OpStack) == 0 { + t.Errorf("rank %d row %d: OpStack is empty for kernel %q", rank, i, r.KernelName) + } + if r.TopLevelOp == "" { + t.Errorf("rank %d row %d: TopLevelOp is empty", rank, i) + } + if r.BottomLevelOp == "" { + t.Errorf("rank %d row %d: BottomLevelOp is empty", rank, i) + } + if len(r.Counters) == 0 { + t.Errorf("rank %d row %d: Counters is empty for kernel %q", rank, i, r.KernelName) + } + } + + // Verify counter keys are consistent across rows. + if len(rows) > 1 { + firstKeys := make(map[string]bool) + for k := range rows[0].Counters { + firstKeys[k] = true + } + for i := 1; i < len(rows); i++ { + for k := range rows[i].Counters { + if !firstKeys[k] { + t.Errorf("rank %d row %d: unexpected counter key %q", rank, i, k) + } + } + } + } + } +} + +func TestCUPTICounterDataNoData(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) // vision_transformer — no CUPTI data + defer db.Close() + + result, err := CUPTICounterData(db, CUPTICounterOpts{}) + if err != nil { + t.Fatalf("cupti counter data: %v", err) + } + + if len(result) != 0 { + t.Errorf("expected empty results for trace without CUPTI data, got %d ranks", len(result)) + } +} + +func TestFindContainingOps(t *testing.T) { + t.Parallel() + ops := []store.CPUOpEventRow{ + {StartedAt: 100, EndedAt: 500, NameID: 1, ThreadID: 1}, // outermost + {StartedAt: 150, EndedAt: 450, NameID: 2, ThreadID: 1}, // middle + {StartedAt: 200, EndedAt: 400, NameID: 3, ThreadID: 1}, // innermost + {StartedAt: 600, EndedAt: 800, NameID: 4, ThreadID: 1}, // non-overlapping + } + + tests := []struct { + name string + startedAt int64 + endedAt int64 + wantIDs []int // expected NameIDs of containing ops + }{ + { + name: "fully nested, all three contain", + startedAt: 200, + endedAt: 400, + wantIDs: []int{1, 2, 3}, + }, + { + name: "only outermost contains", + startedAt: 100, + endedAt: 500, + wantIDs: []int{1}, + }, + { + name: "two contain", + startedAt: 150, + endedAt: 400, + wantIDs: []int{1, 2}, + }, + { + name: "no ops contain", + startedAt: 50, + endedAt: 600, + wantIDs: nil, + }, + { + name: "non-overlapping range", + startedAt: 600, + endedAt: 800, + wantIDs: []int{4}, + }, + { + name: "empty ops", + startedAt: 100, + endedAt: 200, + wantIDs: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := ops + if tt.name == "empty ops" { + input = nil + } + got := findContainingOps(input, tt.startedAt, tt.endedAt) + var gotIDs []int + for _, op := range got { + gotIDs = append(gotIDs, op.NameID) + } + if len(gotIDs) != len(tt.wantIDs) { + t.Errorf("got %d ops %v, want %d ops %v", len(gotIDs), gotIDs, len(tt.wantIDs), tt.wantIDs) + return + } + for i := range gotIDs { + if gotIDs[i] != tt.wantIDs[i] { + t.Errorf("op[%d] nameID = %d, want %d", i, gotIDs[i], tt.wantIDs[i]) + } + } + }) + } +} diff --git a/pkg/analysis/resource/helpers_test.go b/pkg/analysis/resource/helpers_test.go new file mode 100644 index 0000000..d29d1a1 --- /dev/null +++ b/pkg/analysis/resource/helpers_test.go @@ -0,0 +1,61 @@ +package resource + +import ( + "database/sql" + "os" + "path/filepath" + "runtime" + "testing" + + "hta/pkg/store" +) + +func testDataDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + // pkg/analysis/resource/helpers_test.go → project root + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + dir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(dir); err != nil { + t.Skipf("test data not found: %s", dir) + } + return dir +} + +// openSharedVTDB opens a read handle to the pre-built vision_transformer DB. +// The DB was preprocessed once in TestMain. Each caller gets its own *sql.DB. +func openSharedVTDB(t *testing.T) *sql.DB { + t.Helper() + if sharedVTDBPath == "" { + t.Skip("vision_transformer test data not available") + } + db, err := store.Create(sharedVTDBPath) + if err != nil { + t.Fatalf("open shared VT db: %v", err) + } + return db +} + +// openSharedCUPTIDB opens a read handle to the pre-built cupti_profiler DB. +func openSharedCUPTIDB(t *testing.T) *sql.DB { + t.Helper() + if sharedCUPTIDBPath == "" { + t.Skip("CUPTI test data not available") + } + db, err := store.Create(sharedCUPTIDBPath) + if err != nil { + t.Fatalf("open shared CUPTI db: %v", err) + } + return db +} + +func cuptiTestDataDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + dir := filepath.Join(root, "tests", "data", "cupti_profiler") + if _, err := os.Stat(dir); err != nil { + t.Skipf("CUPTI test data not found: %s", dir) + } + return dir +} diff --git a/pkg/analysis/resource/memory_bw.go b/pkg/analysis/resource/memory_bw.go new file mode 100644 index 0000000..9069e4c --- /dev/null +++ b/pkg/analysis/resource/memory_bw.go @@ -0,0 +1,299 @@ +package resource + +import ( + "database/sql" + "fmt" + "math" + "sort" + + "hta/pkg/analysis" + "hta/pkg/store" + "hta/pkg/symbol" +) + +// MemoryBWPoint is a single point in the memory bandwidth time series. +type MemoryBWPoint struct { + Timestamp int64 + ProcessID int64 + Name string + MemoryBWGbps float64 +} + +// MemoryBWOpts configures memory bandwidth summary analysis. +type MemoryBWOpts struct { + Ranks []int // nil = all ranks +} + +// MemoryBWSummaryRow holds summary statistics for one (rank, memory op type) pair. +type MemoryBWSummaryRow struct { + Rank int + Name string + Count int + Mean float64 + Std float64 + Min float64 + P25 float64 + P50 float64 + P75 float64 + Max float64 +} + +// MemoryBWSummary computes summary statistics of memory bandwidth utilization +// per operation type (Memcpy DtoH/HtoD/DtoD, Memset) for each rank. +func MemoryBWSummary(db *sql.DB, opts MemoryBWOpts) ([]MemoryBWSummaryRow, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, err + } + + ranks := opts.Ranks + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, err + } + } + + var results []MemoryBWSummaryRow + for _, rank := range ranks { + rows, err := memoryBWSummaryForRank(db, sym, rank) + if err != nil { + return nil, err + } + results = append(results, rows...) + } + return results, nil +} + +// bwEvent represents a bandwidth change event at a point in time. +type bwEvent struct { + ts int64 + bw float64 + key string // memory op type +} + +func memoryBWSummaryForRank(db *sql.DB, sym *symbol.Table, rank int) ([]MemoryBWSummaryRow, error) { + kernels, err := store.LoadMemoryBWKernels(db, rank) + if err != nil { + return nil, err + } + if len(kernels) == 0 { + return nil, nil + } + + // Group events by memory kernel subtype, creating paired +bw/-bw events. + eventsByType := make(map[string][]bwEvent) + for _, k := range kernels { + name, _ := sym.GetName(k.NameID) + memType := analysis.MemoryKernelType(name) + bw := k.MemoryBWGbps + + // Round up 0-duration events to 1 us to avoid negative values + dur := k.Duration + if dur == 0 { + dur = 1 + } + + eventsByType[memType] = append(eventsByType[memType], + bwEvent{ts: k.StartedAt, bw: bw, key: memType}, + bwEvent{ts: k.StartedAt + dur, bw: -bw, key: memType}, + ) + } + + // For each subtype: sort by timestamp, cumsum, collect positive values, compute stats. + // Sort type names for deterministic output. + typeNames := make([]string, 0, len(eventsByType)) + for name := range eventsByType { + typeNames = append(typeNames, name) + } + sort.Strings(typeNames) + + var results []MemoryBWSummaryRow + for _, typeName := range typeNames { + events := eventsByType[typeName] + sort.Slice(events, func(i, j int) bool { + return events[i].ts < events[j].ts + }) + + // Cumulative sum + var cumsum float64 + var positiveValues []float64 + for _, ev := range events { + cumsum += ev.bw + if cumsum > 0 { + positiveValues = append(positiveValues, cumsum) + } + } + + if len(positiveValues) == 0 { + continue + } + + results = append(results, computeMemBWStats(rank, typeName, positiveValues)) + } + return results, nil +} + +func computeMemBWStats(rank int, name string, values []float64) MemoryBWSummaryRow { + sort.Float64s(values) + n := len(values) + + var sum float64 + for _, v := range values { + sum += v + } + mean := sum / float64(n) + + var variance float64 + for _, v := range values { + d := v - mean + variance += d * d + } + // Sample std (ddof=1) matching pandas default + std := 0.0 + if n > 1 { + std = math.Sqrt(variance / float64(n-1)) + } + + return MemoryBWSummaryRow{ + Rank: rank, + Name: name, + Count: n, + Mean: analysis.RoundTo(mean, 2), + Std: analysis.RoundTo(std, 2), + Min: analysis.RoundTo(values[0], 2), + P25: analysis.RoundTo(percentileFloat(values, 0.25), 2), + P50: analysis.RoundTo(percentileFloat(values, 0.50), 2), + P75: analysis.RoundTo(percentileFloat(values, 0.75), 2), + Max: analysis.RoundTo(values[n-1], 2), + } +} + +// memoryBWTimeSeriesForRank computes the per-operation-type memory bandwidth +// time series for a single rank. +func memoryBWTimeSeriesForRank(db *sql.DB, sym *symbol.Table, rank int) ([]MemoryBWPoint, error) { + kernels, err := store.LoadMemoryBWKernels(db, rank) + if err != nil { + return nil, fmt.Errorf("loading memory BW kernels: %w", err) + } + if len(kernels) == 0 { + return nil, nil + } + + type tsEvent struct { + timestamp int64 + processID int64 + name string + memoryBWGbps float64 + } + var tsEvents []tsEvent + + for _, k := range kernels { + name, err := sym.GetName(k.NameID) + if err != nil { + continue + } + opType := analysis.MemoryKernelType(name) + + dur := k.Duration + // Clamp zero-duration to 1 to avoid negative values + // see https://github.com/facebookresearch/HolisticTraceAnalysis/issues/20 + if dur == 0 { + dur = 1 + } + + // Start event: +bw at timestamp + tsEvents = append(tsEvents, tsEvent{ + timestamp: k.StartedAt, + processID: k.ProcessID, + name: opType, + memoryBWGbps: k.MemoryBWGbps, + }) + // End event: -bw at timestamp + duration + tsEvents = append(tsEvents, tsEvent{ + timestamp: k.StartedAt + dur, + processID: k.ProcessID, + name: opType, + memoryBWGbps: -k.MemoryBWGbps, + }) + } + + if len(tsEvents) == 0 { + return nil, nil + } + + // Sort by (name, timestamp) for cumulative sum per operation type. + sort.Slice(tsEvents, func(i, j int) bool { + if tsEvents[i].name != tsEvents[j].name { + return tsEvents[i].name < tsEvents[j].name + } + return tsEvents[i].timestamp < tsEvents[j].timestamp + }) + + // Compute cumulative bandwidth per operation type. + var points []MemoryBWPoint + var cumSum float64 + prevName := tsEvents[0].name + for _, e := range tsEvents { + if e.name != prevName { + cumSum = 0 + prevName = e.name + } + cumSum += e.memoryBWGbps + points = append(points, MemoryBWPoint{ + Timestamp: e.timestamp, + ProcessID: e.processID, + Name: e.name, + MemoryBWGbps: cumSum, + }) + } + + return points, nil +} + +// MemoryBWTimeSeries computes the memory bandwidth time series for all (or +// selected) ranks in the DB. +func MemoryBWTimeSeries(db *sql.DB, ranks []int) (map[int][]MemoryBWPoint, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + result := make(map[int][]MemoryBWPoint, len(ranks)) + for _, rank := range ranks { + points, err := memoryBWTimeSeriesForRank(db, sym, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + if len(points) > 0 { + result[rank] = points + } + } + return result, nil +} + +// percentileFloat computes the p-th percentile using linear interpolation on a sorted []float64. +func percentileFloat(sorted []float64, p float64) float64 { + n := len(sorted) + if n == 0 { + return 0 + } + if n == 1 { + return sorted[0] + } + idx := p * float64(n-1) + lo := int(idx) + hi := lo + 1 + if hi >= n { + return sorted[n-1] + } + frac := idx - float64(lo) + return sorted[lo]*(1-frac) + sorted[hi]*frac +} diff --git a/pkg/analysis/resource/memory_bw_test.go b/pkg/analysis/resource/memory_bw_test.go new file mode 100644 index 0000000..5585195 --- /dev/null +++ b/pkg/analysis/resource/memory_bw_test.go @@ -0,0 +1,170 @@ +package resource + +import ( + "testing" + + "hta/pkg/analysis" +) + +func TestMemoryBWSummary(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + results, err := MemoryBWSummary(db, MemoryBWOpts{}) + if err != nil { + t.Fatalf("memory bw summary: %v", err) + } + + if len(results) == 0 { + t.Fatal("expected non-empty results") + } + + // Verify we have results for both ranks + ranksSeen := make(map[int]bool) + for _, r := range results { + ranksSeen[r.Rank] = true + } + if len(ranksSeen) < 2 { + t.Errorf("expected results for at least 2 ranks, got %d", len(ranksSeen)) + } + + // Verify each result has valid stats + for _, r := range results { + if r.Count <= 0 { + t.Errorf("rank=%d name=%s: expected positive count, got %d", r.Rank, r.Name, r.Count) + } + if r.Min > r.Max { + t.Errorf("rank=%d name=%s: min (%.2f) > max (%.2f)", r.Rank, r.Name, r.Min, r.Max) + } + if r.Mean < r.Min || r.Mean > r.Max { + t.Errorf("rank=%d name=%s: mean (%.2f) outside [min, max] = [%.2f, %.2f]", + r.Rank, r.Name, r.Mean, r.Min, r.Max) + } + if r.P25 < r.Min || r.P75 > r.Max { + t.Errorf("rank=%d name=%s: percentiles out of range", r.Rank, r.Name) + } + if r.Std < 0 { + t.Errorf("rank=%d name=%s: negative std %.2f", r.Rank, r.Name, r.Std) + } + } + + // Verify we see expected memory op types (at least Memcpy subtypes) + namesSeen := make(map[string]bool) + for _, r := range results { + namesSeen[r.Name] = true + } + // The vision_transformer trace should have at least one Memcpy type + hasMemcpy := false + for name := range namesSeen { + if len(name) >= 6 && name[:6] == "Memcpy" { + hasMemcpy = true + break + } + } + if !hasMemcpy { + t.Errorf("expected at least one Memcpy subtype, got types: %v", namesSeen) + } +} + +func TestMemoryBWSummaryWithRanks(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + results, err := MemoryBWSummary(db, MemoryBWOpts{Ranks: []int{0}}) + if err != nil { + t.Fatalf("memory bw summary: %v", err) + } + + for _, r := range results { + if r.Rank != 0 { + t.Errorf("expected only rank 0, got rank %d", r.Rank) + } + } +} + +func TestMemoryBWTimeSeries(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + // Test with all ranks. + result, err := MemoryBWTimeSeries(db, nil) + if err != nil { + t.Fatalf("memory bw time series: %v", err) + } + + if len(result) < 2 { + t.Fatalf("expected at least 2 ranks, got %d", len(result)) + } + + // Collect all operation types seen across all ranks. + opTypes := make(map[string]bool) + for rank, points := range result { + if len(points) == 0 { + t.Errorf("rank %d: expected non-empty points", rank) + continue + } + + for _, p := range points { + opTypes[p.Name] = true + if p.Timestamp < 0 { + t.Errorf("rank %d: negative timestamp %d", rank, p.Timestamp) + } + } + } + + // Verify at least some expected operation types appear. + expectedTypes := []string{"Memcpy DtoH", "Memcpy HtoD"} + for _, et := range expectedTypes { + if !opTypes[et] { + t.Errorf("expected operation type %q not found in results; got types: %v", et, opTypes) + } + } + + // Test with specific rank. + result1, err := MemoryBWTimeSeries(db, []int{0}) + if err != nil { + t.Fatalf("memory bw for rank 0: %v", err) + } + if _, ok := result1[0]; !ok { + t.Error("expected rank 0 in results") + } + if len(result1) != 1 { + t.Errorf("expected 1 rank, got %d", len(result1)) + } +} + +func TestMemoryKernelType(t *testing.T) { + t.Parallel() + tests := []struct { + name string + want string + }{ + {"Memset (Device)", "Memset"}, + {"Memcpy DtoH (Device -> Pinned)", "Memcpy DtoH"}, + {"Memcpy HtoD (Pinned -> Device)", "Memcpy HtoD"}, + {"Memcpy DtoD (Device -> Device)", "Memcpy DtoD"}, + {"Memcpy", "Memcpy"}, + {"Unknown kernel", "Memcpy Unknown"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := analysis.MemoryKernelType(tt.name) + if got != tt.want { + t.Errorf("analysis.MemoryKernelType(%q) = %q, want %q", tt.name, got, tt.want) + } + }) + } +} diff --git a/pkg/analysis/resource/queue_length.go b/pkg/analysis/resource/queue_length.go new file mode 100644 index 0000000..941ed0f --- /dev/null +++ b/pkg/analysis/resource/queue_length.go @@ -0,0 +1,377 @@ +package resource + +import ( + "database/sql" + "fmt" + "math" + "sort" + + "hta/pkg/analysis" + "hta/pkg/store" + "hta/pkg/symbol" +) + +// QueueLengthPoint is a single point in the queue-length time series. +type QueueLengthPoint struct { + Timestamp int64 + Stream int + QueueLength int +} + +// QueueLengthSummaryRow holds descriptive statistics for one (rank, stream). +type QueueLengthSummaryRow struct { + Rank int + Stream int + Count int + Min int + Max int + Std float64 + P25 float64 + P50 float64 + P75 float64 +} + +// RuntimeLaunchNames are the CUDA/ROCm/MTIA runtime function names that +// enqueue work onto a GPU stream. Matches the Python symbol table filter. +var RuntimeLaunchNames = []string{ + "cudaLaunchKernel", + "cudaLaunchKernelExC", + "cuLaunchKernel", + "cuLaunchKernelEx", + "cudaMemcpyAsync", + "cudaMemsetAsync", + "hipLaunchKernel", + "hipExtModuleLaunchKernel", + "hipMemsetAsync", + "hipMemcpyAsync", + "hipMemcpyWithStream", + "runFunction - job_prep_and_submit_for_execution", +} + +// queueEvent is an internal +1/-1 change for a stream. +type queueEvent struct { + timestamp int64 + stream int + delta int // +1 for enqueue (CPU launch), -1 for dequeue (GPU kernel start) +} + +// QueueLengthTimeSeriesForRank computes the per-stream queue depth time series +// for a single rank. Shared with the future queue-length-time-series command. +func QueueLengthTimeSeriesForRank(db *sql.DB, sym *symbol.Table, rank int) ([]QueueLengthPoint, error) { + // Resolve name IDs for runtime launch functions. + var nameIDs []int + for _, name := range RuntimeLaunchNames { + id := sym.GetID(name) + if id >= 0 { + nameIDs = append(nameIDs, id) + } + } + if len(nameIDs) == 0 { + return nil, nil + } + + // Load CPU runtime launch events and GPU kernels with stream info. + cpuEvents, err := store.LoadCPURuntimeEvents(db, rank, nameIDs) + if err != nil { + return nil, fmt.Errorf("loading runtime launch events: %w", err) + } + gpuEvents, err := store.LoadGPUKernelsWithStream(db, rank) + if err != nil { + return nil, fmt.Errorf("loading GPU kernels with stream: %w", err) + } + + // Index GPU events by correlation ID. + gpuByCorr := make(map[int]store.GPUKernelStreamRow, len(gpuEvents)) + for _, g := range gpuEvents { + gpuByCorr[g.Correlation] = g + } + + // Build +1/-1 events from matched pairs only. + var events []queueEvent + for _, c := range cpuEvents { + g, ok := gpuByCorr[c.Correlation] + if !ok { + continue + } + // +1 at CPU launch time (use GPU kernel's stream) + events = append(events, queueEvent{ + timestamp: c.StartedAt, + stream: g.CUDAStream, + delta: 1, + }) + // -1 at GPU kernel start time + events = append(events, queueEvent{ + timestamp: g.StartedAt, + stream: g.CUDAStream, + delta: -1, + }) + } + + if len(events) == 0 { + return nil, nil + } + + // Sort by (stream, timestamp) for cumulative sum per stream. + sort.Slice(events, func(i, j int) bool { + if events[i].stream != events[j].stream { + return events[i].stream < events[j].stream + } + return events[i].timestamp < events[j].timestamp + }) + + // Compute cumulative queue length per stream. + var points []QueueLengthPoint + cumSum := 0 + prevStream := events[0].stream + for _, e := range events { + if e.stream != prevStream { + cumSum = 0 + prevStream = e.stream + } + cumSum += e.delta + points = append(points, QueueLengthPoint{ + Timestamp: e.timestamp, + Stream: e.stream, + QueueLength: cumSum, + }) + } + + return points, nil +} + +// QueueLengthTimeSeries computes the CUDA stream queue depth time series +// for all (or selected) ranks in the DB. +func QueueLengthTimeSeries(db *sql.DB, ranks []int) (map[int][]QueueLengthPoint, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + result := make(map[int][]QueueLengthPoint, len(ranks)) + for _, rank := range ranks { + points, err := QueueLengthTimeSeriesForRank(db, sym, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + if len(points) > 0 { + result[rank] = points + } + } + return result, nil +} + +// QueueLengthSummary computes descriptive statistics of the CUDA stream queue +// depth per (rank, stream). If ranks is nil, all ranks in the DB are used. +func QueueLengthSummary(db *sql.DB, ranks []int) ([]QueueLengthSummaryRow, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + if len(ranks) == 0 { + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + var results []QueueLengthSummaryRow + + for _, rank := range ranks { + points, err := QueueLengthTimeSeriesForRank(db, sym, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + if len(points) == 0 { + continue + } + + // Group queue_length values by stream. + streamVals := make(map[int][]float64) + for _, p := range points { + streamVals[p.Stream] = append(streamVals[p.Stream], float64(p.QueueLength)) + } + + // Compute summary stats per stream. + streams := make([]int, 0, len(streamVals)) + for s := range streamVals { + streams = append(streams, s) + } + sort.Ints(streams) + + for _, stream := range streams { + vals := streamVals[stream] + n := len(vals) + if n == 0 { + continue + } + + minVal := vals[0] + maxVal := vals[0] + for _, v := range vals[1:] { + if v < minVal { + minVal = v + } + if v > maxVal { + maxVal = v + } + } + + sorted := make([]float64, n) + copy(sorted, vals) + sort.Float64s(sorted) + + results = append(results, QueueLengthSummaryRow{ + Rank: rank, + Stream: stream, + Count: n, + Min: int(minVal), + Max: int(maxVal), + Std: stddevDdof1(vals), + P25: analysis.QuantileLinear(sorted, 0.25), + P50: analysis.QuantileLinear(sorted, 0.50), + P75: analysis.QuantileLinear(sorted, 0.75), + }) + } + } + + // Sort by (rank, stream). + sort.Slice(results, func(i, j int) bool { + if results[i].Rank != results[j].Rank { + return results[i].Rank < results[j].Rank + } + return results[i].Stream < results[j].Stream + }) + + return results, nil +} + +// BlockedQueueOpts controls the blocked-on-full-queue analysis. +type BlockedQueueOpts struct { + Ranks []int + MaxQueueLength int // default 1024 +} + +// BlockedQueueRow holds the time a single (rank, stream) spent at max queue depth. +type BlockedQueueRow struct { + Rank int + Stream int + Duration int64 // total time at max queue length (same units as trace timestamps) + RelativeDuration float64 // duration / trace_duration +} + +// BlockedOnFullQueue calculates how much time the GPU launch queue was at +// maximum capacity per (rank, stream). This indicates CPU-side bottlenecks +// that prevent the GPU from being fully utilised. +func BlockedOnFullQueue(db *sql.DB, opts BlockedQueueOpts) ([]BlockedQueueRow, error) { + maxQL := opts.MaxQueueLength + if maxQL <= 0 { + maxQL = 1024 // CUDA_MAX_LAUNCH_QUEUE_PER_STREAM + } + + tsMap, err := QueueLengthTimeSeries(db, opts.Ranks) + if err != nil { + return nil, fmt.Errorf("queue length time series: %w", err) + } + + var results []BlockedQueueRow + + for rank, points := range tsMap { + if len(points) == 0 { + continue + } + + traceDur, err := store.TraceDuration(db, rank) + if err != nil { + return nil, fmt.Errorf("trace duration for rank %d: %w", rank, err) + } + + // Group points by stream (they're already sorted by (stream, ts) + // from QueueLengthTimeSeriesForRank). + type streamSpan struct { + start int + end int + } + streamSpans := make(map[int]*streamSpan) + var streamOrder []int + for i, p := range points { + sp, ok := streamSpans[p.Stream] + if !ok { + sp = &streamSpan{start: i} + streamSpans[p.Stream] = sp + streamOrder = append(streamOrder, p.Stream) + } + sp.end = i + 1 + } + + for _, stream := range streamOrder { + sp := streamSpans[stream] + streamPoints := points[sp.start:sp.end] + + var blockedTime int64 + for i := 0; i < len(streamPoints)-1; i++ { + if streamPoints[i].QueueLength >= maxQL { + dur := streamPoints[i+1].Timestamp - streamPoints[i].Timestamp + if dur > 0 { + blockedTime += dur + } + } + } + + if blockedTime == 0 { + continue + } + + var relDur float64 + if traceDur > 0 { + relDur = float64(blockedTime) / float64(traceDur) + } + + results = append(results, BlockedQueueRow{ + Rank: rank, + Stream: stream, + Duration: blockedTime, + RelativeDuration: relDur, + }) + } + } + + // Sort by (rank, stream). + sort.Slice(results, func(i, j int) bool { + if results[i].Rank != results[j].Rank { + return results[i].Rank < results[j].Rank + } + return results[i].Stream < results[j].Stream + }) + + return results, nil +} + +// stddevDdof1 computes the sample standard deviation with ddof=1, matching +// pandas default. Returns 0 for slices with fewer than 2 elements. +func stddevDdof1(vals []float64) float64 { + n := len(vals) + if n < 2 { + return 0 + } + + var sum float64 + for _, v := range vals { + sum += v + } + mean := sum / float64(n) + + var sumSqDiff float64 + for _, v := range vals { + d := v - mean + sumSqDiff += d * d + } + return math.Sqrt(sumSqDiff / float64(n-1)) +} diff --git a/pkg/analysis/resource/queue_length_test.go b/pkg/analysis/resource/queue_length_test.go new file mode 100644 index 0000000..e8b92a8 --- /dev/null +++ b/pkg/analysis/resource/queue_length_test.go @@ -0,0 +1,197 @@ +package resource + +import ( + "math" + "testing" +) + +func TestStddevDdof1(t *testing.T) { + t.Parallel() + tests := []struct { + name string + vals []float64 + want float64 + }{ + {"empty", nil, 0}, + {"single", []float64{5}, 0}, + {"two elements", []float64{2, 4}, math.Sqrt(2)}, + { + // vals: 10, 20, 30, 40 → mean=25 + // var(ddof=1) = (225+25+25+225)/3 = 500/3 + "four elements", + []float64{10, 20, 30, 40}, + math.Sqrt(500.0 / 3.0), + }, + { + // vals: 1, 1, 1, 1 → stddev = 0 + "identical", + []float64{1, 1, 1, 1}, + 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stddevDdof1(tt.vals) + if math.Abs(got-tt.want) > 1e-9 { + t.Errorf("stddevDdof1(%v) = %v, want %v", tt.vals, got, tt.want) + } + }) + } +} + +func TestQueueLengthSummary(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + results, err := QueueLengthSummary(db, nil) + if err != nil { + t.Fatalf("queue length summary: %v", err) + } + + if len(results) == 0 { + t.Fatal("expected non-empty results") + } + + for _, r := range results { + if r.Count <= 0 { + t.Errorf("stream %d: count should be > 0, got %d", r.Stream, r.Count) + } + if r.Min < 0 { + t.Errorf("stream %d: min should be >= 0, got %d", r.Stream, r.Min) + } + if r.Max < r.Min { + t.Errorf("stream %d: max (%d) < min (%d)", r.Stream, r.Max, r.Min) + } + if r.Std < 0 { + t.Errorf("stream %d: std should be >= 0, got %f", r.Stream, r.Std) + } + if r.P25 > r.P50 { + t.Errorf("stream %d: P25 (%.2f) > P50 (%.2f)", r.Stream, r.P25, r.P50) + } + if r.P50 > r.P75 { + t.Errorf("stream %d: P50 (%.2f) > P75 (%.2f)", r.Stream, r.P50, r.P75) + } + if r.P25 < float64(r.Min) { + t.Errorf("stream %d: P25 (%.2f) < min (%d)", r.Stream, r.P25, r.Min) + } + if r.P75 > float64(r.Max) { + t.Errorf("stream %d: P75 (%.2f) > max (%d)", r.Stream, r.P75, r.Max) + } + } + + // Results should be sorted by (rank, stream). + for i := 1; i < len(results); i++ { + a, b := results[i-1], results[i] + if a.Rank > b.Rank || (a.Rank == b.Rank && a.Stream >= b.Stream) { + t.Errorf("results not sorted at index %d: (%d,%d) before (%d,%d)", + i, a.Rank, a.Stream, b.Rank, b.Stream) + } + } +} + +func TestBlockedOnFullQueue(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + // Use default opts (max queue length 1024, all ranks). + results, err := BlockedOnFullQueue(db, BlockedQueueOpts{}) + if err != nil { + t.Fatalf("blocked on full queue: %v", err) + } + + // Results may be empty if no stream reaches 1024 — that's valid. + for _, r := range results { + if r.Duration < 0 { + t.Errorf("rank %d stream %d: duration should be >= 0, got %d", + r.Rank, r.Stream, r.Duration) + } + if r.RelativeDuration < 0 || r.RelativeDuration > 1 { + t.Errorf("rank %d stream %d: relative_duration should be in [0,1], got %f", + r.Rank, r.Stream, r.RelativeDuration) + } + } + + // Results should be sorted by (rank, stream). + for i := 1; i < len(results); i++ { + a, b := results[i-1], results[i] + if a.Rank > b.Rank || (a.Rank == b.Rank && a.Stream >= b.Stream) { + t.Errorf("results not sorted at index %d: (%d,%d) before (%d,%d)", + i, a.Rank, a.Stream, b.Rank, b.Stream) + } + } + + // Also test with a very low threshold to ensure we get some results. + lowResults, err := BlockedOnFullQueue(db, BlockedQueueOpts{MaxQueueLength: 1}) + if err != nil { + t.Fatalf("blocked on full queue (low threshold): %v", err) + } + + if len(lowResults) == 0 { + t.Log("warning: even with MaxQueueLength=1, no blocked time found") + } + for _, r := range lowResults { + if r.Duration <= 0 { + t.Errorf("rank %d stream %d: with threshold 1, expected positive duration, got %d", + r.Rank, r.Stream, r.Duration) + } + if r.RelativeDuration < 0 || r.RelativeDuration > 1 { + t.Errorf("rank %d stream %d: relative_duration should be in [0,1], got %f", + r.Rank, r.Stream, r.RelativeDuration) + } + } +} + +func TestQueueLengthTimeSeries(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + // Test with all ranks (default). + result, err := QueueLengthTimeSeries(db, nil) + if err != nil { + t.Fatalf("queue length time series: %v", err) + } + + if len(result) < 2 { + t.Fatalf("expected at least 2 ranks, got %d", len(result)) + } + + for rank, points := range result { + if len(points) == 0 { + t.Errorf("rank %d: expected non-empty points", rank) + continue + } + + // Verify queue_length values are non-negative. + for i, p := range points { + if p.QueueLength < 0 { + t.Errorf("rank %d point %d: negative queue_length %d at ts=%d stream=%d", + rank, i, p.QueueLength, p.Timestamp, p.Stream) + } + } + } + + // Test with specific rank. + result1, err := QueueLengthTimeSeries(db, []int{0}) + if err != nil { + t.Fatalf("queue length for rank 0: %v", err) + } + if _, ok := result1[0]; !ok { + t.Error("expected rank 0 in results") + } + if len(result1) != 1 { + t.Errorf("expected 1 rank, got %d", len(result1)) + } +} diff --git a/pkg/analysis/resource/testmain_test.go b/pkg/analysis/resource/testmain_test.go new file mode 100644 index 0000000..4d70ce9 --- /dev/null +++ b/pkg/analysis/resource/testmain_test.go @@ -0,0 +1,86 @@ +package resource + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + "hta/pkg/pipeline" + "hta/pkg/store" +) + +// sharedVTDBPath and sharedCUPTIDBPath hold paths to pre-built SQLite DBs +// created once by TestMain. Each integration test opens its own read handle +// via openSharedVTDB / openSharedCUPTIDB instead of re-running the expensive +// pipeline.RunWithDB preprocessing. +var ( + sharedVTDBPath string + sharedCUPTIDBPath string + sharedTmpDir string +) + +func TestMain(m *testing.M) { + flag.Parse() + + if !testing.Short() { + if err := setupSharedDBs(); err != nil { + fmt.Fprintf(os.Stderr, "resource test setup: %v\n", err) + os.Exit(1) + } + } + + code := m.Run() + + if sharedTmpDir != "" { + os.RemoveAll(sharedTmpDir) + } + os.Exit(code) +} + +func setupSharedDBs() error { + _, thisFile, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + + tmp, err := os.MkdirTemp("", "resource-test-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + sharedTmpDir = tmp + + // Preprocess vision_transformer trace data once. + vtDir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(vtDir); err == nil { + dbPath := filepath.Join(tmp, "vt.db") + db, err := store.Create(dbPath) + if err != nil { + return fmt.Errorf("create vt db: %w", err) + } + if err := pipeline.RunWithDB(vtDir, db); err != nil { + db.Close() + return fmt.Errorf("preprocess vt: %w", err) + } + db.Close() + sharedVTDBPath = dbPath + } + + // Preprocess cupti_profiler trace data once. + cuptiDir := filepath.Join(root, "tests", "data", "cupti_profiler") + if _, err := os.Stat(cuptiDir); err == nil { + dbPath := filepath.Join(tmp, "cupti.db") + db, err := store.Create(dbPath) + if err != nil { + return fmt.Errorf("create cupti db: %w", err) + } + if err := pipeline.RunWithDB(cuptiDir, db); err != nil { + db.Close() + return fmt.Errorf("preprocess cupti: %w", err) + } + db.Close() + sharedCUPTIDBPath = dbPath + } + + return nil +} diff --git a/pkg/analysis/resource/trace_with_counters.go b/pkg/analysis/resource/trace_with_counters.go new file mode 100644 index 0000000..cd94fb7 --- /dev/null +++ b/pkg/analysis/resource/trace_with_counters.go @@ -0,0 +1,284 @@ +package resource + +import ( + "compress/gzip" + "database/sql" + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + "sort" + "strings" + + "hta/pkg/store" +) + +// CounterType is a bitmask selecting which counter time series to embed. +type CounterType int + +const ( + CounterQueueLength CounterType = 1 << iota + CounterMemoryBW + CounterAll = CounterQueueLength | CounterMemoryBW +) + +// GenerateCountersOpts configures trace-with-counters generation. +type GenerateCountersOpts struct { + Ranks []int // nil = all ranks + Counters CounterType // 0 = CounterAll + OutputSuffix string // default "_with_counters" + OutputDir string // if set, write output files to this directory instead of next to source +} + +// counterEvent is a Chrome Trace counter event (ph: "C"). +type counterEvent struct { + Ph string `json:"ph"` + Ts int64 `json:"ts"` + Pid int64 `json:"pid"` + Name string `json:"name"` + Args map[string]any `json:"args"` + ID int `json:"id,omitempty"` +} + +// GenerateTraceWithCounters produces enriched trace JSON files with embedded +// time series counter events (queue length and/or memory bandwidth) viewable +// in Perfetto/Chrome trace viewer. +// +// It returns the list of output file paths written. +func GenerateTraceWithCounters(db *sql.DB, opts GenerateCountersOpts) ([]string, error) { + counters := opts.Counters + if counters == 0 { + counters = CounterAll + } + suffix := opts.OutputSuffix + if suffix == "" { + suffix = "_with_counters" + } + + ranks := opts.Ranks + if len(ranks) == 0 { + var err error + ranks, err = store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + } + + // Collect time series data. + var qlMap map[int][]QueueLengthPoint + var bwMap map[int][]MemoryBWPoint + + if counters&CounterQueueLength != 0 { + var err error + qlMap, err = QueueLengthTimeSeries(db, ranks) + if err != nil { + return nil, fmt.Errorf("queue length time series: %w", err) + } + } + if counters&CounterMemoryBW != 0 { + var err error + bwMap, err = MemoryBWTimeSeries(db, ranks) + if err != nil { + return nil, fmt.Errorf("memory bw time series: %w", err) + } + } + + // Process each rank that has data. + sortedRanks := make([]int, len(ranks)) + copy(sortedRanks, ranks) + sort.Ints(sortedRanks) + + var outputFiles []string + + for _, rank := range sortedRanks { + qlPoints := qlMap[rank] + bwPoints := bwMap[rank] + if len(qlPoints) == 0 && len(bwPoints) == 0 { + continue + } + + // Load original trace file path. + traceFile, err := store.TraceFile(db, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: loading trace file path: %w", rank, err) + } + + // Load GPU process ID for counter events. + gpuPid, err := store.GPUProcessID(db, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: loading GPU process ID: %w", rank, err) + } + + // Load raw trace JSON. + rawTrace, err := loadRawTrace(traceFile) + if err != nil { + return nil, fmt.Errorf("rank %d: loading raw trace: %w", rank, err) + } + + // Compute timestamp offset. + // DB stores timestamps relative to global minTs (startedAt = rawTs - globalMinTs). + // We need to convert DB timestamps back to raw trace timestamps. + rawMinTs, err := findMinTimestamp(rawTrace["traceEvents"]) + if err != nil { + return nil, fmt.Errorf("rank %d: finding min timestamp: %w", rank, err) + } + dbMinTs, err := store.MinStartedAt(db, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: loading min started_at: %w", rank, err) + } + offset := rawMinTs - dbMinTs + + // Build counter events. + var events []counterEvent + + for _, p := range qlPoints { + events = append(events, counterEvent{ + Ph: "C", + Ts: p.Timestamp + offset, + Pid: gpuPid, + Name: "Queue Length", + Args: map[string]any{"Queue Length": p.QueueLength}, + ID: p.Stream, + }) + } + + for _, p := range bwPoints { + events = append(events, counterEvent{ + Ph: "C", + Ts: p.Timestamp + offset, + Pid: gpuPid, + Name: p.Name, + Args: map[string]any{p.Name + " BW": p.MemoryBWGbps}, + }) + } + + // Append counter events to traceEvents array. + existingEvents, ok := rawTrace["traceEvents"] + if !ok { + return nil, fmt.Errorf("rank %d: traceEvents not found in raw trace", rank) + } + + counterJSON, err := json.Marshal(events) + if err != nil { + return nil, fmt.Errorf("rank %d: marshalling counter events: %w", rank, err) + } + + // Merge: strip outer brackets from both, concatenate with comma. + existingTrimmed := strings.TrimSpace(string(existingEvents)) + counterTrimmed := strings.TrimSpace(string(counterJSON)) + + // Remove [ and ] from both sides. + existingTrimmed = existingTrimmed[1 : len(existingTrimmed)-1] + counterTrimmed = counterTrimmed[1 : len(counterTrimmed)-1] + + merged := "[" + existingTrimmed + "," + counterTrimmed + "]" + rawTrace["traceEvents"] = json.RawMessage(merged) + + // Write output file. + outPath := outputPath(traceFile, suffix) + if opts.OutputDir != "" { + outPath = filepath.Join(opts.OutputDir, filepath.Base(outPath)) + } + if err := writeRawTrace(outPath, rawTrace); err != nil { + return nil, fmt.Errorf("rank %d: writing output: %w", rank, err) + } + outputFiles = append(outputFiles, outPath) + } + + return outputFiles, nil +} + +// loadRawTrace reads a JSON or gzipped JSON trace file into a raw map. +func loadRawTrace(path string) (map[string]json.RawMessage, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var dec *json.Decoder + if strings.HasSuffix(path, ".gz") { + gz, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("gzip open: %w", err) + } + defer gz.Close() + dec = json.NewDecoder(gz) + } else { + dec = json.NewDecoder(f) + } + dec.UseNumber() + + var raw map[string]json.RawMessage + if err := dec.Decode(&raw); err != nil { + return nil, fmt.Errorf("json decode: %w", err) + } + return raw, nil +} + +// findMinTimestamp scans a raw traceEvents JSON array for the minimum "ts" value. +func findMinTimestamp(eventsRaw json.RawMessage) (int64, error) { + if eventsRaw == nil { + return 0, fmt.Errorf("no traceEvents data") + } + + var events []json.RawMessage + if err := json.Unmarshal(eventsRaw, &events); err != nil { + return 0, fmt.Errorf("parsing traceEvents: %w", err) + } + + minTs := int64(math.MaxInt64) + for _, raw := range events { + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + continue + } + tsRaw, ok := obj["ts"] + if !ok { + continue + } + var n json.Number + if err := json.Unmarshal(tsRaw, &n); err != nil { + continue + } + if i, err := n.Int64(); err == nil { + if i < minTs { + minTs = i + } + } else if f, err := n.Float64(); err == nil { + iv := int64(math.Round(f)) + if iv < minTs { + minTs = iv + } + } + } + + if minTs == math.MaxInt64 { + return 0, fmt.Errorf("no timestamps found in traceEvents") + } + return minTs, nil +} + +// outputPath computes the output file path by inserting the suffix before .json/.json.gz +// and always producing an uncompressed .json file. +func outputPath(traceFile, suffix string) string { + if strings.HasSuffix(traceFile, ".json.gz") { + base := traceFile[:len(traceFile)-len(".json.gz")] + return base + suffix + ".json" + } + if strings.HasSuffix(traceFile, ".json") { + base := traceFile[:len(traceFile)-len(".json")] + return base + suffix + ".json" + } + return traceFile + suffix + ".json" +} + +// writeRawTrace writes the trace map as indented JSON. +func writeRawTrace(path string, data map[string]json.RawMessage) error { + out, err := json.MarshalIndent(data, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, out, 0644) +} diff --git a/pkg/analysis/resource/trace_with_counters_test.go b/pkg/analysis/resource/trace_with_counters_test.go new file mode 100644 index 0000000..b3fb636 --- /dev/null +++ b/pkg/analysis/resource/trace_with_counters_test.go @@ -0,0 +1,165 @@ +package resource + +import ( + "encoding/json" + "os" + "testing" +) + +func TestGenerateTraceWithCounters(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + outDir := t.TempDir() + opts := GenerateCountersOpts{ + Counters: CounterAll, + OutputSuffix: "_test_counters", + OutputDir: outDir, + } + + outputFiles, err := GenerateTraceWithCounters(db, opts) + if err != nil { + t.Fatalf("generate trace with counters: %v", err) + } + + if len(outputFiles) == 0 { + t.Fatal("expected at least one output file") + } + + for _, outFile := range outputFiles { + // Verify the file exists. + info, err := os.Stat(outFile) + if err != nil { + t.Fatalf("output file %s: %v", outFile, err) + } + if info.Size() == 0 { + t.Fatalf("output file %s is empty", outFile) + } + + // Verify it's valid JSON with traceEvents. + data, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("reading output file %s: %v", outFile, err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("output file %s is not valid JSON: %v", outFile, err) + } + + eventsRaw, ok := raw["traceEvents"] + if !ok { + t.Fatalf("output file %s missing traceEvents", outFile) + } + + // Parse events and count counter events (ph == "C"). + var events []map[string]json.RawMessage + if err := json.Unmarshal(eventsRaw, &events); err != nil { + t.Fatalf("parsing traceEvents: %v", err) + } + + counterCount := 0 + for _, ev := range events { + phRaw, ok := ev["ph"] + if !ok { + continue + } + var ph string + if err := json.Unmarshal(phRaw, &ph); err != nil { + continue + } + if ph == "C" { + counterCount++ + } + } + + if counterCount == 0 { + t.Errorf("output file %s: expected counter events (ph=C), found none", outFile) + } + t.Logf("output file %s: %d total events, %d counter events", outFile, len(events), counterCount) + } +} + +func TestGenerateTraceWithCountersQueueOnly(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + db := openSharedVTDB(t) + defer db.Close() + + outDir := t.TempDir() + opts := GenerateCountersOpts{ + Counters: CounterQueueLength, + OutputSuffix: "_test_ql", + OutputDir: outDir, + } + + outputFiles, err := GenerateTraceWithCounters(db, opts) + if err != nil { + t.Fatalf("generate trace with counters (queue only): %v", err) + } + + for _, outFile := range outputFiles { + data, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("reading %s: %v", outFile, err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("%s not valid JSON: %v", outFile, err) + } + + var events []map[string]json.RawMessage + if err := json.Unmarshal(raw["traceEvents"], &events); err != nil { + t.Fatalf("parsing traceEvents: %v", err) + } + + // Verify only queue length counters exist. + for _, ev := range events { + phRaw, ok := ev["ph"] + if !ok { + continue + } + var ph string + if json.Unmarshal(phRaw, &ph) != nil || ph != "C" { + continue + } + var name string + if err := json.Unmarshal(ev["name"], &name); err != nil { + t.Errorf("counter event missing name") + continue + } + if name != "Queue Length" { + t.Errorf("expected only Queue Length counters, got %q", name) + } + } + } +} + +func TestOutputPath(t *testing.T) { + t.Parallel() + tests := []struct { + traceFile string + suffix string + want string + }{ + {"/data/rank_0.json.gz", "_with_counters", "/data/rank_0_with_counters.json"}, + {"/data/rank_0.json", "_with_counters", "/data/rank_0_with_counters.json"}, + {"/data/trace.json.gz", "_ql", "/data/trace_ql.json"}, + {"/data/trace.json", "_ql", "/data/trace_ql.json"}, + {"/data/trace", "_suffix", "/data/trace_suffix.json"}, + } + + for _, tt := range tests { + got := outputPath(tt.traceFile, tt.suffix) + if got != tt.want { + t.Errorf("outputPath(%q, %q) = %q, want %q", tt.traceFile, tt.suffix, got, tt.want) + } + } +} diff --git a/pkg/analysis/straggler/straggler.go b/pkg/analysis/straggler/straggler.go new file mode 100644 index 0000000..5232983 --- /dev/null +++ b/pkg/analysis/straggler/straggler.go @@ -0,0 +1,417 @@ +package straggler + +import ( + "database/sql" + "fmt" + "math" + "sort" + "strconv" + "strings" + + "hta/pkg/analysis" + "hta/pkg/store" +) + +// StragglerOpts configures the potential stragglers analysis. +type StragglerOpts struct { + NumCandidates int // top K straggler candidates (default 2; values < 1 are clamped to 1) + ProfilerSteps []int // nil = all available profiler steps +} + +// StragglerResult holds a rank and how many times it was detected as a straggler. +type StragglerResult struct { + Rank int + Count int +} + +type stepBoundary struct { + step int + start int64 + end int64 +} + +// PotentialStragglers identifies ranks that may be stragglers by analyzing +// cross-rank communication kernel timing. Ranks with late-starting critical +// comm kernels are flagged. +func PotentialStragglers(db *sql.DB, opts StragglerOpts) ([]StragglerResult, error) { + numCandidates := max(opts.NumCandidates, 1) + + // 1. Load symbol table and get profiler steps. + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + allSteps, err := analysis.ProfilerSteps(db) + if err != nil { + return nil, fmt.Errorf("loading profiler steps: %w", err) + } + + // 2. Validate/select profiler steps. + selectedSteps := opts.ProfilerSteps + if selectedSteps == nil { + selectedSteps = allSteps + } else { + allStepsSet := make(map[int]bool, len(allSteps)) + for _, s := range allSteps { + allStepsSet[s] = true + } + var valid []int + for _, s := range selectedSteps { + if allStepsSet[s] { + valid = append(valid, s) + } + } + if len(valid) == 0 { + return nil, fmt.Errorf("no valid profiler steps found in %v; available: %v", opts.ProfilerSteps, allSteps) + } + selectedSteps = valid + } + selectedStepsSet := make(map[int]bool, len(selectedSteps)) + for _, s := range selectedSteps { + selectedStepsSet[s] = true + } + + // 3. Find profiler step name IDs from symbol table. + var profilerStepNameIDs []int + for _, entry := range sym.All() { + if analysis.ProfilerStepRe.MatchString(entry.Name) { + profilerStepNameIDs = append(profilerStepNameIDs, entry.ID) + } + } + if len(profilerStepNameIDs) == 0 { + return nil, nil + } + + // 4. Load profiler step events and build boundaries per rank. + stepEvents, err := store.LoadProfilerStepEvents(db, profilerStepNameIDs) + if err != nil { + return nil, fmt.Errorf("loading profiler step events: %w", err) + } + + boundaries := make(map[int][]stepBoundary) // rank → sorted boundaries + for _, e := range stepEvents { + name, err := sym.GetName(e.NameID) + if err != nil { + continue + } + m := analysis.ProfilerStepRe.FindStringSubmatch(name) + if m == nil { + continue + } + stepNum, err := strconv.Atoi(m[1]) + if err != nil { + continue + } + boundaries[e.Rank] = append(boundaries[e.Rank], stepBoundary{ + step: stepNum, + start: e.StartedAt, + end: e.EndedAt, + }) + } + + // 5. Find ncclKernel name IDs (matching Python's exact filter). + commNameIDs := make(map[int]bool) + for _, entry := range sym.All() { + if strings.HasPrefix(entry.Name, "ncclKernel") { + commNameIDs[entry.ID] = true + } + } + + // 6. Load all GPU kernels across ranks. + gpuKernels, err := store.LoadAllGPUKernelsFull(db) + if err != nil { + return nil, fmt.Errorf("loading GPU kernels: %w", err) + } + if len(gpuKernels) == 0 { + return nil, nil + } + + // 7. Assign iterations and compute global time bounds. + type annotatedKernel struct { + rank int + startedAt int64 + endedAt int64 + nameID int + streamID int + iteration int + } + + nIters := len(selectedSteps) + var globalMin, globalMax int64 + globalMin = math.MaxInt64 + + var allAnnotated []annotatedKernel + for _, k := range gpuKernels { + bounds, ok := boundaries[k.Rank] + if !ok { + continue + } + iter := assignIteration(k.StartedAt, bounds) + if iter < 0 { + continue + } + if !selectedStepsSet[iter] { + continue + } + allAnnotated = append(allAnnotated, annotatedKernel{ + rank: k.Rank, + startedAt: k.StartedAt, + endedAt: k.EndedAt, + nameID: k.NameID, + streamID: k.StreamID, + iteration: iter, + }) + if k.StartedAt < globalMin { + globalMin = k.StartedAt + } + if k.EndedAt > globalMax { + globalMax = k.EndedAt + } + } + + if len(allAnnotated) == 0 || nIters == 0 { + return nil, nil + } + + meanIterTime := float64(globalMax-globalMin) / float64(nIters) + if meanIterTime <= 0 { + return nil, nil + } + minDuration := meanIterTime * 0.01 + + // 8. Filter to comm kernels with stream > 0, iteration > 0, duration >= threshold. + var commKernels []annotatedKernel + for _, k := range allAnnotated { + if !commNameIDs[k.nameID] { + continue + } + if k.streamID <= 0 { + continue + } + if k.iteration <= 0 { + continue + } + dur := float64(k.endedAt - k.startedAt) + if dur < minDuration { + continue + } + commKernels = append(commKernels, k) + } + + if len(commKernels) == 0 { + return nil, nil + } + + // 9. Group by (rank, stream, iteration, name) → take last kernel per group. + // Data is already sorted by (rank, started_at), so last seen per group is the latest. + type groupKey struct { + rank int + stream int + iteration int + nameID int + } + lastKernels := make(map[groupKey]annotatedKernel) + for _, k := range commKernels { + key := groupKey{k.rank, k.streamID, k.iteration, k.nameID} + lastKernels[key] = k // overwrites with later entry (sorted by started_at) + } + + // 10. Compute normalized values. + type metricEntry struct { + rank int + stream int + iteration int + nameID int + normalizedStartTime float64 + normalizedDuration float64 + } + entries := make([]metricEntry, 0, len(lastKernels)) + for _, k := range lastKernels { + entries = append(entries, metricEntry{ + rank: k.rank, + stream: k.streamID, + iteration: k.iteration, + nameID: k.nameID, + normalizedStartTime: float64(k.startedAt-globalMin) / meanIterTime, + normalizedDuration: float64(k.endedAt-k.startedAt) / meanIterTime, + }) + } + + // 11. Find best indicator (stream, name): + // Group by (stream, iteration, name) → std dev of normalizedDuration across ranks. + type streamIterName struct { + stream int + iteration int + nameID int + } + durationsByGroup := make(map[streamIterName][]float64) + for _, e := range entries { + key := streamIterName{e.stream, e.iteration, e.nameID} + durationsByGroup[key] = append(durationsByGroup[key], e.normalizedDuration) + } + + // Compute std dev per (stream, iteration, name) group. + type streamName struct { + stream int + nameID int + } + stdDevsByStreamName := make(map[streamName][]float64) + for key, vals := range durationsByGroup { + sd := stddev(vals) + if math.IsNaN(sd) { + continue // skip single-element groups (NaN), matching pandas .mean() skipping NaN + } + sn := streamName{key.stream, key.nameID} + stdDevsByStreamName[sn] = append(stdDevsByStreamName[sn], sd) + } + + if len(stdDevsByStreamName) == 0 { + return nil, nil + } + + // Compute mean std dev per (stream, name) and find the best indicator. + var bestSN streamName + bestMean := math.Inf(-1) + for sn, sds := range stdDevsByStreamName { + var sum float64 + for _, v := range sds { + sum += v + } + mean := sum / float64(len(sds)) + if mean > bestMean { + bestMean = mean + bestSN = sn + } + } + + // 12. Filter to candidate kernels matching best (stream, name). + var candidates []metricEntry + for _, e := range entries { + if e.stream == bestSN.stream && e.nameID == bestSN.nameID { + candidates = append(candidates, e) + } + } + + if len(candidates) == 0 { + return nil, nil + } + + // 13. Top-K straggler selection. + iterations := make(map[int]bool) + for _, c := range candidates { + iterations[c.iteration] = true + } + + nIterations := len(iterations) + + if nIterations <= 1 { + // Single iteration: sort by normalizedStartTime desc, threshold = k-th value. + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].normalizedStartTime > candidates[j].normalizedStartTime + }) + threshIdx := numCandidates - 1 + if threshIdx >= len(candidates) { + threshIdx = len(candidates) - 1 + } + threshold := candidates[threshIdx].normalizedStartTime + + stragglerSet := make(map[int]int) // rank → count + for _, c := range candidates { + if c.normalizedStartTime >= threshold { + stragglerSet[c.rank] = 1 + } + } + return buildStragglerResults(stragglerSet, numCandidates), nil + } + + // Multiple iterations: per-iteration thresholding, sum straggler counts. + stragglerCounts := make(map[int]int) // rank → count + iterCandidates := make(map[int][]metricEntry) + for _, c := range candidates { + iterCandidates[c.iteration] = append(iterCandidates[c.iteration], c) + } + for iter, group := range iterCandidates { + if iter < 0 { + continue + } + sort.Slice(group, func(i, j int) bool { + return group[i].normalizedStartTime > group[j].normalizedStartTime + }) + threshIdx := numCandidates - 1 + if threshIdx >= len(group) { + threshIdx = len(group) - 1 + } + threshold := group[threshIdx].normalizedStartTime + for _, c := range group { + if c.normalizedStartTime >= threshold { + stragglerCounts[c.rank]++ + } + } + } + + return buildStragglerResults(stragglerCounts, numCandidates), nil +} + +// buildStragglerResults converts the rank→count map to a sorted slice of top N results. +func buildStragglerResults(counts map[int]int, numCandidates int) []StragglerResult { + results := make([]StragglerResult, 0, len(counts)) + for rank, count := range counts { + if count > 0 { + results = append(results, StragglerResult{Rank: rank, Count: count}) + } + } + // Sort by count descending, then rank ascending for ties. + sort.Slice(results, func(i, j int) bool { + if results[i].Count != results[j].Count { + return results[i].Count > results[j].Count + } + return results[i].Rank < results[j].Rank + }) + if len(results) > numCandidates { + results = results[:numCandidates] + } + return results +} + +// assignIteration finds which profiler step a GPU kernel belongs to using binary search. +// Returns the step number, or -1 if the kernel doesn't fall within any step boundary. +func assignIteration(startedAt int64, boundaries []stepBoundary) int { + // boundaries are sorted by start time. + // Find the last boundary where start <= startedAt. + n := len(boundaries) + lo, hi := 0, n-1 + for lo <= hi { + mid := (lo + hi) / 2 + if boundaries[mid].start <= startedAt { + lo = mid + 1 + } else { + hi = mid - 1 + } + } + // hi is the index of the last boundary with start <= startedAt. + if hi >= 0 && startedAt <= boundaries[hi].end { + return boundaries[hi].step + } + return -1 +} + +// stddev computes sample standard deviation (ddof=1), matching pandas default. +// Returns NaN for slices with fewer than 2 elements. +func stddev(vals []float64) float64 { + n := len(vals) + if n < 2 { + return math.NaN() + } + var sum float64 + for _, v := range vals { + sum += v + } + mean := sum / float64(n) + var sumSqDiff float64 + for _, v := range vals { + diff := v - mean + sumSqDiff += diff * diff + } + return math.Sqrt(sumSqDiff / float64(n-1)) +} diff --git a/pkg/analysis/straggler/straggler_test.go b/pkg/analysis/straggler/straggler_test.go new file mode 100644 index 0000000..fcf8b94 --- /dev/null +++ b/pkg/analysis/straggler/straggler_test.go @@ -0,0 +1,138 @@ +package straggler + +import ( + "os" + "path/filepath" + "runtime" + "sort" + "testing" + + "hta/pkg/analysis" + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func testDataDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + // pkg/analysis/straggler/straggler_test.go → project root + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + dir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(dir); err != nil { + t.Skipf("test data not found: %s", dir) + } + return dir +} + +func TestPotentialStragglers(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + traceDir := testDataDir(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + db, err := store.Create(dbPath) + if err != nil { + t.Fatalf("create db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Fatalf("preprocess: %v", err) + } + + // Get available profiler steps (expected: 15-19 for vision_transformer). + allSteps, err := analysis.ProfilerSteps(db) + if err != nil { + t.Fatalf("profiler steps: %v", err) + } + + tests := []struct { + name string + steps []int + numCandidates int + wantRanks []int + }{ + { + name: "first step only, num_candidates=-1 (clamped to 1)", + steps: allSteps[:1], + numCandidates: -1, + wantRanks: []int{7}, + }, + { + name: "first step only, num_candidates=2", + steps: allSteps[:1], + numCandidates: 2, + wantRanks: []int{6, 7}, + }, + { + name: "all steps, num_candidates=2", + steps: allSteps, + numCandidates: 2, + wantRanks: []int{0, 1}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + results, err := PotentialStragglers(db, StragglerOpts{ + NumCandidates: tc.numCandidates, + ProfilerSteps: tc.steps, + }) + if err != nil { + t.Fatalf("PotentialStragglers: %v", err) + } + + gotRanks := make([]int, len(results)) + for i, r := range results { + gotRanks[i] = r.Rank + } + sort.Ints(gotRanks) + + if len(gotRanks) != len(tc.wantRanks) { + t.Fatalf("got ranks %v, want %v", gotRanks, tc.wantRanks) + } + for i := range gotRanks { + if gotRanks[i] != tc.wantRanks[i] { + t.Errorf("rank[%d]: got %d, want %d", i, gotRanks[i], tc.wantRanks[i]) + } + } + }) + } +} + +func TestPotentialStragglersSingleRank(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + _, thisFile, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + traceDir := filepath.Join(root, "tests", "data", "mtia_trace_single_rank") + + if _, err := os.Stat(traceDir); err != nil { + t.Skipf("test data not found: %s", traceDir) + } + + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := store.Create(dbPath) + if err != nil { + t.Fatalf("create db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Skipf("preprocess not supported for this trace: %v", err) + } + + results, err := PotentialStragglers(db, StragglerOpts{NumCandidates: 2}) + if err != nil { + t.Fatalf("PotentialStragglers: %v", err) + } + + // Single rank: no meaningful straggler detection possible. + if len(results) > 1 { + t.Errorf("expected at most 1 result for single rank, got %d", len(results)) + } +} diff --git a/pkg/analysis/temporal/idle_time.go b/pkg/analysis/temporal/idle_time.go new file mode 100644 index 0000000..3515cd6 --- /dev/null +++ b/pkg/analysis/temporal/idle_time.go @@ -0,0 +1,226 @@ +package temporal + +import ( + "database/sql" + "fmt" + "math" + "slices" + + "hta/pkg/analysis" + "hta/pkg/store" +) + +// IdleTimeOpts configures idle-time breakdown analysis. +type IdleTimeOpts struct { + Ranks []int // nil = all ranks + Streams []int // nil = all streams + ConsecutiveKernelDelay int64 // threshold in microseconds, default 30 + ShowIdleIntervalStats bool +} + +// IdleTimeResult holds idle-time breakdown for one (rank, stream, category) triple. +type IdleTimeResult struct { + Rank int + Stream int + IdleCategory string // "host_wait", "kernel_wait", "other" + IdleTimeUs int64 // total idle time in microseconds + IdleTimeRatio float64 // fraction of total idle time on this stream +} + +// IdleIntervalStats holds descriptive statistics for idle intervals. +type IdleIntervalStats struct { + Rank int + Stream int + IdleCategory string + Count int + Mean float64 + Std float64 + Min int64 + Pct25 float64 + Pct50 float64 + Pct75 float64 + Max int64 +} + +// IdleTimeBreakdown computes the idle-time breakdown for all (or selected) ranks. +func IdleTimeBreakdown(db *sql.DB, opts IdleTimeOpts) ([]IdleTimeResult, []IdleIntervalStats, error) { + if opts.ConsecutiveKernelDelay == 0 { + opts.ConsecutiveKernelDelay = 30 + } + + ranks := opts.Ranks + if len(ranks) == 0 { + var err error + ranks, err = store.Ranks(db) + if err != nil { + return nil, nil, fmt.Errorf("loading ranks: %w", err) + } + } + + var allResults []IdleTimeResult + var allStats []IdleIntervalStats + for _, rank := range ranks { + results, stats, err := idleTimePerRank(db, rank, opts) + if err != nil { + return nil, nil, fmt.Errorf("rank %d: %w", rank, err) + } + allResults = append(allResults, results...) + allStats = append(allStats, stats...) + } + return allResults, allStats, nil +} + +func idleTimePerRank(db *sql.DB, rank int, opts IdleTimeOpts) ([]IdleTimeResult, []IdleIntervalStats, error) { + kernels, err := store.LoadGPUKernelsForIdleTime(db, rank, opts.Streams) + if err != nil { + return nil, nil, err + } + if len(kernels) == 0 { + return nil, nil, nil + } + + // Group kernels by stream (they come sorted by stream, then started_at) + streamGroups := make(map[int][]store.GPUKernelIdleTimeRow) + var streamOrder []int + for _, k := range kernels { + if _, exists := streamGroups[k.CUDAStreamID]; !exists { + streamOrder = append(streamOrder, k.CUDAStreamID) + } + streamGroups[k.CUDAStreamID] = append(streamGroups[k.CUDAStreamID], k) + } + + var allResults []IdleTimeResult + var allStats []IdleIntervalStats + for _, stream := range streamOrder { + results, stats := classifyIdleTimeForStream(rank, stream, streamGroups[stream], opts) + allResults = append(allResults, results...) + allStats = append(allStats, stats...) + } + return allResults, allStats, nil +} + +const ( + categoryHostWait = "host_wait" + categoryKernelWait = "kernel_wait" + categoryOther = "other" +) + +func classifyIdleTimeForStream( + rank, stream int, + kernels []store.GPUKernelIdleTimeRow, + opts IdleTimeOpts, +) ([]IdleTimeResult, []IdleIntervalStats) { + // Accumulate idle time per category + var hostWaitSum, kernelWaitSum, otherSum int64 + var hostWaitGaps, kernelWaitGaps, otherGaps []int64 + + for i := 1; i < len(kernels); i++ { + gap := kernels[i].StartedAt - kernels[i-1].EndedAt + if gap <= 0 { + continue // overlapping kernels + } + + // Classify the gap + if kernels[i].RuntimeTs > kernels[i-1].EndedAt { + hostWaitSum += gap + if opts.ShowIdleIntervalStats { + hostWaitGaps = append(hostWaitGaps, gap) + } + } else if gap < opts.ConsecutiveKernelDelay { + kernelWaitSum += gap + if opts.ShowIdleIntervalStats { + kernelWaitGaps = append(kernelWaitGaps, gap) + } + } else { + otherSum += gap + if opts.ShowIdleIntervalStats { + otherGaps = append(otherGaps, gap) + } + } + } + + totalIdle := hostWaitSum + kernelWaitSum + otherSum + + ratioOf := func(v int64) float64 { + if totalIdle == 0 { + return 0 + } + return analysis.RoundTo(float64(v)/float64(totalIdle), 2) + } + + results := []IdleTimeResult{ + {Rank: rank, Stream: stream, IdleCategory: categoryHostWait, IdleTimeUs: hostWaitSum, IdleTimeRatio: ratioOf(hostWaitSum)}, + {Rank: rank, Stream: stream, IdleCategory: categoryKernelWait, IdleTimeUs: kernelWaitSum, IdleTimeRatio: ratioOf(kernelWaitSum)}, + {Rank: rank, Stream: stream, IdleCategory: categoryOther, IdleTimeUs: otherSum, IdleTimeRatio: ratioOf(otherSum)}, + } + + var stats []IdleIntervalStats + if opts.ShowIdleIntervalStats { + stats = []IdleIntervalStats{ + computeIntervalStats(rank, stream, categoryHostWait, hostWaitGaps), + computeIntervalStats(rank, stream, categoryKernelWait, kernelWaitGaps), + computeIntervalStats(rank, stream, categoryOther, otherGaps), + } + } + + return results, stats +} + +func computeIntervalStats(rank, stream int, category string, gaps []int64) IdleIntervalStats { + s := IdleIntervalStats{ + Rank: rank, + Stream: stream, + IdleCategory: category, + Count: len(gaps), + } + if len(gaps) == 0 { + return s + } + + slices.Sort(gaps) + + s.Min = gaps[0] + s.Max = gaps[len(gaps)-1] + + var sum int64 + for _, g := range gaps { + sum += g + } + mean := float64(sum) / float64(len(gaps)) + s.Mean = analysis.RoundTo(mean, 2) + + var variance float64 + for _, g := range gaps { + d := float64(g) - mean + variance += d * d + } + // Use sample std (ddof=1) matching pandas default + if len(gaps) > 1 { + s.Std = analysis.RoundTo(math.Sqrt(variance/float64(len(gaps)-1)), 2) + } + + s.Pct25 = analysis.RoundTo(percentile(gaps, 0.25), 2) + s.Pct50 = analysis.RoundTo(percentile(gaps, 0.50), 2) + s.Pct75 = analysis.RoundTo(percentile(gaps, 0.75), 2) + + return s +} + +// percentile computes the p-th percentile using linear interpolation (matching numpy/pandas). +func percentile(sorted []int64, p float64) float64 { + n := len(sorted) + if n == 0 { + return 0 + } + if n == 1 { + return float64(sorted[0]) + } + idx := p * float64(n-1) + lo := int(idx) + hi := lo + 1 + if hi >= n { + return float64(sorted[n-1]) + } + frac := idx - float64(lo) + return float64(sorted[lo])*(1-frac) + float64(sorted[hi])*frac +} diff --git a/pkg/analysis/temporal/idle_time_test.go b/pkg/analysis/temporal/idle_time_test.go new file mode 100644 index 0000000..187fdc7 --- /dev/null +++ b/pkg/analysis/temporal/idle_time_test.go @@ -0,0 +1,128 @@ +package temporal + +import ( + "math" + "path/filepath" + "testing" + + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func TestIdleTimeBreakdownIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + traceDir := testDataDir(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + db, err := store.Create(dbPath) + if err != nil { + t.Fatalf("create db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Fatalf("preprocess: %v", err) + } + + opts := IdleTimeOpts{ + ConsecutiveKernelDelay: 30, + ShowIdleIntervalStats: true, + } + results, stats, err := IdleTimeBreakdown(db, opts) + if err != nil { + t.Fatalf("idle time breakdown: %v", err) + } + + if len(results) == 0 { + t.Fatal("expected non-empty results") + } + + // Python reference values for rank 0 (from running Python idle-time-breakdown) + type expected struct { + rank int + stream int + category string + idleTime int64 + } + expectations := []expected{ + {0, 7, "host_wait", 1000581}, + {0, 7, "kernel_wait", 30984}, + {0, 7, "other", 385099}, + {0, 20, "host_wait", 1694651}, + {0, 24, "host_wait", 1382259}, + {0, 24, "other", 420773}, + {1, 7, "host_wait", 942050}, + {1, 7, "kernel_wait", 36343}, + {1, 7, "other", 437053}, + } + + findResult := func(rank, stream int, category string) *IdleTimeResult { + for i := range results { + if results[i].Rank == rank && results[i].Stream == stream && results[i].IdleCategory == category { + return &results[i] + } + } + return nil + } + + for _, exp := range expectations { + r := findResult(exp.rank, exp.stream, exp.category) + if r == nil { + t.Errorf("result not found: rank=%d stream=%d category=%s", exp.rank, exp.stream, exp.category) + continue + } + if r.IdleTimeUs != exp.idleTime { + t.Errorf("rank=%d stream=%d category=%s: idle_time got %d, want %d", + exp.rank, exp.stream, exp.category, r.IdleTimeUs, exp.idleTime) + } + } + + // Verify ratios sum to ~1.0 per stream for each rank + type streamKey struct { + rank, stream int + } + ratioSums := make(map[streamKey]float64) + for _, r := range results { + ratioSums[streamKey{r.Rank, r.Stream}] += r.IdleTimeRatio + } + for key, sum := range ratioSums { + if math.Abs(sum-1.0) > 0.02 { + t.Errorf("rank=%d stream=%d: ratio sum = %.4f, want ~1.0", key.rank, key.stream, sum) + } + } + + // Verify we have 6 unique streams per rank + streamsPerRank := make(map[int]map[int]bool) + for _, r := range results { + if streamsPerRank[r.Rank] == nil { + streamsPerRank[r.Rank] = make(map[int]bool) + } + streamsPerRank[r.Rank][r.Stream] = true + } + for rank, streams := range streamsPerRank { + if len(streams) != 6 { + t.Errorf("rank %d: expected 6 streams, got %d", rank, len(streams)) + } + } + + // Verify interval stats are non-empty when requested + if len(stats) == 0 { + t.Error("expected non-empty interval stats when ShowIdleIntervalStats=true") + } + + // Check a specific stat: rank 0, stream 7, host_wait + for _, s := range stats { + if s.Rank == 0 && s.Stream == 7 && s.IdleCategory == "host_wait" { + if s.Count == 0 { + t.Error("rank=0 stream=7 host_wait: expected non-zero count") + } + if s.Mean <= 0 { + t.Error("rank=0 stream=7 host_wait: expected positive mean") + } + break + } + } +} diff --git a/pkg/analysis/temporal/overlap.go b/pkg/analysis/temporal/overlap.go new file mode 100644 index 0000000..15fa294 --- /dev/null +++ b/pkg/analysis/temporal/overlap.go @@ -0,0 +1,126 @@ +package temporal + +import ( + "database/sql" + "fmt" + "sort" + + "hta/pkg/analysis" + "hta/pkg/store" +) + +// OverlapResult holds the comm-comp overlap percentage for a single rank. +type OverlapResult struct { + Rank int + OverlapPctg float64 +} + +// CommCompOverlap computes the overlap between communication and computation +// GPU kernels for all ranks using a sweep-line algorithm. +func CommCompOverlap(db *sql.DB) ([]OverlapResult, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + ranks, err := store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + + results := make([]OverlapResult, 0, len(ranks)) + for _, rank := range ranks { + r, err := commCompOverlapPerRank(db, sym, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + results = append(results, r) + } + return results, nil +} + +// statusEvent represents a point event for the sweep-line algorithm. +// delta is +1/-1 for comm start/end, +2/-2 for comp start/end. +type statusEvent struct { + time int64 + delta int +} + +func commCompOverlapPerRank(db *sql.DB, sym interface{ GetName(int) (string, error) }, rank int) (OverlapResult, error) { + kernels, err := store.LoadGPUKernels(db, rank) + if err != nil { + return OverlapResult{}, err + } + + // Classify kernels into comm and comp intervals. + var commIntervals, compIntervals []analysis.Interval + for _, k := range kernels { + name, err := sym.GetName(k.NameID) + if err != nil { + return OverlapResult{}, fmt.Errorf("symbol lookup: %w", err) + } + iv := analysis.Interval{Start: k.StartedAt, End: k.EndedAt} + switch analysis.ClassifyKernel(name) { + case analysis.KernelCommunication: + commIntervals = append(commIntervals, iv) + case analysis.KernelComputation: + compIntervals = append(compIntervals, iv) + } + } + + // Already sorted by started_at from SQL ORDER BY. + mergedComm := analysis.MergeIntervals(commIntervals) + if len(mergedComm) == 0 { + return OverlapResult{Rank: rank, OverlapPctg: 0}, nil + } + + sort.Slice(compIntervals, func(i, j int) bool { + return compIntervals[i].Start < compIntervals[j].Start + }) + mergedComp := analysis.MergeIntervals(compIntervals) + if len(mergedComp) == 0 { + return OverlapResult{Rank: rank, OverlapPctg: 0}, nil + } + + // Total comm time (denominator). + var totalCommTime int64 + for _, iv := range mergedComm { + totalCommTime += iv.End - iv.Start + } + if totalCommTime == 0 { + return OverlapResult{Rank: rank, OverlapPctg: 0}, nil + } + + // Build sweep-line events. + events := make([]statusEvent, 0, 2*(len(mergedComm)+len(mergedComp))) + for _, iv := range mergedComm { + events = append(events, statusEvent{time: iv.Start, delta: 1}) + events = append(events, statusEvent{time: iv.End, delta: -1}) + } + for _, iv := range mergedComp { + events = append(events, statusEvent{time: iv.Start, delta: 2}) + events = append(events, statusEvent{time: iv.End, delta: -2}) + } + + // Sort by time; tie-break: end-events (negative delta) before start-events + // to avoid false overlap at boundaries. + sort.SliceStable(events, func(i, j int) bool { + if events[i].time != events[j].time { + return events[i].time < events[j].time + } + return events[i].delta < events[j].delta + }) + + // Sweep: running status == 3 means both comm and comp are active. + var running int + var overlapTime int64 + for i, ev := range events { + running += ev.delta + if running == 3 && i+1 < len(events) { + overlapTime += events[i+1].time - ev.time + } + } + + pctg := analysis.RoundTo(100*float64(overlapTime)/float64(totalCommTime), 2) + return OverlapResult{Rank: rank, OverlapPctg: pctg}, nil +} diff --git a/pkg/analysis/temporal/overlap_test.go b/pkg/analysis/temporal/overlap_test.go new file mode 100644 index 0000000..4fcdec2 --- /dev/null +++ b/pkg/analysis/temporal/overlap_test.go @@ -0,0 +1,65 @@ +package temporal + +import ( + "math" + "path/filepath" + "testing" + + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func TestCommCompOverlapIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + traceDir := testDataDir(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + db, err := store.Create(dbPath) + if err != nil { + t.Fatalf("create db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Fatalf("preprocess: %v", err) + } + + results, err := CommCompOverlap(db) + if err != nil { + t.Fatalf("comm comp overlap: %v", err) + } + + if len(results) < 2 { + t.Fatalf("expected at least 2 ranks, got %d", len(results)) + } + + // Expected values from Python reference. + expectations := []struct { + rank int + overlapPctg float64 + }{ + {0, 22.01}, + {1, 21.28}, + } + + for _, exp := range expectations { + var r *OverlapResult + for i := range results { + if results[i].Rank == exp.rank { + r = &results[i] + break + } + } + if r == nil { + t.Errorf("rank %d not found in results", exp.rank) + continue + } + + if math.Abs(r.OverlapPctg-exp.overlapPctg) > 0.05 { + t.Errorf("rank %d overlap_pctg: got %.2f, want %.2f", exp.rank, r.OverlapPctg, exp.overlapPctg) + } + } +} diff --git a/pkg/analysis/temporal/temporal.go b/pkg/analysis/temporal/temporal.go new file mode 100644 index 0000000..09b64bb --- /dev/null +++ b/pkg/analysis/temporal/temporal.go @@ -0,0 +1,103 @@ +package temporal + +import ( + "database/sql" + "fmt" + "sort" + + "hta/pkg/analysis" + "hta/pkg/store" + "hta/pkg/symbol" +) + +// TemporalResult holds the temporal breakdown for a single rank. +type TemporalResult struct { + Rank int + IdleTimeUs int64 + ComputeTimeUs int64 + NonComputeTimeUs int64 + KernelTimeUs int64 + IdleTimePctg float64 + ComputeTimePctg float64 + NonComputeTimePctg float64 +} + +// TemporalBreakdown computes the temporal breakdown for all ranks in the DB. +func TemporalBreakdown(db *sql.DB) ([]TemporalResult, error) { + sym, err := store.LoadSymbolTable(db) + if err != nil { + return nil, fmt.Errorf("loading symbol table: %w", err) + } + + ranks, err := store.Ranks(db) + if err != nil { + return nil, fmt.Errorf("loading ranks: %w", err) + } + + results := make([]TemporalResult, 0, len(ranks)) + for _, rank := range ranks { + r, err := temporalBreakdownPerRank(db, sym, rank) + if err != nil { + return nil, fmt.Errorf("rank %d: %w", rank, err) + } + results = append(results, r) + } + return results, nil +} + +func temporalBreakdownPerRank(db *sql.DB, sym *symbol.Table, rank int) (TemporalResult, error) { + kernels, err := store.LoadGPUKernels(db, rank) + if err != nil { + return TemporalResult{}, err + } + if len(kernels) == 0 { + return TemporalResult{Rank: rank}, nil + } + + // Build intervals for all GPU kernels and classify each one + allIntervals := make([]analysis.Interval, len(kernels)) + var computeIntervals []analysis.Interval + + for i, k := range kernels { + allIntervals[i] = analysis.Interval{Start: k.StartedAt, End: k.EndedAt} + + name, err := sym.GetName(k.NameID) + if err != nil { + return TemporalResult{}, fmt.Errorf("symbol lookup: %w", err) + } + if analysis.ClassifyKernel(name) == analysis.KernelComputation { + computeIntervals = append(computeIntervals, analysis.Interval{Start: k.StartedAt, End: k.EndedAt}) + } + } + + // Already sorted by started_at from SQL ORDER BY + mergedAll := analysis.MergeIntervals(allIntervals) + kernelTime := mergedAll[len(mergedAll)-1].End - mergedAll[0].Start + var kernelRunTime int64 + for _, iv := range mergedAll { + kernelRunTime += iv.End - iv.Start + } + idleTime := kernelTime - kernelRunTime + + // Merge compute intervals + sort.Slice(computeIntervals, func(i, j int) bool { + return computeIntervals[i].Start < computeIntervals[j].Start + }) + mergedCompute := analysis.MergeIntervals(computeIntervals) + var computeTime int64 + for _, iv := range mergedCompute { + computeTime += iv.End - iv.Start + } + nonComputeTime := kernelTime - computeTime - idleTime + + return TemporalResult{ + Rank: rank, + IdleTimeUs: idleTime, + ComputeTimeUs: computeTime, + NonComputeTimeUs: nonComputeTime, + KernelTimeUs: kernelTime, + IdleTimePctg: analysis.RoundTo(100*float64(idleTime)/float64(kernelTime), 2), + ComputeTimePctg: analysis.RoundTo(100*float64(computeTime)/float64(kernelTime), 2), + NonComputeTimePctg: analysis.RoundTo(100*float64(nonComputeTime)/float64(kernelTime), 2), + }, nil +} diff --git a/pkg/analysis/temporal/temporal_test.go b/pkg/analysis/temporal/temporal_test.go new file mode 100644 index 0000000..2536a80 --- /dev/null +++ b/pkg/analysis/temporal/temporal_test.go @@ -0,0 +1,105 @@ +package temporal + +import ( + "math" + "os" + "path/filepath" + "runtime" + "testing" + + "hta/pkg/pipeline" + "hta/pkg/store" +) + +func testDataDir(t *testing.T) string { + t.Helper() + _, thisFile, _, _ := runtime.Caller(0) + // pkg/analysis/temporal/temporal_test.go → project root + root := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + dir := filepath.Join(root, "tests", "data", "vision_transformer") + if _, err := os.Stat(dir); err != nil { + t.Skipf("test data not found: %s", dir) + } + return dir +} + +func TestTemporalBreakdownIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + traceDir := testDataDir(t) + dbPath := filepath.Join(t.TempDir(), "test.db") + + db, err := store.Create(dbPath) + if err != nil { + t.Fatalf("create db: %v", err) + } + defer db.Close() + + if err := pipeline.RunWithDB(traceDir, db); err != nil { + t.Fatalf("preprocess: %v", err) + } + + results, err := TemporalBreakdown(db) + if err != nil { + t.Fatalf("temporal breakdown: %v", err) + } + + if len(results) < 2 { + t.Fatalf("expected at least 2 ranks, got %d", len(results)) + } + + // Expected values from the plan (Python reference) + expectations := []struct { + rank int + idleTimeUs int64 + computeTimeUs int64 + nonComputeTimeUs int64 + kernelTimeUs int64 + }{ + {0, 552069, 596651, 884850, 2033570}, + {1, 431771, 596759, 1004227, 2032757}, + } + + for _, exp := range expectations { + var r *TemporalResult + for i := range results { + if results[i].Rank == exp.rank { + r = &results[i] + break + } + } + if r == nil { + t.Errorf("rank %d not found in results", exp.rank) + continue + } + + if r.IdleTimeUs != exp.idleTimeUs { + t.Errorf("rank %d idle_time: got %d, want %d", exp.rank, r.IdleTimeUs, exp.idleTimeUs) + } + if r.ComputeTimeUs != exp.computeTimeUs { + t.Errorf("rank %d compute_time: got %d, want %d", exp.rank, r.ComputeTimeUs, exp.computeTimeUs) + } + if r.NonComputeTimeUs != exp.nonComputeTimeUs { + t.Errorf("rank %d non_compute_time: got %d, want %d", exp.rank, r.NonComputeTimeUs, exp.nonComputeTimeUs) + } + if r.KernelTimeUs != exp.kernelTimeUs { + t.Errorf("rank %d kernel_time: got %d, want %d", exp.rank, r.KernelTimeUs, exp.kernelTimeUs) + } + + // Check percentages within tolerance + expIdlePctg := math.Round(float64(exp.idleTimeUs)*10000/float64(exp.kernelTimeUs)) / 100 + if math.Abs(r.IdleTimePctg-expIdlePctg) > 0.01 { + t.Errorf("rank %d idle_time_pctg: got %.2f, want %.2f", exp.rank, r.IdleTimePctg, expIdlePctg) + } + expComputePctg := math.Round(float64(exp.computeTimeUs)*10000/float64(exp.kernelTimeUs)) / 100 + if math.Abs(r.ComputeTimePctg-expComputePctg) > 0.01 { + t.Errorf("rank %d compute_time_pctg: got %.2f, want %.2f", exp.rank, r.ComputeTimePctg, expComputePctg) + } + expNonComputePctg := math.Round(float64(exp.nonComputeTimeUs)*10000/float64(exp.kernelTimeUs)) / 100 + if math.Abs(r.NonComputeTimePctg-expNonComputePctg) > 0.01 { + t.Errorf("rank %d non_compute_time_pctg: got %.2f, want %.2f", exp.rank, r.NonComputeTimePctg, expNonComputePctg) + } + } +} diff --git a/pkg/pipeline/preprocess.go b/pkg/pipeline/preprocess.go new file mode 100644 index 0000000..7fb1040 --- /dev/null +++ b/pkg/pipeline/preprocess.go @@ -0,0 +1,351 @@ +package pipeline + +import ( + "database/sql" + "fmt" + "log" + "math" + "regexp" + + "hta/pkg/store" + "hta/pkg/symbol" + "hta/pkg/trace" +) + +var profilerStepRe = regexp.MustCompile(`^ProfilerStep#\d+`) + +// Run executes the full pre-process pipeline: +// discover files → parse → build symbols → align → filter → write DB. +func Run(traceDir, dbPath string) error { + // Step 1: Discover trace files + files, err := trace.DiscoverFiles(traceDir) + if err != nil { + return err + } + if len(files) == 0 { + return fmt.Errorf("no trace files found in %s", traceDir) + } + log.Printf("Found %d trace file(s)", len(files)) + + // Step 2: Parse each file + rankTraces := make([]*trace.RankTrace, 0, len(files)) + for _, f := range files { + rt, err := trace.ParseFile(f) + if err != nil { + return fmt.Errorf("parsing %s: %w", f, err) + } + log.Printf("Rank %d: %d events from %s", rt.Meta.Rank, len(rt.Events), f) + rankTraces = append(rankTraces, rt) + } + + // Step 3: Build global symbol table + sym := symbol.NewTable() + for _, rt := range rankTraces { + for i := range rt.Events { + sym.Add(rt.Events[i].Cat) + sym.Add(rt.Events[i].Name) + } + } + log.Printf("Symbol table: %d symbols", sym.Len()) + + // Step 4: Align timestamps — find global min_ts + var minTs int64 = math.MaxInt64 + for _, rt := range rankTraces { + for _, e := range rt.Events { + if e.Ts < minTs { + minTs = e.Ts + } + } + } + log.Printf("Global min_ts: %d", minTs) + + // Step 5: Filter irrelevant GPU kernels + // Find profiler step name IDs + var profilerStepNameIDs []int + for _, s := range sym.All() { + if profilerStepRe.MatchString(s.Name) { + profilerStepNameIDs = append(profilerStepNameIDs, s.ID) + } + } + + needFiltering := len(profilerStepNameIDs) >= 2 + if !needFiltering { + log.Printf("Skipping GPU kernel filtering (%d profiler step name(s))", len(profilerStepNameIDs)) + } + + cuptiCatID := sym.GetID("cuda_profiler_range") // -1 if not present + + // Build store events, applying filter if needed + var allEvents []store.Event + for _, rt := range rankTraces { + events := buildStoreEvents(rt, sym, minTs) + if needFiltering { + events = filterGPUKernels(events, profilerStepNameIDs, cuptiCatID) + } + allEvents = append(allEvents, events...) + } + log.Printf("Total events after filtering: %d", len(allEvents)) + + // Step 6: Write to SQLite + db, err := store.Create(dbPath) + if err != nil { + return fmt.Errorf("creating db: %w", err) + } + defer db.Close() + + if err := store.CreateTables(db); err != nil { + return fmt.Errorf("creating tables: %w", err) + } + + tx, err := db.Begin() + if err != nil { + return err + } + + // Insert symbols + symRows := make([]struct{ ID int; Name string }, sym.Len()) + for _, s := range sym.All() { + symRows[s.ID] = struct{ ID int; Name string }{s.ID, s.Name} + } + if err := store.InsertSymbols(tx, symRows); err != nil { + tx.Rollback() + return fmt.Errorf("inserting symbols: %w", err) + } + + // Insert metadata + for _, rt := range rankTraces { + m := rt.Meta + if err := store.InsertMetadata(tx, m.Rank, m.TraceFile, m.Backend, m.SchemaVersion, m.WorldSize, m.DeviceName, m.DeviceTotalMem, m.DeviceCompCap); err != nil { + tx.Rollback() + return fmt.Errorf("inserting metadata rank %d: %w", m.Rank, err) + } + } + + // Insert events + firstID, err := store.InsertEvents(tx, allEvents) + if err != nil { + tx.Rollback() + return fmt.Errorf("inserting events: %w", err) + } + + // Persist extra event args (e.g. CUPTI counters). + var eventArgs []store.EventArg + for i, e := range allEvents { + if len(e.ExtraArgs) == 0 { + continue + } + eventID := firstID + int64(i) + for k, v := range e.ExtraArgs { + eventArgs = append(eventArgs, store.EventArg{EventID: eventID, Key: k, Value: v}) + } + } + if err := store.InsertEventArgs(tx, eventArgs); err != nil { + tx.Rollback() + return fmt.Errorf("inserting event args: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit: %w", err) + } + + if err := store.CreateIndexes(db); err != nil { + return fmt.Errorf("creating indexes: %w", err) + } + + log.Printf("Wrote %s", dbPath) + return nil +} + +// buildStoreEvents converts raw events to store events with aligned timestamps. +func buildStoreEvents(rt *trace.RankTrace, sym *symbol.Table, minTs int64) []store.Event { + events := make([]store.Event, 0, len(rt.Events)) + for _, e := range rt.Events { + startedAt := e.Ts - minTs + deviceType := "cpu" + if e.Stream >= 0 { + deviceType = "gpu" + } else if e.Cat == "cuda_profiler_range" { + deviceType = "gpu" + } + events = append(events, store.Event{ + Rank: rt.Meta.Rank, + DeviceType: deviceType, + StartedAt: startedAt, + Duration: e.Dur, + EndedAt: startedAt + e.Dur, + CategoryID: sym.GetID(e.Cat), + NameID: sym.GetID(e.Name), + ProcessID: e.Pid, + ThreadID: e.Tid, + CUDAStream: e.Stream, + Correlation: e.Correlation, + MemoryBWGbps: e.MemoryBWGbps, + WaitOnCudaEventRecordCorrID: e.WaitOnCudaEventRecordCorrID, + WaitOnStream: e.WaitOnStream, + ExtraArgs: e.ExtraArgs, + }) + } + return events +} + +// filterGPUKernels filters out GPU events whose correlation does not match +// a CPU event that started before the last profiler step's start time. +// CUPTI profiler range events (which lack correlation) are always retained. +func filterGPUKernels(events []store.Event, profilerStepNameIDs []int, cuptiCatID int) []store.Event { + profilerStepSet := make(map[int]bool, len(profilerStepNameIDs)) + for _, id := range profilerStepNameIDs { + profilerStepSet[id] = true + } + + // Identify CPU events (stream < 0 or correlation < 0, essentially device_type == "cpu") + // Find last profiler step start time among CPU events + var lastProfilerStart int64 = -1 + for _, e := range events { + if e.DeviceType == "cpu" && profilerStepSet[e.NameID] { + if e.StartedAt > lastProfilerStart { + lastProfilerStart = e.StartedAt + } + } + } + + if lastProfilerStart < 0 { + // No profiler steps found in CPU events; skip filtering + return events + } + + // Collect correlation IDs from CPU events that started before the last profiler step + cpuCorrelations := make(map[int]bool) + for _, e := range events { + if e.DeviceType == "cpu" && e.StartedAt < lastProfilerStart && e.Correlation >= 0 { + cpuCorrelations[e.Correlation] = true + } + } + + // Keep: all CPU events with ts < lastProfilerStart, plus GPU events with matching correlation. + // Always keep CUPTI profiler range events (they have no correlation). + filtered := make([]store.Event, 0, len(events)) + for _, e := range events { + if e.DeviceType == "cpu" { + if e.StartedAt < lastProfilerStart { + filtered = append(filtered, e) + } + } else { + // GPU event: keep if correlation matches, or if it's a CUPTI profiler range event + if cuptiCatID >= 0 && e.CategoryID == cuptiCatID { + filtered = append(filtered, e) + } else if e.Correlation >= 0 && cpuCorrelations[e.Correlation] { + filtered = append(filtered, e) + } + } + } + return filtered +} + +// RunWithDB is like Run but accepts an already-opened DB (useful for tests). +func RunWithDB(traceDir string, db *sql.DB) error { + files, err := trace.DiscoverFiles(traceDir) + if err != nil { + return err + } + if len(files) == 0 { + return fmt.Errorf("no trace files found in %s", traceDir) + } + + rankTraces := make([]*trace.RankTrace, 0, len(files)) + for _, f := range files { + rt, err := trace.ParseFile(f) + if err != nil { + return fmt.Errorf("parsing %s: %w", f, err) + } + rankTraces = append(rankTraces, rt) + } + + sym := symbol.NewTable() + for _, rt := range rankTraces { + for i := range rt.Events { + sym.Add(rt.Events[i].Cat) + sym.Add(rt.Events[i].Name) + } + } + + var minTs int64 = math.MaxInt64 + for _, rt := range rankTraces { + for _, e := range rt.Events { + if e.Ts < minTs { + minTs = e.Ts + } + } + } + + var profilerStepNameIDs []int + for _, s := range sym.All() { + if profilerStepRe.MatchString(s.Name) { + profilerStepNameIDs = append(profilerStepNameIDs, s.ID) + } + } + + needFiltering := len(profilerStepNameIDs) >= 2 + cuptiCatID := sym.GetID("cuda_profiler_range") + + var allEvents []store.Event + for _, rt := range rankTraces { + events := buildStoreEvents(rt, sym, minTs) + if needFiltering { + events = filterGPUKernels(events, profilerStepNameIDs, cuptiCatID) + } + allEvents = append(allEvents, events...) + } + + if err := store.CreateTables(db); err != nil { + return err + } + + tx, err := db.Begin() + if err != nil { + return err + } + + symRows := make([]struct{ ID int; Name string }, sym.Len()) + for _, s := range sym.All() { + symRows[s.ID] = struct{ ID int; Name string }{s.ID, s.Name} + } + if err := store.InsertSymbols(tx, symRows); err != nil { + tx.Rollback() + return err + } + + for _, rt := range rankTraces { + m := rt.Meta + if err := store.InsertMetadata(tx, m.Rank, m.TraceFile, m.Backend, m.SchemaVersion, m.WorldSize, m.DeviceName, m.DeviceTotalMem, m.DeviceCompCap); err != nil { + tx.Rollback() + return err + } + } + + firstID, err := store.InsertEvents(tx, allEvents) + if err != nil { + tx.Rollback() + return err + } + + var eventArgs []store.EventArg + for i, e := range allEvents { + if len(e.ExtraArgs) == 0 { + continue + } + eventID := firstID + int64(i) + for k, v := range e.ExtraArgs { + eventArgs = append(eventArgs, store.EventArg{EventID: eventID, Key: k, Value: v}) + } + } + if err := store.InsertEventArgs(tx, eventArgs); err != nil { + tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + return store.CreateIndexes(db) +} diff --git a/pkg/store/reader.go b/pkg/store/reader.go new file mode 100644 index 0000000..92e4e31 --- /dev/null +++ b/pkg/store/reader.go @@ -0,0 +1,889 @@ +package store + +import ( + "database/sql" + "fmt" + "strings" + + "hta/pkg/symbol" +) + +// LoadSymbolTable reads all symbols from the DB into a SymbolTable. +func LoadSymbolTable(db *sql.DB) (*symbol.Table, error) { + rows, err := db.Query("SELECT id, name FROM symbol_name ORDER BY id") + if err != nil { + return nil, err + } + defer rows.Close() + + // Collect all rows first to ensure order + type entry struct { + id int + name string + } + var entries []entry + for rows.Next() { + var e entry + if err := rows.Scan(&e.id, &e.name); err != nil { + return nil, err + } + entries = append(entries, e) + } + if err := rows.Err(); err != nil { + return nil, err + } + + tbl := symbol.NewTable() + for _, e := range entries { + got := tbl.Add(e.name) + if got != e.id { + return nil, fmt.Errorf("symbol table ID mismatch: expected %d, got %d for %q", e.id, got, e.name) + } + } + return tbl, nil +} + +// GPUKernelRow holds a single GPU event for analysis. +type GPUKernelRow struct { + StartedAt int64 + EndedAt int64 + NameID int +} + +// LoadGPUKernels returns all GPU events for a rank, sorted by started_at. +func LoadGPUKernels(db *sql.DB, rank int) ([]GPUKernelRow, error) { + rows, err := db.Query( + `SELECT started_at, ended_at, name FROM trace_event + WHERE gpu_rank = ? AND device_type = 'gpu' + ORDER BY started_at`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var kernels []GPUKernelRow + for rows.Next() { + var k GPUKernelRow + if err := rows.Scan(&k.StartedAt, &k.EndedAt, &k.NameID); err != nil { + return nil, err + } + kernels = append(kernels, k) + } + return kernels, rows.Err() +} + +// GPUKernelFullRow holds a GPU event with rank and stream information. +type GPUKernelFullRow struct { + Rank int + StartedAt int64 + EndedAt int64 + NameID int + StreamID int +} + +// LoadAllGPUKernelsFull returns all GPU events across all ranks, sorted by (rank, started_at). +func LoadAllGPUKernelsFull(db *sql.DB) ([]GPUKernelFullRow, error) { + rows, err := db.Query( + `SELECT gpu_rank, started_at, ended_at, name, cuda_stream_id + FROM trace_event WHERE device_type = 'gpu' + ORDER BY gpu_rank, started_at`, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var kernels []GPUKernelFullRow + for rows.Next() { + var k GPUKernelFullRow + if err := rows.Scan(&k.Rank, &k.StartedAt, &k.EndedAt, &k.NameID, &k.StreamID); err != nil { + return nil, err + } + kernels = append(kernels, k) + } + return kernels, rows.Err() +} + +// ProfilerStepEvent holds a CPU profiler step event with rank information. +type ProfilerStepEvent struct { + Rank int + StartedAt int64 + EndedAt int64 + NameID int +} + +// LoadProfilerStepEvents returns CPU profiler step events for the given name IDs, +// sorted by (rank, started_at). +func LoadProfilerStepEvents(db *sql.DB, profilerStepNameIDs []int) ([]ProfilerStepEvent, error) { + if len(profilerStepNameIDs) == 0 { + return nil, nil + } + + // Build IN clause with placeholders. + placeholders := make([]string, len(profilerStepNameIDs)) + args := make([]any, len(profilerStepNameIDs)) + for i, id := range profilerStepNameIDs { + placeholders[i] = "?" + args[i] = id + } + + query := `SELECT gpu_rank, started_at, ended_at, name FROM trace_event + WHERE device_type = 'cpu' AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY gpu_rank, started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var events []ProfilerStepEvent + for rows.Next() { + var e ProfilerStepEvent + if err := rows.Scan(&e.Rank, &e.StartedAt, &e.EndedAt, &e.NameID); err != nil { + return nil, err + } + events = append(events, e) + } + return events, rows.Err() +} + +// CPURuntimeRow holds a CPU runtime launch event with its correlation ID. +type CPURuntimeRow struct { + Correlation int + StartedAt int64 + Duration int64 + NameID int +} + +// LoadCPURuntimeEvents returns CPU events for a rank filtered by the given +// name IDs (e.g. cudaLaunchKernel symbol IDs) and having a valid correlation. +func LoadCPURuntimeEvents(db *sql.DB, rank int, nameIDs []int) ([]CPURuntimeRow, error) { + if len(nameIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(nameIDs)) + args := make([]any, 0, len(nameIDs)+1) + args = append(args, rank) + for i, id := range nameIDs { + placeholders[i] = "?" + args = append(args, id) + } + query := `SELECT correlation, started_at, duration, name FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' + AND correlation != -1 + AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CPURuntimeRow + for rows.Next() { + var r CPURuntimeRow + if err := rows.Scan(&r.Correlation, &r.StartedAt, &r.Duration, &r.NameID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// GPUKernelCorrelationRow holds a GPU kernel event with its correlation ID. +type GPUKernelCorrelationRow struct { + Correlation int + StartedAt int64 + Duration int64 + NameID int +} + +// LoadGPUKernelsWithCorrelation returns all GPU events for a rank that have +// a valid correlation ID, sorted by started_at. +func LoadGPUKernelsWithCorrelation(db *sql.DB, rank int) ([]GPUKernelCorrelationRow, error) { + rows, err := db.Query( + `SELECT correlation, started_at, duration, name FROM trace_event + WHERE gpu_rank = ? AND device_type = 'gpu' + AND correlation != -1 + ORDER BY started_at`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []GPUKernelCorrelationRow + for rows.Next() { + var k GPUKernelCorrelationRow + if err := rows.Scan(&k.Correlation, &k.StartedAt, &k.Duration, &k.NameID); err != nil { + return nil, err + } + out = append(out, k) + } + return out, rows.Err() +} + +// GPUKernelStreamRow holds a GPU event with its stream and correlation ID, +// used by queue-length analysis. +type GPUKernelStreamRow struct { + StartedAt int64 + CUDAStream int + Correlation int +} + +// LoadGPUKernelsWithStream returns GPU events for a rank that have a valid +// cuda_stream_id and correlation, ordered by started_at. +func LoadGPUKernelsWithStream(db *sql.DB, rank int) ([]GPUKernelStreamRow, error) { + rows, err := db.Query( + `SELECT started_at, cuda_stream_id, correlation FROM trace_event + WHERE gpu_rank = ? AND device_type = 'gpu' + AND cuda_stream_id != -1 AND correlation != -1 + ORDER BY started_at`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []GPUKernelStreamRow + for rows.Next() { + var r GPUKernelStreamRow + if err := rows.Scan(&r.StartedAt, &r.CUDAStream, &r.Correlation); err != nil { + return nil, err + } + result = append(result, r) + } + return result, rows.Err() +} + +// GPUKernelIdleTimeRow holds a GPU event with stream and correlated CPU runtime timestamp. +type GPUKernelIdleTimeRow struct { + StartedAt int64 + EndedAt int64 + NameID int + CUDAStreamID int + Correlation int + RuntimeTs int64 // started_at of the correlated CPU launch event (-1 if none) +} + +// LoadGPUKernelsForIdleTime returns GPU kernels filtered to kernel/memset/memcpy +// categories and valid streams, with correlated CPU runtime timestamps resolved +// in Go. Results are ordered by cuda_stream_id, then started_at. +func LoadGPUKernelsForIdleTime(db *sql.DB, rank int, streams []int) ([]GPUKernelIdleTimeRow, error) { + // Step 1: load CPU correlation map (correlation → started_at) + cpuMap, err := loadCPUCorrelationMap(db, rank) + if err != nil { + return nil, fmt.Errorf("loading CPU correlations: %w", err) + } + + // Step 2: load GPU kernels + query := ` +SELECT started_at, ended_at, name, cuda_stream_id, correlation +FROM trace_event +WHERE gpu_rank = ? AND device_type = 'gpu' AND cuda_stream_id != -1 + AND category IN (SELECT id FROM symbol_name + WHERE name IN ('kernel','Kernel','gpu_memset','Memset','gpu_memcpy','Memcpy','mtia_ccp_events'))` + + args := []any{rank} + + if len(streams) > 0 { + var b strings.Builder + for i, s := range streams { + if i > 0 { + b.WriteByte(',') + } + b.WriteByte('?') + args = append(args, s) + } + query += "\n AND cuda_stream_id IN (" + b.String() + ")" + } + + query += "\nORDER BY cuda_stream_id, started_at" + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var kernels []GPUKernelIdleTimeRow + for rows.Next() { + var k GPUKernelIdleTimeRow + if err := rows.Scan(&k.StartedAt, &k.EndedAt, &k.NameID, &k.CUDAStreamID, &k.Correlation); err != nil { + return nil, err + } + // Resolve runtime timestamp from CPU correlation map + if ts, ok := cpuMap[k.Correlation]; ok && k.Correlation != -1 { + k.RuntimeTs = ts + } else { + k.RuntimeTs = -1 + } + kernels = append(kernels, k) + } + return kernels, rows.Err() +} + +// loadCPUCorrelationMap builds a map from correlation ID → CPU started_at for a rank. +func loadCPUCorrelationMap(db *sql.DB, rank int) (map[int]int64, error) { + rows, err := db.Query( + `SELECT correlation, started_at FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' AND correlation != -1`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + m := make(map[int]int64) + for rows.Next() { + var corr int + var ts int64 + if err := rows.Scan(&corr, &ts); err != nil { + return nil, err + } + m[corr] = ts + } + return m, rows.Err() +} + +// TraceDuration returns MAX(started_at) - MIN(started_at) for a given rank. +func TraceDuration(db *sql.DB, rank int) (int64, error) { + var dur int64 + err := db.QueryRow( + `SELECT MAX(started_at) - MIN(started_at) FROM trace_event WHERE gpu_rank = ?`, + rank, + ).Scan(&dur) + return dur, err +} + +// MemoryBWKernelRow holds a GPU event with memory bandwidth data. +type MemoryBWKernelRow struct { + Rank int + StartedAt int64 + Duration int64 + NameID int + ProcessID int64 + MemoryBWGbps float64 +} + +// LoadMemoryBWKernels returns GPU events with memory_bw_gbps > 0 for a rank, +// sorted by started_at. +func LoadMemoryBWKernels(db *sql.DB, rank int) ([]MemoryBWKernelRow, error) { + rows, err := db.Query( + `SELECT gpu_rank, started_at, duration, name, process_id, memory_bw_gbps FROM trace_event + WHERE gpu_rank = ? AND device_type = 'gpu' AND memory_bw_gbps > 0 + ORDER BY started_at`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var kernels []MemoryBWKernelRow + for rows.Next() { + var k MemoryBWKernelRow + if err := rows.Scan(&k.Rank, &k.StartedAt, &k.Duration, &k.NameID, &k.ProcessID, &k.MemoryBWGbps); err != nil { + return nil, err + } + kernels = append(kernels, k) + } + return kernels, rows.Err() +} + +// UserAnnotationRow holds a GPU user annotation event for overlap matching. +type UserAnnotationRow struct { + StartedAt int64 + EndedAt int64 + NameID int + Duration int64 + ProcessID int + ThreadID int +} + +// LoadUserAnnotations returns CPU events whose category is "gpu_user_annotation" +// for a rank, sorted by duration descending (longest/outermost first). +func LoadUserAnnotations(db *sql.DB, rank int) ([]UserAnnotationRow, error) { + rows, err := db.Query( + `SELECT started_at, ended_at, name, duration, process_id, thread_id + FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' + AND category IN (SELECT id FROM symbol_name WHERE name = 'gpu_user_annotation') + ORDER BY duration DESC`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []UserAnnotationRow + for rows.Next() { + var r UserAnnotationRow + if err := rows.Scan(&r.StartedAt, &r.EndedAt, &r.NameID, &r.Duration, &r.ProcessID, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// GPUKernelAnnotationRow holds a GPU kernel with pid/tid for annotation matching. +type GPUKernelAnnotationRow struct { + StartedAt int64 + EndedAt int64 + NameID int + ProcessID int + ThreadID int +} + +// LoadGPUKernelsForAnnotation returns GPU events for a rank with process_id and +// thread_id, sorted by started_at. This is like LoadGPUKernels but includes the +// fields needed for annotation overlap matching. +func LoadGPUKernelsForAnnotation(db *sql.DB, rank int) ([]GPUKernelAnnotationRow, error) { + rows, err := db.Query( + `SELECT started_at, ended_at, name, process_id, thread_id + FROM trace_event + WHERE gpu_rank = ? AND device_type = 'gpu' + ORDER BY started_at`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []GPUKernelAnnotationRow + for rows.Next() { + var k GPUKernelAnnotationRow + if err := rows.Scan(&k.StartedAt, &k.EndedAt, &k.NameID, &k.ProcessID, &k.ThreadID); err != nil { + return nil, err + } + out = append(out, k) + } + return out, rows.Err() +} + +// TraceFile returns the original trace file path for a rank. +func TraceFile(db *sql.DB, rank int) (string, error) { + var path string + err := db.QueryRow("SELECT trace_file FROM meta_data WHERE gpu_rank = ?", rank).Scan(&path) + return path, err +} + +// GPUProcessID returns the process_id used by GPU events for a rank. +func GPUProcessID(db *sql.DB, rank int) (int64, error) { + var pid int64 + err := db.QueryRow( + `SELECT DISTINCT process_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'gpu' LIMIT 1`, + rank, + ).Scan(&pid) + return pid, err +} + +// MinStartedAt returns the minimum started_at timestamp for a rank. +func MinStartedAt(db *sql.DB, rank int) (int64, error) { + var ts int64 + err := db.QueryRow( + "SELECT MIN(started_at) FROM trace_event WHERE gpu_rank = ?", + rank, + ).Scan(&ts) + return ts, err +} + +// ATenOpRow holds a single CPU ATen operator event. +type ATenOpRow struct { + StartedAt int64 + EndedAt int64 + NameID int + ThreadID int64 +} + +// LoadATenOperators returns CPU ATen operator events for a rank filtered by +// the given name IDs, ordered by (thread_id, started_at). +func LoadATenOperators(db *sql.DB, rank int, nameIDs []int) ([]ATenOpRow, error) { + if len(nameIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(nameIDs)) + args := make([]any, 0, len(nameIDs)+1) + args = append(args, rank) + for i, id := range nameIDs { + placeholders[i] = "?" + args = append(args, id) + } + query := `SELECT started_at, ended_at, name, thread_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' + AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY thread_id, started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []ATenOpRow + for rows.Next() { + var r ATenOpRow + if err := rows.Scan(&r.StartedAt, &r.EndedAt, &r.NameID, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// CPURuntimeThreadRow holds a CPU runtime launch event with thread and end time info. +type CPURuntimeThreadRow struct { + Correlation int + StartedAt int64 + EndedAt int64 + ThreadID int64 +} + +// LoadCPURuntimeEventsWithThread returns CPU runtime events for a rank filtered +// by the given name IDs, with thread and end time information, ordered by +// (thread_id, started_at). +func LoadCPURuntimeEventsWithThread(db *sql.DB, rank int, nameIDs []int) ([]CPURuntimeThreadRow, error) { + if len(nameIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(nameIDs)) + args := make([]any, 0, len(nameIDs)+1) + args = append(args, rank) + for i, id := range nameIDs { + placeholders[i] = "?" + args = append(args, id) + } + query := `SELECT correlation, started_at, ended_at, thread_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' + AND correlation != -1 + AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY thread_id, started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CPURuntimeThreadRow + for rows.Next() { + var r CPURuntimeThreadRow + if err := rows.Scan(&r.Correlation, &r.StartedAt, &r.EndedAt, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// CriticalPathEventRow holds a trace event for critical path analysis. +type CriticalPathEventRow struct { + ID int64 + StartedAt int64 + Duration int64 + EndedAt int64 + NameID int + CategoryID int + DeviceType string + ThreadID int64 + ProcessID int64 + Stream int + Correlation int + + WaitOnCudaEventRecordCorrID int + WaitOnStream int +} + +// LoadAllEventsForRank returns all trace events for a given rank, +// ordered by (started_at, id). Used by critical path analysis. +func LoadAllEventsForRank(db *sql.DB, rank int) ([]CriticalPathEventRow, error) { + rows, err := db.Query( + `SELECT id, started_at, duration, ended_at, name, category, device_type, + thread_id, process_id, cuda_stream_id, correlation, + wait_on_cuda_event_record_corr_id, wait_on_stream + FROM trace_event + WHERE gpu_rank = ? + ORDER BY started_at, id`, + rank, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CriticalPathEventRow + for rows.Next() { + var r CriticalPathEventRow + if err := rows.Scan( + &r.ID, &r.StartedAt, &r.Duration, &r.EndedAt, + &r.NameID, &r.CategoryID, &r.DeviceType, + &r.ThreadID, &r.ProcessID, &r.Stream, &r.Correlation, + &r.WaitOnCudaEventRecordCorrID, &r.WaitOnStream, + ); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// CPUOperatorRow holds a CPU operator event for kernel sequence analysis. +type CPUOperatorRow struct { + ID int64 + StartedAt int64 + EndedAt int64 + Duration int64 + NameID int + ThreadID int64 +} + +// LoadCPUOperators returns CPU operator events for a rank filtered by the given +// name IDs, sorted by (thread_id, started_at). +func LoadCPUOperators(db *sql.DB, rank int, nameIDs []int) ([]CPUOperatorRow, error) { + if len(nameIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(nameIDs)) + args := make([]any, 0, len(nameIDs)+1) + args = append(args, rank) + for i, id := range nameIDs { + placeholders[i] = "?" + args = append(args, id) + } + query := `SELECT id, started_at, ended_at, duration, name, thread_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' + AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY thread_id, started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CPUOperatorRow + for rows.Next() { + var r CPUOperatorRow + if err := rows.Scan(&r.ID, &r.StartedAt, &r.EndedAt, &r.Duration, &r.NameID, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// CPURuntimeLaunchRow holds a CPU runtime launch event with duration, name, and thread info. +type CPURuntimeLaunchRow struct { + Correlation int + StartedAt int64 + Duration int64 + NameID int + ThreadID int64 +} + +// LoadCPURuntimeLaunchesWithThread returns CPU launch events for a rank filtered +// by the given name IDs with valid correlation, including duration and name, +// sorted by (thread_id, started_at). +func LoadCPURuntimeLaunchesWithThread(db *sql.DB, rank int, nameIDs []int) ([]CPURuntimeLaunchRow, error) { + if len(nameIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(nameIDs)) + args := make([]any, 0, len(nameIDs)+1) + args = append(args, rank) + for i, id := range nameIDs { + placeholders[i] = "?" + args = append(args, id) + } + query := `SELECT correlation, started_at, duration, name, thread_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' + AND correlation != -1 + AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY thread_id, started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CPURuntimeLaunchRow + for rows.Next() { + var r CPURuntimeLaunchRow + if err := rows.Scan(&r.Correlation, &r.StartedAt, &r.Duration, &r.NameID, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// CUPTIKernelRow holds a CUPTI profiler range event. +type CUPTIKernelRow struct { + EventID int64 + StartedAt int64 + Duration int64 + NameID int +} + +// LoadCUPTIKernels returns GPU events with the given category ID (cuda_profiler_range) +// for a rank, sorted by started_at. +func LoadCUPTIKernels(db *sql.DB, rank int, categoryID int) ([]CUPTIKernelRow, error) { + rows, err := db.Query( + `SELECT id, started_at, duration, name FROM trace_event + WHERE gpu_rank = ? AND category = ? + ORDER BY started_at`, + rank, categoryID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CUPTIKernelRow + for rows.Next() { + var k CUPTIKernelRow + if err := rows.Scan(&k.EventID, &k.StartedAt, &k.Duration, &k.NameID); err != nil { + return nil, err + } + out = append(out, k) + } + return out, rows.Err() +} + +// CPUOpEventRow holds a CPU operator event for operator stack building. +type CPUOpEventRow struct { + StartedAt int64 + EndedAt int64 + NameID int + ThreadID int64 +} + +// LoadCPUOpEvents returns CPU events with the given category ID (cpu_op) for a rank, +// sorted by (thread_id, started_at). +func LoadCPUOpEvents(db *sql.DB, rank int, categoryID int) ([]CPUOpEventRow, error) { + rows, err := db.Query( + `SELECT started_at, ended_at, name, thread_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' AND category = ? + ORDER BY thread_id, started_at`, + rank, categoryID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CPUOpEventRow + for rows.Next() { + var r CPUOpEventRow + if err := rows.Scan(&r.StartedAt, &r.EndedAt, &r.NameID, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// LoadEventArgsBatch returns all event_args for the given event IDs as a map +// from event_id to {key: value}. +func LoadEventArgsBatch(db *sql.DB, eventIDs []int64) (map[int64]map[string]float64, error) { + if len(eventIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(eventIDs)) + args := make([]any, len(eventIDs)) + for i, id := range eventIDs { + placeholders[i] = "?" + args[i] = id + } + query := `SELECT event_id, key, value FROM event_args + WHERE event_id IN (` + strings.Join(placeholders, ",") + `)` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[int64]map[string]float64) + for rows.Next() { + var eventID int64 + var key string + var value float64 + if err := rows.Scan(&eventID, &key, &value); err != nil { + return nil, err + } + if result[eventID] == nil { + result[eventID] = make(map[string]float64) + } + result[eventID][key] = value + } + return result, rows.Err() +} + +// LoadCUPTILaunches returns CPU cudaLaunchKernel events for a rank filtered by +// the given category ID (cuda_runtime), sorted by started_at. +// Unlike LoadCPURuntimeLaunchesWithThread, this does NOT require valid correlation +// (CUPTI traces match launches to kernels by position, not correlation). +func LoadCUPTILaunches(db *sql.DB, rank int, categoryID int, nameIDs []int) ([]CPURuntimeLaunchRow, error) { + if len(nameIDs) == 0 { + return nil, nil + } + placeholders := make([]string, len(nameIDs)) + args := make([]any, 0, len(nameIDs)+2) + args = append(args, rank, categoryID) + for i, id := range nameIDs { + placeholders[i] = "?" + args = append(args, id) + } + query := `SELECT correlation, started_at, duration, name, thread_id FROM trace_event + WHERE gpu_rank = ? AND device_type = 'cpu' AND category = ? + AND name IN (` + strings.Join(placeholders, ",") + `) + ORDER BY started_at` + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []CPURuntimeLaunchRow + for rows.Next() { + var r CPURuntimeLaunchRow + if err := rows.Scan(&r.Correlation, &r.StartedAt, &r.Duration, &r.NameID, &r.ThreadID); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() +} + +// Ranks returns all gpu_rank values in the DB. +func Ranks(db *sql.DB) ([]int, error) { + rows, err := db.Query("SELECT gpu_rank FROM meta_data ORDER BY gpu_rank") + if err != nil { + return nil, err + } + defer rows.Close() + + var ranks []int + for rows.Next() { + var r int + if err := rows.Scan(&r); err != nil { + return nil, err + } + ranks = append(ranks, r) + } + return ranks, rows.Err() +} diff --git a/pkg/store/schema.go b/pkg/store/schema.go new file mode 100644 index 0000000..9c203ee --- /dev/null +++ b/pkg/store/schema.go @@ -0,0 +1,193 @@ +package store + +import ( + "database/sql" + "fmt" + + _ "modernc.org/sqlite" +) + +const ddl = ` +CREATE TABLE IF NOT EXISTS symbol_name ( + id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL +); + +CREATE TABLE IF NOT EXISTS meta_data ( + gpu_rank INTEGER PRIMARY KEY, + trace_file TEXT NOT NULL, + schema_version INTEGER NOT NULL DEFAULT 0, + backend TEXT NOT NULL DEFAULT '', + world_size INTEGER NOT NULL DEFAULT 0, + device_name TEXT NOT NULL DEFAULT '', + device_total_mem INTEGER NOT NULL DEFAULT 0, + device_compute_cap TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE IF NOT EXISTS trace_event ( + id INTEGER PRIMARY KEY, + gpu_rank INTEGER NOT NULL, + device_type TEXT NOT NULL, + started_at INTEGER NOT NULL, + duration INTEGER NOT NULL, + ended_at INTEGER NOT NULL, + category INTEGER NOT NULL, + name INTEGER NOT NULL, + process_id INTEGER NOT NULL, + thread_id INTEGER NOT NULL, + cuda_stream_id INTEGER NOT NULL DEFAULT -1, + correlation INTEGER NOT NULL DEFAULT -1, + memory_bw_gbps REAL NOT NULL DEFAULT 0.0, + wait_on_cuda_event_record_corr_id INTEGER NOT NULL DEFAULT -1, + wait_on_stream INTEGER NOT NULL DEFAULT -1 +); + +CREATE TABLE IF NOT EXISTS event_args ( + event_id INTEGER NOT NULL, + key TEXT NOT NULL, + value REAL NOT NULL, + PRIMARY KEY (event_id, key) +); +` + +const indexDDL = ` +CREATE INDEX IF NOT EXISTS idx_te_rank_device ON trace_event(gpu_rank, device_type); +CREATE INDEX IF NOT EXISTS idx_te_rank_stream ON trace_event(gpu_rank, cuda_stream_id); +CREATE INDEX IF NOT EXISTS idx_te_correlation ON trace_event(correlation) WHERE correlation != -1; +CREATE INDEX IF NOT EXISTS idx_ea_event ON event_args(event_id); +` + +// Create opens (or creates) a SQLite DB and returns a handle. +func Create(path string) (*sql.DB, error) { + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, err + } + // Performance pragmas + for _, pragma := range []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA cache_size=-64000", // 64MB + } { + if _, err := db.Exec(pragma); err != nil { + db.Close() + return nil, fmt.Errorf("pragma %q: %w", pragma, err) + } + } + return db, nil +} + +// CreateTables creates the schema tables (without indexes). +func CreateTables(db *sql.DB) error { + _, err := db.Exec(ddl) + return err +} + +// CreateIndexes creates the indexes (call after bulk insert). +func CreateIndexes(db *sql.DB) error { + _, err := db.Exec(indexDDL) + return err +} + +// InsertSymbols bulk-inserts symbol_name rows. +func InsertSymbols(tx *sql.Tx, symbols []struct{ ID int; Name string }) error { + stmt, err := tx.Prepare("INSERT INTO symbol_name (id, name) VALUES (?, ?)") + if err != nil { + return err + } + defer stmt.Close() + for _, s := range symbols { + if _, err := stmt.Exec(s.ID, s.Name); err != nil { + return fmt.Errorf("insert symbol %d %q: %w", s.ID, s.Name, err) + } + } + return nil +} + +// InsertMetadata inserts a metadata row. +func InsertMetadata(tx *sql.Tx, rank int, traceFile, backend string, schemaVersion, worldSize int, deviceName string, deviceTotalMem int64, deviceCompCap string) error { + _, err := tx.Exec( + `INSERT INTO meta_data (gpu_rank, trace_file, schema_version, backend, world_size, device_name, device_total_mem, device_compute_cap) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + rank, traceFile, schemaVersion, backend, worldSize, deviceName, deviceTotalMem, deviceCompCap, + ) + return err +} + +// Event holds the data for a single trace_event row. +type Event struct { + Rank int + DeviceType string + StartedAt int64 + Duration int64 + EndedAt int64 + CategoryID int + NameID int + ProcessID int64 + ThreadID int64 + CUDAStream int + Correlation int + MemoryBWGbps float64 + + WaitOnCudaEventRecordCorrID int + WaitOnStream int + + ExtraArgs map[string]float64 // additional numeric args (e.g. CUPTI counters) +} + +// InsertEvents bulk-inserts trace_event rows and returns the row ID of the +// first inserted event (useful for mapping events to their event_args). +func InsertEvents(tx *sql.Tx, events []Event) (int64, error) { + if len(events) == 0 { + return 0, nil + } + stmt, err := tx.Prepare( + `INSERT INTO trace_event (gpu_rank, device_type, started_at, duration, ended_at, category, name, process_id, thread_id, cuda_stream_id, correlation, memory_bw_gbps, wait_on_cuda_event_record_corr_id, wait_on_stream) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if err != nil { + return 0, err + } + defer stmt.Close() + for _, e := range events { + if _, err := stmt.Exec( + e.Rank, e.DeviceType, e.StartedAt, e.Duration, e.EndedAt, + e.CategoryID, e.NameID, e.ProcessID, e.ThreadID, e.CUDAStream, e.Correlation, e.MemoryBWGbps, + e.WaitOnCudaEventRecordCorrID, e.WaitOnStream, + ); err != nil { + return 0, fmt.Errorf("insert event: %w", err) + } + } + // Compute first row ID from last_insert_rowid. + var lastID int64 + err = tx.QueryRow("SELECT last_insert_rowid()").Scan(&lastID) + if err != nil { + return 0, fmt.Errorf("last_insert_rowid: %w", err) + } + firstID := lastID - int64(len(events)) + 1 + return firstID, nil +} + +// EventArg holds a single extra argument for an event. +type EventArg struct { + EventID int64 + Key string + Value float64 +} + +// InsertEventArgs bulk-inserts event_args rows. +func InsertEventArgs(tx *sql.Tx, args []EventArg) error { + if len(args) == 0 { + return nil + } + stmt, err := tx.Prepare("INSERT INTO event_args (event_id, key, value) VALUES (?, ?, ?)") + if err != nil { + return err + } + defer stmt.Close() + for _, a := range args { + if _, err := stmt.Exec(a.EventID, a.Key, a.Value); err != nil { + return fmt.Errorf("insert event_arg: %w", err) + } + } + return nil +} diff --git a/pkg/symbol/table.go b/pkg/symbol/table.go new file mode 100644 index 0000000..e2a9575 --- /dev/null +++ b/pkg/symbol/table.go @@ -0,0 +1,66 @@ +package symbol + +import "fmt" + +// Table maps strings to sequential integer IDs and back. +type Table struct { + nameToID map[string]int + idToName []string +} + +// NewTable creates an empty symbol table. +func NewTable() *Table { + return &Table{ + nameToID: make(map[string]int), + } +} + +// Add inserts a name into the table if it doesn't already exist and returns its ID. +func (t *Table) Add(name string) int { + if id, ok := t.nameToID[name]; ok { + return id + } + id := len(t.idToName) + t.nameToID[name] = id + t.idToName = append(t.idToName, name) + return id +} + +// GetID returns the ID for a name. Returns -1 if not found. +func (t *Table) GetID(name string) int { + if id, ok := t.nameToID[name]; ok { + return id + } + return -1 +} + +// GetName returns the name for an ID. Returns error if out of range. +func (t *Table) GetName(id int) (string, error) { + if id < 0 || id >= len(t.idToName) { + return "", fmt.Errorf("symbol ID %d out of range [0, %d)", id, len(t.idToName)) + } + return t.idToName[id], nil +} + +// Len returns the number of symbols. +func (t *Table) Len() int { + return len(t.idToName) +} + +// All returns a copy of all (id, name) pairs for iteration. +func (t *Table) All() []struct { + ID int + Name string +} { + out := make([]struct { + ID int + Name string + }, len(t.idToName)) + for i, name := range t.idToName { + out[i] = struct { + ID int + Name string + }{ID: i, Name: name} + } + return out +} diff --git a/pkg/symbol/table_test.go b/pkg/symbol/table_test.go new file mode 100644 index 0000000..1307c83 --- /dev/null +++ b/pkg/symbol/table_test.go @@ -0,0 +1,49 @@ +package symbol + +import "testing" + +func TestAddAndGetID(t *testing.T) { + tbl := NewTable() + id0 := tbl.Add("foo") + id1 := tbl.Add("bar") + id2 := tbl.Add("foo") // duplicate + + if id0 != 0 { + t.Errorf("first symbol should be 0, got %d", id0) + } + if id1 != 1 { + t.Errorf("second symbol should be 1, got %d", id1) + } + if id2 != id0 { + t.Errorf("duplicate should return same ID: got %d, want %d", id2, id0) + } + if tbl.Len() != 2 { + t.Errorf("len should be 2, got %d", tbl.Len()) + } +} + +func TestGetName(t *testing.T) { + tbl := NewTable() + tbl.Add("hello") + tbl.Add("world") + + name, err := tbl.GetName(0) + if err != nil || name != "hello" { + t.Errorf("GetName(0) = %q, %v; want 'hello', nil", name, err) + } + name, err = tbl.GetName(1) + if err != nil || name != "world" { + t.Errorf("GetName(1) = %q, %v; want 'world', nil", name, err) + } + _, err = tbl.GetName(99) + if err == nil { + t.Error("GetName(99) should return error") + } +} + +func TestGetIDNotFound(t *testing.T) { + tbl := NewTable() + if got := tbl.GetID("missing"); got != -1 { + t.Errorf("GetID('missing') = %d; want -1", got) + } +} diff --git a/pkg/trace/parser.go b/pkg/trace/parser.go new file mode 100644 index 0000000..415ccd8 --- /dev/null +++ b/pkg/trace/parser.go @@ -0,0 +1,318 @@ +package trace + +import ( + "compress/gzip" + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" +) + +// RawEvent represents one trace event extracted from the JSON. +type RawEvent struct { + Cat string + Name string + Ts int64 // microseconds + Dur int64 // microseconds + Pid int64 + Tid int64 + Stream int + Correlation int + MemoryBWGbps float64 + + // CUDA sync fields (used by critical path analysis) + WaitOnCudaEventRecordCorrID int // -1 if not present + WaitOnStream int // -1 if not present + + ExtraArgs map[string]float64 // additional numeric args (e.g. CUPTI counters) +} + +// DeviceProperty holds GPU device information from the trace metadata. +type DeviceProperty struct { + ID int `json:"id"` + Name string `json:"name"` + TotalMem int64 `json:"totalGlobalMem"` + ComputeMajor int `json:"computeMajor"` + ComputeMinor int `json:"computeMinor"` +} + +// Metadata holds the non-event fields from a trace file. +type Metadata struct { + Rank int + TraceFile string + SchemaVersion int + Backend string + WorldSize int + DeviceName string + DeviceTotalMem int64 + DeviceCompCap string +} + +// RankTrace holds parsed data for a single rank. +type RankTrace struct { + Meta Metadata + Events []RawEvent +} + +// DiscoverFiles finds *.json and *.json.gz files in dir. +func DiscoverFiles(dir string) ([]string, error) { + var files []string + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("reading trace dir: %w", err) + } + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if strings.HasSuffix(name, ".json") || strings.HasSuffix(name, ".json.gz") { + files = append(files, filepath.Join(dir, name)) + } + } + sort.Strings(files) + return files, nil +} + +var rankFromFilename = regexp.MustCompile(`rank[_-](\d+)`) + +// ParseFile reads a single JSON/GZ trace file and returns parsed events + metadata. +func ParseFile(path string) (*RankTrace, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var dec *json.Decoder + if strings.HasSuffix(path, ".gz") { + gz, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("gzip open: %w", err) + } + defer gz.Close() + dec = json.NewDecoder(gz) + } else { + dec = json.NewDecoder(f) + } + dec.UseNumber() // preserve int64 precision for timestamps + + // Decode to a raw map + var raw map[string]json.RawMessage + if err := dec.Decode(&raw); err != nil { + return nil, fmt.Errorf("json decode: %w", err) + } + + // Parse metadata + meta := Metadata{ + Rank: -1, + TraceFile: path, + } + + if v, ok := raw["schemaVersion"]; ok { + var sv int + if err := json.Unmarshal(v, &sv); err == nil { + meta.SchemaVersion = sv + } + } + + if v, ok := raw["distributedInfo"]; ok { + var di struct { + Backend string `json:"backend"` + Rank int `json:"rank"` + WorldSize int `json:"world_size"` + } + if err := json.Unmarshal(v, &di); err == nil { + meta.Backend = di.Backend + meta.Rank = di.Rank + meta.WorldSize = di.WorldSize + } + } + + // Fallback rank from filename + if meta.Rank < 0 { + if m := rankFromFilename.FindStringSubmatch(filepath.Base(path)); m != nil { + if r, err := strconv.Atoi(m[1]); err == nil { + meta.Rank = r + } + } + } + if meta.Rank < 0 { + meta.Rank = 0 + } + + if v, ok := raw["deviceProperties"]; ok { + var props []DeviceProperty + if err := json.Unmarshal(v, &props); err == nil && len(props) > 0 { + meta.DeviceName = props[0].Name + meta.DeviceTotalMem = props[0].TotalMem + meta.DeviceCompCap = fmt.Sprintf("%d.%d", props[0].ComputeMajor, props[0].ComputeMinor) + } + } + + // Parse trace events + eventsRaw, ok := raw["traceEvents"] + if !ok { + return &RankTrace{Meta: meta}, nil + } + + var rawEvents []json.RawMessage + if err := json.Unmarshal(eventsRaw, &rawEvents); err != nil { + return nil, fmt.Errorf("parsing traceEvents array: %w", err) + } + + events := make([]RawEvent, 0, len(rawEvents)/2) // rough capacity guess + + for _, re := range rawEvents { + var obj map[string]json.RawMessage + if err := json.Unmarshal(re, &obj); err != nil { + continue + } + + // Must have both dur and cat + durRaw, hasDur := obj["dur"] + catRaw, hasCat := obj["cat"] + if !hasDur || !hasCat { + continue + } + + var cat string + if err := json.Unmarshal(catRaw, &cat); err != nil { + continue + } + + // Skip python_function and Trace categories + if cat == "python_function" || cat == "Trace" { + continue + } + + // Parse duration (can be int or float) + dur, err := parseNumber(durRaw) + if err != nil { + continue + } + + // Parse timestamp with full precision + tsRaw, hasTs := obj["ts"] + if !hasTs { + continue + } + ts, err := parseNumber(tsRaw) + if err != nil { + continue + } + + var name string + if v, ok := obj["name"]; ok { + json.Unmarshal(v, &name) + } + + var pid, tid int64 + if v, ok := obj["pid"]; ok { + pid, _ = parseNumber(v) + } + if v, ok := obj["tid"]; ok { + tid, _ = parseNumber(v) + } + + // Extract stream, correlation, memory bandwidth, and CUDA sync fields from args + stream := -1 + correlation := -1 + waitOnCudaEventRecordCorrID := -1 + waitOnStream := -1 + var memoryBWGbps float64 + var extraArgs map[string]float64 + if argsRaw, ok := obj["args"]; ok { + var args map[string]json.RawMessage + if json.Unmarshal(argsRaw, &args) == nil { + if v, ok := args["stream"]; ok { + if s, err := parseNumber(v); err == nil { + stream = int(s) + } + } + if v, ok := args["correlation"]; ok { + if c, err := parseNumber(v); err == nil { + correlation = int(c) + } + } + if v, ok := args["memory bandwidth (GB/s)"]; ok { + var n json.Number + if json.Unmarshal(v, &n) == nil { + if f, err := n.Float64(); err == nil { + memoryBWGbps = f + } + } + } + if v, ok := args["wait_on_cuda_event_record_corr_id"]; ok { + if c, err := parseNumber(v); err == nil { + waitOnCudaEventRecordCorrID = int(c) + } + } + if v, ok := args["wait_on_stream"]; ok { + if s, err := parseNumber(v); err == nil { + waitOnStream = int(s) + } + } + // Collect remaining numeric args (e.g. CUPTI counters). + knownArgs := map[string]bool{ + "stream": true, "correlation": true, "memory bandwidth (GB/s)": true, + "wait_on_cuda_event_record_corr_id": true, "wait_on_stream": true, + } + for k, v := range args { + if knownArgs[k] { + continue + } + var n json.Number + if json.Unmarshal(v, &n) == nil { + if f, err := n.Float64(); err == nil { + if extraArgs == nil { + extraArgs = make(map[string]float64) + } + extraArgs[k] = f + } + } + } + } + } + + events = append(events, RawEvent{ + Cat: cat, + Name: name, + Ts: ts, + Dur: dur, + Pid: pid, + Tid: tid, + Stream: stream, + Correlation: correlation, + MemoryBWGbps: memoryBWGbps, + WaitOnCudaEventRecordCorrID: waitOnCudaEventRecordCorrID, + WaitOnStream: waitOnStream, + ExtraArgs: extraArgs, + }) + } + + return &RankTrace{Meta: meta, Events: events}, nil +} + +// parseNumber unmarshals a JSON number (int or float) to int64. +func parseNumber(raw json.RawMessage) (int64, error) { + var n json.Number + if err := json.Unmarshal(raw, &n); err != nil { + return 0, err + } + // Try int first + if i, err := n.Int64(); err == nil { + return i, nil + } + // Fall back to float + f, err := n.Float64() + if err != nil { + return 0, err + } + return int64(math.Round(f)), nil +} From bbd16fd0bba1430abe4c5bf8bfecdce519afa2f5 Mon Sep 17 00:00:00 2001 From: Lei Lei Date: Wed, 11 Mar 2026 13:54:57 +0800 Subject: [PATCH 2/2] Fetch test data via sparse checkout from upstream HTA repo in CI The tests/ symlink points to a local path unavailable in CI. Instead, sparse-clone only tests/data/ from the public facebookresearch/HolisticTraceAnalysis repo. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci-go.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-go.yml b/.github/workflows/ci-go.yml index bc8c435..5e2adfb 100644 --- a/.github/workflows/ci-go.yml +++ b/.github/workflows/ci-go.yml @@ -20,8 +20,14 @@ jobs: runs-on: ubuntu-latest-m steps: - uses: actions/checkout@v4 - with: - submodules: recursive + - name: Fetch test data + run: | + git clone --depth 1 --filter=blob:none --sparse \ + https://github.com/facebookresearch/HolisticTraceAnalysis.git /tmp/hta + cd /tmp/hta + git sparse-checkout set tests/data + rm -f "$GITHUB_WORKSPACE/tests" + ln -s /tmp/hta/tests "$GITHUB_WORKSPACE/tests" - uses: actions/setup-go@v5 with: go-version-file: go.mod