Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions code_review_graph/event_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Post-build Spring Application Event resolver.

Connects publishEvent() call sites to @EventListener methods by matching
on the event class name.

Resolution chain:
publisher_method →(PUBLISHES)→ event:XxxEvent
listener_method →(HANDLES)→ event:XxxEvent
⟹ emit CALLS edge: publisher_method → listener_method
"""

from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .graph import GraphStore

logger = logging.getLogger(__name__)


def resolve_spring_events(store: GraphStore) -> dict:
"""Emit CALLS edges from event publishers to matching @EventListener methods.

Safe to call multiple times — already-resolved edges are skipped via
extra.event_resolved flag.

Returns a dict with resolution counts for telemetry.
"""
conn = store._conn

# Only process Java files
java_files: set[str] = {
row["file_path"]
for row in conn.execute(
"SELECT DISTINCT file_path FROM nodes WHERE language = 'java'"
).fetchall()
}
if not java_files:
return {"files_indexed": 0, "calls_emitted": 0}

# Build: event_type → [listener_method_qualified]
listeners: dict[str, list[str]] = {}
for row in conn.execute(
"SELECT source_qualified, extra FROM edges WHERE kind = 'HANDLES'"
).fetchall():
try:
extra = json.loads(row["extra"] or "{}")
except (json.JSONDecodeError, TypeError):
extra = {}
event_type = extra.get("event_type")
if event_type:
listeners.setdefault(event_type, []).append(row["source_qualified"])

if not listeners:
logger.info("Event resolver: no HANDLES edges found, skipping")
return {"files_indexed": len(java_files), "calls_emitted": 0}

# Collect PUBLISHES edges and emit CALLS for each matching listener
publishes_rows = conn.execute(
"SELECT id, source_qualified, extra, file_path FROM edges WHERE kind = 'PUBLISHES'"
).fetchall()

emitted = 0
new_edges: list[tuple] = []

for row in publishes_rows:
if row["file_path"] not in java_files:
continue
try:
extra = json.loads(row["extra"] or "{}")
except (json.JSONDecodeError, TypeError):
extra = {}

if extra.get("event_resolved"):
continue

event_type = extra.get("event_type")
if not event_type:
continue

matching_listeners = listeners.get(event_type, [])
for listener_qual in matching_listeners:
call_extra = json.dumps({
"event_resolved": True,
"event_type": event_type,
})
new_edges.append((
"CALLS",
row["source_qualified"],
listener_qual,
row["source_qualified"],
listener_qual,
row["file_path"],
call_extra,
))
emitted += 1
logger.debug(
"Event resolved: %s →[%s]→ %s",
row["source_qualified"], event_type, listener_qual,
)

# Mark PUBLISHES edge as processed
extra["event_resolved"] = True
conn.execute(
"UPDATE edges SET extra = ? WHERE id = ?",
(json.dumps(extra), row["id"]),
)

if new_edges:
conn.executemany(
"INSERT OR IGNORE INTO edges "
"(kind, source_qualified, target_qualified, file_path, extra) "
"VALUES (?, ?, ?, ?, ?)",
[(e[0], e[1], e[2], e[5], e[6]) for e in new_edges],
)
conn.commit()

logger.info(
"Spring event resolver: emitted %d CALLS edges in %d Java files",
emitted, len(java_files),
)
return {"files_indexed": len(java_files), "calls_emitted": emitted}
16 changes: 16 additions & 0 deletions code_review_graph/incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def _run_temporal_resolver(store: GraphStore) -> Optional[dict]:
logger.warning("Temporal resolver failed: %s", exc)
return None


def _run_event_resolver(store: GraphStore) -> Optional[dict]:
"""Run the Spring Application Event resolver, swallowing any failure so
build never fails because of it. Returns stats or None on error.
"""
try:
from .event_resolver import resolve_spring_events
return resolve_spring_events(store)
except Exception as exc: # noqa: BLE001
logger.warning("Spring event resolver failed: %s", exc)
return None

# Default ignore patterns (in addition to .gitignore).
#
# `<dir>/**` patterns are matched at any depth by _should_ignore, so
Expand Down Expand Up @@ -904,6 +916,7 @@ def full_build(
rescript_stats = _run_rescript_resolver(store)
spring_stats = _run_spring_resolver(store)
temporal_stats = _run_temporal_resolver(store)
event_stats = _run_event_resolver(store)

return {
"files_parsed": len(files),
Expand All @@ -913,6 +926,7 @@ def full_build(
"rescript_resolution": rescript_stats,
"spring_resolution": spring_stats,
"temporal_resolution": temporal_stats,
"event_resolution": event_stats,
}


Expand Down Expand Up @@ -1042,6 +1056,7 @@ def incremental_update(
spring_changed = any(rp.endswith(".java") for rp in all_files)
spring_stats = _run_spring_resolver(store) if spring_changed else None
temporal_stats = _run_temporal_resolver(store) if spring_changed else None
event_stats = _run_event_resolver(store) if spring_changed else None

return {
"files_updated": len(all_files),
Expand All @@ -1053,6 +1068,7 @@ def incremental_update(
"rescript_resolution": rescript_stats,
"spring_resolution": spring_stats,
"temporal_resolution": temporal_stats,
"event_resolution": event_stats,
}


Expand Down
172 changes: 172 additions & 0 deletions code_review_graph/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,11 @@ def _builtin_language_names() -> frozenset[str]:
"KafkaSender",
})

# Spring Application Event annotations
_SPRING_EVENT_LISTENER_ANNOTATIONS = frozenset({"EventListener"})
# Spring publish methods
_SPRING_PUBLISH_METHODS = frozenset({"publishEvent", "multicastEvent"})


# ---------------------------------------------------------------------------
# ReScript regex patterns and helpers (no tree-sitter grammar bundled)
Expand Down Expand Up @@ -4111,6 +4116,162 @@ def _get_kafka_annotation_topics(annotation_node) -> list[str]:
topics.append(raw)
return topics

def _emit_event_listener_from_method(
self,
method_node,
method_name: str,
class_name: Optional[str],
file_path: str,
edges: list[EdgeInfo],
) -> None:
"""Emit HANDLES edge for @EventListener annotated methods.

