Skip to content

Commit def5b96

Browse files
committed
Add UML image layout rendering
1 parent 7082cec commit def5b96

29 files changed

+1997
-24
lines changed

modules/modeling/module_modeling_llm/module_modeling_llm/core/generate_suggestions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
import base64
12
from athena.schemas.grading_criterion import StructuredGradingCriterion
23
from langchain_core.output_parsers import PydanticOutputParser
34
from langchain_core.prompts import ChatPromptTemplate
45

56
from athena import emit_meta
67
from module_modeling_llm.config import BasicApproachConfig
8+
from module_modeling_llm.helius_render.api import render_diagram
79
from module_modeling_llm.models.assessment_model import AssessmentModel
810
from module_modeling_llm.prompts.apollon_format_description import apollon_format_description
911
from llm_core.utils.predict_and_parse import predict_and_parse
1012
from module_modeling_llm.prompts.graded_feedback_prompt import GradedFeedbackInputs
1113
from module_modeling_llm.models.exercise_model import ExerciseModel
14+
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1215

1316
async def generate_suggestions(
1417
exercise_model: ExerciseModel,
@@ -37,6 +40,19 @@ async def generate_suggestions(
3740
feedback_output_format=PydanticOutputParser(pydantic_object=AssessmentModel).get_format_instructions()
3841
)
3942

43+
diagram_json = exercise_model.model
44+
png_data = render_diagram(diagram_json, {value: key for key, value in exercise_model.element_id_mapping.items()})
45+
base64_image = base64.b64encode(png_data).decode("utf-8")
46+
47+
chat_prompt = ChatPromptTemplate.from_messages([
48+
("system", config.generate_suggestions_prompt.graded_feedback_system_message),
49+
("human", config.generate_suggestions_prompt.graded_feedback_human_message),
50+
HumanMessagePromptTemplate.from_template(
51+
[{'image_url': {'url': f'data:image/jpeg;base64,{base64_image}', 'detail': 'high'}}]
52+
)
53+
])
54+
55+
4056
chat_prompt = ChatPromptTemplate.from_messages([
4157
("system", config.generate_suggestions_prompt.graded_feedback_system_message),
4258
("human", config.generate_suggestions_prompt.graded_feedback_human_message)])
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
from typing import Dict, Optional, cast
3+
from module_modeling_llm.helius_render.models.diagram import UMLDiagram
4+
from module_modeling_llm.helius_render.utils.config_loader import load_all_configs
5+
from module_modeling_llm.helius_render.renderers.uml_renderer import UMLRenderer
6+
from module_modeling_llm.helius_render.utils.css_loader import load_css
7+
8+
# Global initialization
9+
# Load configs and css once
10+
_CONFIGS = load_all_configs()
11+
_CSS = load_css()
12+
_RENDERER = UMLRenderer(_CONFIGS, _CSS)
13+
14+
def render_diagram(json_data: str, name_map: Optional[Dict[str, str]] = None) -> bytes:
15+
16+
# Parse diagram
17+
diagram_data = json.loads(json_data)
18+
diagram = cast(UMLDiagram, diagram_data)
19+
20+
if name_map:
21+
for elem in diagram['elements'].values():
22+
elem_id = elem['id']
23+
if elem_id in name_map:
24+
elem['name'] = name_map[elem_id]
25+
26+
for rel in diagram['relationships'].values():
27+
rel_id = rel['id']
28+
if rel_id in name_map:
29+
rel['name'] = name_map[rel_id]
30+
31+
# Render using the pre-initialized renderer
32+
png_data = _RENDERER.render_to_bytes(diagram)
33+
34+
return png_data
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"default": {
3+
"shape": "rectangle",
4+
"class_name": "uml-element",
5+
"text_class": "uml-element-name"
6+
}
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"default": {}
3+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"arrow": {
3+
"path": "M 0 0 L 10 5 L 0 10 z",
4+
"viewBox": "0 0 10 10",
5+
"refX": "9",
6+
"refY": "5",
7+
"fill": "black"
8+
}
9+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dataclasses import dataclass
2+
3+
@dataclass
4+
class Bounds:
5+
x: float
6+
y: float
7+
width: float
8+
height: float
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import TypedDict, Optional, Dict
2+
3+
class ElementConfigEntry(TypedDict):
4+
shape: str
5+
class_name: str
6+
text_class: str
7+
8+
class RelationshipConfigEntry(TypedDict):
9+
marker_end: Optional[str]
10+
stroke_dasharray: Optional[str]
11+
12+
class MarkerConfigEntry(TypedDict):
13+
path: str
14+
viewBox: str
15+
refX: str
16+
refY: str
17+
fill: str
18+
19+
ElementConfig = Dict[str, ElementConfigEntry]
20+
RelationshipConfig = Dict[str, RelationshipConfigEntry]
21+
MarkerConfig = Dict[str, MarkerConfigEntry]
22+
23+
class AllConfigs(TypedDict):
24+
elements: ElementConfig
25+
relationships: RelationshipConfig
26+
markers: MarkerConfig
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Dict, Any, TypedDict
2+
from .element import Element
3+
from .relationship import Relationship
4+
5+
class UMLDiagram(TypedDict):
6+
id: str
7+
title: str
8+
elements: Dict[str, Element]
9+
relationships: Dict[str, Relationship]
10+
version: str
11+
type: str
12+
size: Dict[str, int]
13+
interactive: Dict[str, Any]
14+
assessments: Dict[str, Any]
15+
lastUpdate: str
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Dict, List, Optional, TypedDict, Any
2+
3+
class Element(TypedDict):
4+
id: str
5+
type: str
6+
name: str
7+
owner: Optional[str]
8+
bounds: Dict[str, float]
9+
attributes: Optional[List[str]]
10+
methods: Optional[List[str]]
11+
properties: Optional[Dict[str, Any]]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Dict, List, Optional, TypedDict
2+
3+
class Message(TypedDict):
4+
type: str
5+
name: str
6+
direction: str # 'target' or 'source'
7+
8+
class EndpointData(TypedDict):
9+
element: str
10+
multiplicity: Optional[str]
11+
role: Optional[str]
12+
direction: Optional[str]
13+
14+
class Relationship(TypedDict):
15+
id: str
16+
type: str
17+
name: str
18+
owner: Optional[str]
19+
source: EndpointData
20+
target: EndpointData
21+
path: List[Dict[str, float]]
22+
bounds: Dict[str, float]
23+
isManuallyLayouted: Optional[bool]
24+
stroke_dasharray: Optional[str]
25+
marker_start: Optional[str]
26+
marker_end: Optional[str]
27+
messages: Optional[List[Message]]
28+
_source_point: Optional[Dict[str, float]]
29+
_target_point: Optional[Dict[str, float]]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Dict
2+
import xml.etree.ElementTree as ET
3+
4+
from jinja2 import Template
5+
from module_modeling_llm.helius_render.models.bounds import Bounds
6+
from module_modeling_llm.helius_render.models.config_types import ElementConfig, ElementConfigEntry
7+
from module_modeling_llm.helius_render.models.element import Element
8+
from module_modeling_llm.helius_render.utils.template_manager import TemplateManager
9+
10+
class ElementRenderer:
11+
"""
12+
Renders UML elements (like classes) into an SVG <g> element using a Jinja2 template.
13+
"""
14+
15+
def __init__(self, element_config: ElementConfig, template_manager: TemplateManager):
16+
self.element_config = element_config
17+
self.template_manager = template_manager
18+
19+
def render(self, element: Element, svg: ET.Element) -> None:
20+
"""
21+
Render a single UML element into the given SVG root.
22+
23+
Args:
24+
element (Element): The UML element to render.
25+
svg (ET.Element): The SVG root element to append to.
26+
elements_by_id (Dict[str, Element]): All elements keyed by ID (not always needed here).
27+
"""
28+
29+
elem_type = element.get('type', 'default')
30+
config: ElementConfigEntry = self.element_config.get(elem_type, self.element_config['default'])
31+
bounds = Bounds(**element['bounds'])
32+
33+
template: Template = self.template_manager.get_template('element.svg.jinja')
34+
svg_content: str = template.render(
35+
element=element,
36+
bounds=bounds,
37+
element_shape=config['shape'],
38+
element_class=config['class_name'],
39+
element_text_class=config['text_class']
40+
)
41+
group: ET.Element = ET.fromstring(svg_content)
42+
svg.append(group)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Dict, List, Tuple
2+
import xml.etree.ElementTree as ET
3+
from module_modeling_llm.helius_render.services.path_service import compute_relationship_path
4+
from module_modeling_llm.helius_render.models.element import Element
5+
from module_modeling_llm.helius_render.models.relationship import Relationship
6+
from module_modeling_llm.helius_render.models.config_types import RelationshipConfig
7+
from module_modeling_llm.helius_render.utils.template_manager import TemplateManager
8+
9+
class RelationshipRenderer:
10+
"""
11+
Renders UML relationships into SVG <path> elements
12+
"""
13+
def __init__(self, relationship_config: RelationshipConfig, template_manager: TemplateManager) -> None:
14+
self.rel_config = relationship_config
15+
self.template_manager = template_manager
16+
17+
def render_relationship(self, rel: Relationship, svg: ET.Element, elements_by_id: Dict[str, Element]) -> None:
18+
"""
19+
Render a UML relationship as an SVG path.
20+
Args:
21+
rel (Relationship): The relationship data.
22+
svg (ET.Element): The SVG parent element to append the path to.
23+
elements_by_id (Dict[str, Element]): Map of element IDs to element objects.
24+
Raises:
25+
ValueError: If source or target elements are missing.
26+
"""
27+
source_element = elements_by_id.get(rel['source']['element'])
28+
target_element = elements_by_id.get(rel['target']['element'])
29+
30+
if not source_element or not target_element:
31+
raise ValueError(f"Invalid relationship {rel['id']}, missing source or target.")
32+
33+
# Compute the path for the relationship
34+
rel['path'] = compute_relationship_path(source_element, target_element, rel)
35+
36+
# Compute a true midpoint along the entire polyline
37+
mid_x, mid_y = self._compute_midpoint_along_path(rel['path'])
38+
39+
template = self.template_manager.get_template('relationship_path.svg.jinja')
40+
svg_content = template.render(
41+
rel=rel,
42+
path_d=self._create_path_string(rel['path']),
43+
mid_x=mid_x,
44+
mid_y=mid_y
45+
)
46+
element = ET.fromstring(svg_content)
47+
svg.append(element)
48+
49+
def _create_path_string(self, points: List[Dict[str, float]]) -> str:
50+
if not points:
51+
return ""
52+
path = f"M {points[0]['x']} {points[0]['y']}"
53+
for p in points[1:]:
54+
path += f" L {p['x']} {p['y']}"
55+
return path
56+
57+
def _compute_midpoint_along_path(self, path_points: List[Dict[str, float]]) -> Tuple[float, float]:
58+
if not path_points:
59+
return (0,0)
60+
61+
# Compute total length of the polyline and store segments
62+
total_length = 0.0
63+
segments = []
64+
for i in range(len(path_points)-1):
65+
p1 = path_points[i]
66+
p2 = path_points[i+1]
67+
dx = p2['x'] - p1['x']
68+
dy = p2['y'] - p1['y']
69+
seg_length = (dx**2 + dy**2)**0.5
70+
segments.append((p1, p2, seg_length))
71+
total_length += seg_length
72+
73+
# Target distance is half of total length
74+
half_length = total_length / 2.0
75+
76+
# Walk along segments until we find the segment containing the midpoint
77+
distance_covered = 0.0
78+
for (start, end, seg_length) in segments:
79+
if distance_covered + seg_length == half_length:
80+
# Midpoint lies exactly at the end of this segment
81+
return (end['x'], end['y'])
82+
elif distance_covered + seg_length > half_length:
83+
# Midpoint lies within this segment
84+
remaining = half_length - distance_covered
85+
ratio = remaining / seg_length
86+
mid_x = start['x'] + ratio * (end['x'] - start['x'])
87+
mid_y = start['y'] + ratio * (end['y'] - start['y'])
88+
return (mid_x, mid_y)
89+
distance_covered += seg_length
90+
91+
# Fallback: if something went wrong, return last point
92+
return (path_points[-1]['x'], path_points[-1]['y'])

0 commit comments

Comments
 (0)