Skip to content

Commit

Permalink
Add UML image layout rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonWehrhahn committed Jan 4, 2025
1 parent 7082cec commit def5b96
Show file tree
Hide file tree
Showing 29 changed files with 1,997 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import base64
from athena.schemas.grading_criterion import StructuredGradingCriterion
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate

from athena import emit_meta
from module_modeling_llm.config import BasicApproachConfig
from module_modeling_llm.helius_render.api import render_diagram
from module_modeling_llm.models.assessment_model import AssessmentModel
from module_modeling_llm.prompts.apollon_format_description import apollon_format_description
from llm_core.utils.predict_and_parse import predict_and_parse
from module_modeling_llm.prompts.graded_feedback_prompt import GradedFeedbackInputs
from module_modeling_llm.models.exercise_model import ExerciseModel
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate

async def generate_suggestions(
exercise_model: ExerciseModel,
Expand Down Expand Up @@ -37,6 +40,19 @@ async def generate_suggestions(
feedback_output_format=PydanticOutputParser(pydantic_object=AssessmentModel).get_format_instructions()
)

diagram_json = exercise_model.model
png_data = render_diagram(diagram_json, {value: key for key, value in exercise_model.element_id_mapping.items()})
base64_image = base64.b64encode(png_data).decode("utf-8")

chat_prompt = ChatPromptTemplate.from_messages([
("system", config.generate_suggestions_prompt.graded_feedback_system_message),
("human", config.generate_suggestions_prompt.graded_feedback_human_message),
HumanMessagePromptTemplate.from_template(
[{'image_url': {'url': f'data:image/jpeg;base64,{base64_image}', 'detail': 'high'}}]
)
])


chat_prompt = ChatPromptTemplate.from_messages([
("system", config.generate_suggestions_prompt.graded_feedback_system_message),
("human", config.generate_suggestions_prompt.graded_feedback_human_message)])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json
from typing import Dict, Optional, cast
from module_modeling_llm.helius_render.models.diagram import UMLDiagram
from module_modeling_llm.helius_render.utils.config_loader import load_all_configs
from module_modeling_llm.helius_render.renderers.uml_renderer import UMLRenderer
from module_modeling_llm.helius_render.utils.css_loader import load_css

# Global initialization
# Load configs and css once
_CONFIGS = load_all_configs()
_CSS = load_css()
_RENDERER = UMLRenderer(_CONFIGS, _CSS)

def render_diagram(json_data: str, name_map: Optional[Dict[str, str]] = None) -> bytes:

# Parse diagram
diagram_data = json.loads(json_data)
diagram = cast(UMLDiagram, diagram_data)

if name_map:
for elem in diagram['elements'].values():
elem_id = elem['id']
if elem_id in name_map:
elem['name'] = name_map[elem_id]

for rel in diagram['relationships'].values():
rel_id = rel['id']
if rel_id in name_map:
rel['name'] = name_map[rel_id]

# Render using the pre-initialized renderer
png_data = _RENDERER.render_to_bytes(diagram)

return png_data
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"default": {
"shape": "rectangle",
"class_name": "uml-element",
"text_class": "uml-element-name"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"default": {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"arrow": {
"path": "M 0 0 L 10 5 L 0 10 z",
"viewBox": "0 0 10 10",
"refX": "9",
"refY": "5",
"fill": "black"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

@dataclass
class Bounds:
x: float
y: float
width: float
height: float
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TypedDict, Optional, Dict

class ElementConfigEntry(TypedDict):
shape: str
class_name: str
text_class: str

class RelationshipConfigEntry(TypedDict):
marker_end: Optional[str]
stroke_dasharray: Optional[str]

class MarkerConfigEntry(TypedDict):
path: str
viewBox: str
refX: str
refY: str
fill: str

ElementConfig = Dict[str, ElementConfigEntry]
RelationshipConfig = Dict[str, RelationshipConfigEntry]
MarkerConfig = Dict[str, MarkerConfigEntry]

class AllConfigs(TypedDict):
elements: ElementConfig
relationships: RelationshipConfig
markers: MarkerConfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Dict, Any, TypedDict
from .element import Element
from .relationship import Relationship

class UMLDiagram(TypedDict):
id: str
title: str
elements: Dict[str, Element]
relationships: Dict[str, Relationship]
version: str
type: str
size: Dict[str, int]
interactive: Dict[str, Any]
assessments: Dict[str, Any]
lastUpdate: str
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Dict, List, Optional, TypedDict, Any

class Element(TypedDict):
id: str
type: str
name: str
owner: Optional[str]
bounds: Dict[str, float]
attributes: Optional[List[str]]
methods: Optional[List[str]]
properties: Optional[Dict[str, Any]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, List, Optional, TypedDict

class Message(TypedDict):
type: str
name: str
direction: str # 'target' or 'source'

class EndpointData(TypedDict):
element: str
multiplicity: Optional[str]
role: Optional[str]
direction: Optional[str]

class Relationship(TypedDict):
id: str
type: str
name: str
owner: Optional[str]
source: EndpointData
target: EndpointData
path: List[Dict[str, float]]
bounds: Dict[str, float]
isManuallyLayouted: Optional[bool]
stroke_dasharray: Optional[str]
marker_start: Optional[str]
marker_end: Optional[str]
messages: Optional[List[Message]]
_source_point: Optional[Dict[str, float]]
_target_point: Optional[Dict[str, float]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Dict
import xml.etree.ElementTree as ET

from jinja2 import Template
from module_modeling_llm.helius_render.models.bounds import Bounds
from module_modeling_llm.helius_render.models.config_types import ElementConfig, ElementConfigEntry
from module_modeling_llm.helius_render.models.element import Element
from module_modeling_llm.helius_render.utils.template_manager import TemplateManager

class ElementRenderer:
"""
Renders UML elements (like classes) into an SVG <g> element using a Jinja2 template.
"""

def __init__(self, element_config: ElementConfig, template_manager: TemplateManager):
self.element_config = element_config
self.template_manager = template_manager

def render(self, element: Element, svg: ET.Element) -> None:
"""
Render a single UML element into the given SVG root.
Args:
element (Element): The UML element to render.
svg (ET.Element): The SVG root element to append to.
elements_by_id (Dict[str, Element]): All elements keyed by ID (not always needed here).
"""

elem_type = element.get('type', 'default')
config: ElementConfigEntry = self.element_config.get(elem_type, self.element_config['default'])
bounds = Bounds(**element['bounds'])

template: Template = self.template_manager.get_template('element.svg.jinja')
svg_content: str = template.render(
element=element,
bounds=bounds,
element_shape=config['shape'],
element_class=config['class_name'],
element_text_class=config['text_class']
)
group: ET.Element = ET.fromstring(svg_content)
svg.append(group)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Dict, List, Tuple
import xml.etree.ElementTree as ET
from module_modeling_llm.helius_render.services.path_service import compute_relationship_path
from module_modeling_llm.helius_render.models.element import Element
from module_modeling_llm.helius_render.models.relationship import Relationship
from module_modeling_llm.helius_render.models.config_types import RelationshipConfig
from module_modeling_llm.helius_render.utils.template_manager import TemplateManager

class RelationshipRenderer:
"""
Renders UML relationships into SVG <path> elements
"""
def __init__(self, relationship_config: RelationshipConfig, template_manager: TemplateManager) -> None:
self.rel_config = relationship_config
self.template_manager = template_manager

def render_relationship(self, rel: Relationship, svg: ET.Element, elements_by_id: Dict[str, Element]) -> None:
"""
Render a UML relationship as an SVG path.
Args:
rel (Relationship): The relationship data.
svg (ET.Element): The SVG parent element to append the path to.
elements_by_id (Dict[str, Element]): Map of element IDs to element objects.
Raises:
ValueError: If source or target elements are missing.
"""
source_element = elements_by_id.get(rel['source']['element'])
target_element = elements_by_id.get(rel['target']['element'])

if not source_element or not target_element:
raise ValueError(f"Invalid relationship {rel['id']}, missing source or target.")

# Compute the path for the relationship
rel['path'] = compute_relationship_path(source_element, target_element, rel)

# Compute a true midpoint along the entire polyline
mid_x, mid_y = self._compute_midpoint_along_path(rel['path'])

template = self.template_manager.get_template('relationship_path.svg.jinja')
svg_content = template.render(
rel=rel,
path_d=self._create_path_string(rel['path']),
mid_x=mid_x,
mid_y=mid_y
)
element = ET.fromstring(svg_content)
svg.append(element)

def _create_path_string(self, points: List[Dict[str, float]]) -> str:
if not points:
return ""
path = f"M {points[0]['x']} {points[0]['y']}"
for p in points[1:]:
path += f" L {p['x']} {p['y']}"
return path

def _compute_midpoint_along_path(self, path_points: List[Dict[str, float]]) -> Tuple[float, float]:
if not path_points:
return (0,0)

# Compute total length of the polyline and store segments
total_length = 0.0
segments = []
for i in range(len(path_points)-1):
p1 = path_points[i]
p2 = path_points[i+1]
dx = p2['x'] - p1['x']
dy = p2['y'] - p1['y']
seg_length = (dx**2 + dy**2)**0.5
segments.append((p1, p2, seg_length))
total_length += seg_length

# Target distance is half of total length
half_length = total_length / 2.0

# Walk along segments until we find the segment containing the midpoint
distance_covered = 0.0
for (start, end, seg_length) in segments:
if distance_covered + seg_length == half_length:
# Midpoint lies exactly at the end of this segment
return (end['x'], end['y'])
elif distance_covered + seg_length > half_length:
# Midpoint lies within this segment
remaining = half_length - distance_covered
ratio = remaining / seg_length
mid_x = start['x'] + ratio * (end['x'] - start['x'])
mid_y = start['y'] + ratio * (end['y'] - start['y'])
return (mid_x, mid_y)
distance_covered += seg_length

# Fallback: if something went wrong, return last point
return (path_points[-1]['x'], path_points[-1]['y'])
Loading

0 comments on commit def5b96

Please sign in to comment.