Event type is inferred from:
1. @EventListener(OrderEvent.class) or @EventListener(classes = {A.class, B.class})
2. First method parameter type (when annotation has no args)
"""
for child in method_node.children:
if child.type != "modifiers":
continue
for mod in child.children:
if mod.type not in ("annotation", "marker_annotation"):
continue
ann_name: Optional[str] = None
for sub in mod.children:
if sub.type == "identifier":
ann_name = sub.text.decode("utf-8", errors="replace")
break
if ann_name not in _SPRING_EVENT_LISTENER_ANNOTATIONS:
continue

event_types: list[str] = []

# Try annotation arguments first
for arg_list in mod.children:
if arg_list.type != "annotation_argument_list":
continue
for elem in arg_list.children:
if elem.type == "class_literal":
for typ in elem.children:
if typ.type == "type_identifier":
event_types.append(
typ.text.decode("utf-8", errors="replace")
)
break
elif elem.type == "element_value_pair":
key_node = next(
(c for c in elem.children if c.type == "identifier"), None
)
if key_node and key_node.text.decode("utf-8", errors="replace") in (
"value", "classes"
):
for val in elem.children:
if val.type == "class_literal":
for typ in val.children:
if typ.type == "type_identifier":
event_types.append(
typ.text.decode("utf-8", errors="replace")
)
break
elif val.type in (
"array_initializer",
"element_value_array_initializer",
):
for item in val.children:
if item.type == "class_literal":
for typ in item.children:
if typ.type == "type_identifier":
event_types.append(
typ.text.decode(
"utf-8", errors="replace"
)
)
break

# Fall back to first method parameter type
if not event_types:
for param_list in method_node.children:
if param_list.type != "formal_parameters":
continue
for param in param_list.children:
if param.type == "formal_parameter":
for typ in param.children:
if typ.type == "type_identifier":
event_types.append(
typ.text.decode("utf-8", errors="replace")
)
break
break

qualified_source = self._qualify(method_name, file_path, class_name)
for event_type in event_types:
edges.append(EdgeInfo(
kind="HANDLES",
source=qualified_source,
target=f"event:{event_type}",
file_path=file_path,
line=method_node.start_point[0] + 1,
extra={"event_type": event_type},
))

def _emit_publish_event_edges_from_method(
self,
method_node,
method_name: str,
class_name: Optional[str],
file_path: str,
edges: list[EdgeInfo],
) -> None:
"""Walk a method body and emit PUBLISHES edges for publishEvent() calls."""
body = next(
(c for c in method_node.children if c.type == "block"), None
)
if not body:
return
qualified_source = self._qualify(method_name, file_path, class_name)
self._walk_and_emit_publish_events(body, qualified_source, file_path, edges)

def _walk_and_emit_publish_events(
self,
node,
qualified_source: str,
file_path: str,
edges: list[EdgeInfo],
) -> None:
"""Recursively scan an AST subtree for publishEvent(new XxxEvent()) calls."""
if node.type == "method_invocation":
# For chained calls like `eventPublisher.publishEvent(...)` the
# children are [receiver_identifier, '.', method_identifier, argument_list].
# Take the last identifier before the argument_list — that is the
# actual method name, not the receiver.
ident_nodes = [c for c in node.children if c.type == "identifier"]
method_name_node = ident_nodes[-1] if ident_nodes else None
if method_name_node:
call_name = method_name_node.text.decode("utf-8", errors="replace")
if call_name in _SPRING_PUBLISH_METHODS:
args = next(
(c for c in node.children if c.type == "argument_list"), None
)
if args:
for arg in args.children:
if arg.type == "object_creation_expression":
for typ in arg.children:
if typ.type == "type_identifier":
event_type = typ.text.decode(
"utf-8", errors="replace"
)
edges.append(EdgeInfo(
kind="PUBLISHES",
source=qualified_source,
target=f"event:{event_type}",
file_path=file_path,
line=node.start_point[0] + 1,
extra={"event_type": event_type},
))
break
for child in node.children:
self._walk_and_emit_publish_events(child, qualified_source, file_path, edges)

def _emit_kafka_edges_from_class(
self,
class_node,
Expand Down Expand Up @@ -4444,6 +4605,17 @@ def _extract_functions(
self._emit_kafka_edges_from_method(
child, name, enclosing_class, file_path, edges,
)
if any(a.split("(")[0] in _SPRING_EVENT_LISTENER_ANNOTATIONS for a in deco_list):
method_extra["event_listener"] = True
self._emit_event_listener_from_method(
child, name, enclosing_class, file_path, edges,
)

# Detect publishEvent() calls in any Java method body
if language == "java" and enclosing_class:
self._emit_publish_event_edges_from_method(
child, name, enclosing_class, file_path, edges,
)

node = NodeInfo(
kind=kind,
Expand Down
Loading