diff --git a/code_review_graph/event_resolver.py b/code_review_graph/event_resolver.py new file mode 100644 index 00000000..489780db --- /dev/null +++ b/code_review_graph/event_resolver.py @@ -0,0 +1,125 @@ +"""Post-build Spring Application Event resolver. + +Connects publishEvent() call sites to @EventListener methods by matching +on the event class name. + +Resolution chain: + publisher_method →(PUBLISHES)→ event:XxxEvent + listener_method →(HANDLES)→ event:XxxEvent + ⟹ emit CALLS edge: publisher_method → listener_method +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph import GraphStore + +logger = logging.getLogger(__name__) + + +def resolve_spring_events(store: GraphStore) -> dict: + """Emit CALLS edges from event publishers to matching @EventListener methods. + + Safe to call multiple times — already-resolved edges are skipped via + extra.event_resolved flag. + + Returns a dict with resolution counts for telemetry. + """ + conn = store._conn + + # Only process Java files + java_files: set[str] = { + row["file_path"] + for row in conn.execute( + "SELECT DISTINCT file_path FROM nodes WHERE language = 'java'" + ).fetchall() + } + if not java_files: + return {"files_indexed": 0, "calls_emitted": 0} + + # Build: event_type → [listener_method_qualified] + listeners: dict[str, list[str]] = {} + for row in conn.execute( + "SELECT source_qualified, extra FROM edges WHERE kind = 'HANDLES'" + ).fetchall(): + try: + extra = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + extra = {} + event_type = extra.get("event_type") + if event_type: + listeners.setdefault(event_type, []).append(row["source_qualified"]) + + if not listeners: + logger.info("Event resolver: no HANDLES edges found, skipping") + return {"files_indexed": len(java_files), "calls_emitted": 0} + + # Collect PUBLISHES edges and emit CALLS for each matching listener + publishes_rows = conn.execute( + "SELECT id, source_qualified, extra, file_path FROM edges WHERE kind = 'PUBLISHES'" + ).fetchall() + + emitted = 0 + new_edges: list[tuple] = [] + + for row in publishes_rows: + if row["file_path"] not in java_files: + continue + try: + extra = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + extra = {} + + if extra.get("event_resolved"): + continue + + event_type = extra.get("event_type") + if not event_type: + continue + + matching_listeners = listeners.get(event_type, []) + for listener_qual in matching_listeners: + call_extra = json.dumps({ + "event_resolved": True, + "event_type": event_type, + }) + new_edges.append(( + "CALLS", + row["source_qualified"], + listener_qual, + row["source_qualified"], + listener_qual, + row["file_path"], + call_extra, + )) + emitted += 1 + logger.debug( + "Event resolved: %s →[%s]→ %s", + row["source_qualified"], event_type, listener_qual, + ) + + # Mark PUBLISHES edge as processed + extra["event_resolved"] = True + conn.execute( + "UPDATE edges SET extra = ? WHERE id = ?", + (json.dumps(extra), row["id"]), + ) + + if new_edges: + conn.executemany( + "INSERT OR IGNORE INTO edges " + "(kind, source_qualified, target_qualified, file_path, extra) " + "VALUES (?, ?, ?, ?, ?)", + [(e[0], e[1], e[2], e[5], e[6]) for e in new_edges], + ) + conn.commit() + + logger.info( + "Spring event resolver: emitted %d CALLS edges in %d Java files", + emitted, len(java_files), + ) + return {"files_indexed": len(java_files), "calls_emitted": emitted} diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index 81dc3026..432c0ab2 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -93,6 +93,18 @@ def _run_temporal_resolver(store: GraphStore) -> Optional[dict]: logger.warning("Temporal resolver failed: %s", exc) return None + +def _run_event_resolver(store: GraphStore) -> Optional[dict]: + """Run the Spring Application Event resolver, swallowing any failure so + build never fails because of it. Returns stats or None on error. + """ + try: + from .event_resolver import resolve_spring_events + return resolve_spring_events(store) + except Exception as exc: # noqa: BLE001 + logger.warning("Spring event resolver failed: %s", exc) + return None + # Default ignore patterns (in addition to .gitignore). # # `/**` patterns are matched at any depth by _should_ignore, so @@ -904,6 +916,7 @@ def full_build( rescript_stats = _run_rescript_resolver(store) spring_stats = _run_spring_resolver(store) temporal_stats = _run_temporal_resolver(store) + event_stats = _run_event_resolver(store) return { "files_parsed": len(files), @@ -913,6 +926,7 @@ def full_build( "rescript_resolution": rescript_stats, "spring_resolution": spring_stats, "temporal_resolution": temporal_stats, + "event_resolution": event_stats, } @@ -1042,6 +1056,7 @@ def incremental_update( spring_changed = any(rp.endswith(".java") for rp in all_files) spring_stats = _run_spring_resolver(store) if spring_changed else None temporal_stats = _run_temporal_resolver(store) if spring_changed else None + event_stats = _run_event_resolver(store) if spring_changed else None return { "files_updated": len(all_files), @@ -1053,6 +1068,7 @@ def incremental_update( "rescript_resolution": rescript_stats, "spring_resolution": spring_stats, "temporal_resolution": temporal_stats, + "event_resolution": event_stats, } diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index c55b2e8f..cbeb067d 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -518,6 +518,11 @@ def _builtin_language_names() -> frozenset[str]: "KafkaSender", }) +# Spring Application Event annotations +_SPRING_EVENT_LISTENER_ANNOTATIONS = frozenset({"EventListener"}) +# Spring publish methods +_SPRING_PUBLISH_METHODS = frozenset({"publishEvent", "multicastEvent"}) + # --------------------------------------------------------------------------- # ReScript regex patterns and helpers (no tree-sitter grammar bundled) @@ -4111,6 +4116,162 @@ def _get_kafka_annotation_topics(annotation_node) -> list[str]: topics.append(raw) return topics + def _emit_event_listener_from_method( + self, + method_node, + method_name: str, + class_name: Optional[str], + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit HANDLES edge for @EventListener annotated methods. + + Event type is inferred from: + 1. @EventListener(OrderEvent.class) or @EventListener(classes = {A.class, B.class}) + 2. First method parameter type (when annotation has no args) + """ + for child in method_node.children: + if child.type != "modifiers": + continue + for mod in child.children: + if mod.type not in ("annotation", "marker_annotation"): + continue + ann_name: Optional[str] = None + for sub in mod.children: + if sub.type == "identifier": + ann_name = sub.text.decode("utf-8", errors="replace") + break + if ann_name not in _SPRING_EVENT_LISTENER_ANNOTATIONS: + continue + + event_types: list[str] = [] + + # Try annotation arguments first + for arg_list in mod.children: + if arg_list.type != "annotation_argument_list": + continue + for elem in arg_list.children: + if elem.type == "class_literal": + for typ in elem.children: + if typ.type == "type_identifier": + event_types.append( + typ.text.decode("utf-8", errors="replace") + ) + break + elif elem.type == "element_value_pair": + key_node = next( + (c for c in elem.children if c.type == "identifier"), None + ) + if key_node and key_node.text.decode("utf-8", errors="replace") in ( + "value", "classes" + ): + for val in elem.children: + if val.type == "class_literal": + for typ in val.children: + if typ.type == "type_identifier": + event_types.append( + typ.text.decode("utf-8", errors="replace") + ) + break + elif val.type in ( + "array_initializer", + "element_value_array_initializer", + ): + for item in val.children: + if item.type == "class_literal": + for typ in item.children: + if typ.type == "type_identifier": + event_types.append( + typ.text.decode( + "utf-8", errors="replace" + ) + ) + break + + # Fall back to first method parameter type + if not event_types: + for param_list in method_node.children: + if param_list.type != "formal_parameters": + continue + for param in param_list.children: + if param.type == "formal_parameter": + for typ in param.children: + if typ.type == "type_identifier": + event_types.append( + typ.text.decode("utf-8", errors="replace") + ) + break + break + + qualified_source = self._qualify(method_name, file_path, class_name) + for event_type in event_types: + edges.append(EdgeInfo( + kind="HANDLES", + source=qualified_source, + target=f"event:{event_type}", + file_path=file_path, + line=method_node.start_point[0] + 1, + extra={"event_type": event_type}, + )) + + def _emit_publish_event_edges_from_method( + self, + method_node, + method_name: str, + class_name: Optional[str], + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Walk a method body and emit PUBLISHES edges for publishEvent() calls.""" + body = next( + (c for c in method_node.children if c.type == "block"), None + ) + if not body: + return + qualified_source = self._qualify(method_name, file_path, class_name) + self._walk_and_emit_publish_events(body, qualified_source, file_path, edges) + + def _walk_and_emit_publish_events( + self, + node, + qualified_source: str, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Recursively scan an AST subtree for publishEvent(new XxxEvent()) calls.""" + if node.type == "method_invocation": + # For chained calls like `eventPublisher.publishEvent(...)` the + # children are [receiver_identifier, '.', method_identifier, argument_list]. + # Take the last identifier before the argument_list — that is the + # actual method name, not the receiver. + ident_nodes = [c for c in node.children if c.type == "identifier"] + method_name_node = ident_nodes[-1] if ident_nodes else None + if method_name_node: + call_name = method_name_node.text.decode("utf-8", errors="replace") + if call_name in _SPRING_PUBLISH_METHODS: + args = next( + (c for c in node.children if c.type == "argument_list"), None + ) + if args: + for arg in args.children: + if arg.type == "object_creation_expression": + for typ in arg.children: + if typ.type == "type_identifier": + event_type = typ.text.decode( + "utf-8", errors="replace" + ) + edges.append(EdgeInfo( + kind="PUBLISHES", + source=qualified_source, + target=f"event:{event_type}", + file_path=file_path, + line=node.start_point[0] + 1, + extra={"event_type": event_type}, + )) + break + for child in node.children: + self._walk_and_emit_publish_events(child, qualified_source, file_path, edges) + def _emit_kafka_edges_from_class( self, class_node, @@ -4444,6 +4605,17 @@ def _extract_functions( self._emit_kafka_edges_from_method( child, name, enclosing_class, file_path, edges, ) + if any(a.split("(")[0] in _SPRING_EVENT_LISTENER_ANNOTATIONS for a in deco_list): + method_extra["event_listener"] = True + self._emit_event_listener_from_method( + child, name, enclosing_class, file_path, edges, + ) + + # Detect publishEvent() calls in any Java method body + if language == "java" and enclosing_class: + self._emit_publish_event_edges_from_method( + child, name, enclosing_class, file_path, edges, + ) node = NodeInfo( kind=kind, diff --git a/tests/fixtures/SpringEvents.java b/tests/fixtures/SpringEvents.java new file mode 100644 index 00000000..3dd5d848 --- /dev/null +++ b/tests/fixtures/SpringEvents.java @@ -0,0 +1,46 @@ +package com.example.events; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.event.EventListener; +import org.springframework.stereotype.Component; +import org.springframework.stereotype.Service; + +// Plain event class — no Spring annotation +class OrderPlacedEvent { + private final Long orderId; + public OrderPlacedEvent(Long orderId) { this.orderId = orderId; } + public Long getOrderId() { return orderId; } +} + +// Publisher service +@Service +class OrderService { + private final ApplicationEventPublisher eventPublisher; + + public OrderService(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + } + + public void placeOrder(Long orderId) { + // business logic + eventPublisher.publishEvent(new OrderPlacedEvent(orderId)); + } +} + +// Listener — infers event type from parameter +@Component +class NotificationListener { + @EventListener + public void onOrderPlaced(OrderPlacedEvent event) { + // send notification + } +} + +// Listener — explicit annotation arg +@Component +class AuditListener { + @EventListener(OrderPlacedEvent.class) + public void auditOrder(OrderPlacedEvent event) { + // audit log + } +} diff --git a/tests/test_multilang.py b/tests/test_multilang.py index afda355e..1423234a 100644 --- a/tests/test_multilang.py +++ b/tests/test_multilang.py @@ -2041,6 +2041,46 @@ def test_contains_edges_wire_file_to_top_level_bindings(self): ) +class TestSpringEventListenerParsing: + """Tests for Spring @EventListener / publishEvent() edge detection.""" + + def setup_method(self): + self.parser = CodeParser() + self.nodes, self.edges = self.parser.parse_file(FIXTURES / "SpringEvents.java") + + def test_event_listener_inferred_from_parameter_emits_handles(self): + handles = [e for e in self.edges if e.kind == "HANDLES"] + targets = {e.target for e in handles} + assert "event:OrderPlacedEvent" in targets + + def test_event_listener_explicit_annotation_arg_emits_handles(self): + handles = [e for e in self.edges if e.kind == "HANDLES"] + sources = {e.source for e in handles} + # AuditListener uses explicit @EventListener(OrderPlacedEvent.class) + assert any("AuditListener" in s or "auditOrder" in s for s in sources) + + def test_publish_event_call_emits_publishes_edge(self): + publishes = [e for e in self.edges if e.kind == "PUBLISHES"] + assert publishes, "Expected at least one PUBLISHES edge" + targets = {e.target for e in publishes} + assert "event:OrderPlacedEvent" in targets + + def test_publish_event_source_is_publishing_method(self): + publishes = [e for e in self.edges if e.kind == "PUBLISHES"] + sources = {e.source for e in publishes} + assert any("placeOrder" in s or "OrderService" in s for s in sources) + + def test_handles_edge_stores_event_type_in_extra(self): + handles = [e for e in self.edges if e.kind == "HANDLES"] + event_types = {e.extra.get("event_type") for e in handles} + assert "OrderPlacedEvent" in event_types + + def test_publishes_edge_stores_event_type_in_extra(self): + publishes = [e for e in self.edges if e.kind == "PUBLISHES"] + event_types = {e.extra.get("event_type") for e in publishes} + assert "OrderPlacedEvent" in event_types + + class TestSpringDIParsing: """Tests for Spring DI annotation detection and INJECTS edge generation."""