Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor viz #1174

Draft
wants to merge 3 commits into
base: robynpy_release
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/src/robyn/modeling/convergence/convergence.py
Original file line number Diff line number Diff line change
@@ -68,19 +68,22 @@ def calculate_convergence(self, trials: List[Trial]) -> Dict[str, Any]:

# Create visualization plots
self.logger.info("Creating visualization plots")
moo_distrb_plot = self.visualizer.create_moo_distrb_plot(
dt_objfunc_cvg, conv_msg
)
moo_cloud_plot = self.visualizer.create_moo_cloud_plot(
df, conv_msg, calibrated
)
ts_validation_plot = None # self.visualizer.create_ts_validation_plot(trials) #Disabled for testing. #Sandeep
plots_dict = {
"moo_distrb_plot": self.visualizer.create_moo_distrb_plot(
dt_objfunc_cvg, conv_msg
),
"moo_cloud_plot": self.visualizer.create_moo_cloud_plot(
df, conv_msg, calibrated
),
"ts_validation_plot": None # Disabled for testing
}

# Display the plots
self.visualizer.display_convergence_plots(plots_dict)

self.logger.info("Convergence calculation completed successfully")
return {
"moo_distrb_plot": moo_distrb_plot,
"moo_cloud_plot": moo_cloud_plot,
"ts_validation_plot": ts_validation_plot,
**plots_dict,
"errors": errors,
"conv_msg": conv_msg,
}
2 changes: 1 addition & 1 deletion python/src/robyn/reporting/onepager_reporting.py
Original file line number Diff line number Diff line change
@@ -39,9 +39,9 @@ def __init__(

# Default plots using PlotType enum directly
self.default_plots = [
PlotType.SPEND_EFFECT,
PlotType.WATERFALL,
PlotType.FITTED_VS_ACTUAL,
PlotType.SPEND_EFFECT,
PlotType.BOOTSTRAP,
PlotType.ADSTOCK,
PlotType.IMMEDIATE_CARRYOVER,
265 changes: 234 additions & 31 deletions python/src/robyn/visualization/base_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,185 @@
# pyre-strict

import logging

from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple, Union, List
from pathlib import Path
from IPython.display import Image, display

import matplotlib.pyplot as plt
import numpy as np
import base64
import io
from IPython.display import Image, display

# Configure logger
logger = logging.getLogger(__name__)


class BaseVisualizer(ABC):
"""
Base class for all Robyn visualization components.
Provides common plotting functionality and styling.
Enhanced base class for all Robyn visualization components.
Provides standardized plotting functionality and styling.
"""

def __init__(self, style: str = "bmh"):
"""
Initialize BaseVisualizer with common plot settings.
Initialize BaseVisualizer with standardized plot settings.

Args:
style: matplotlib style to use (default: "bmh")
"""
logger.info("Initializing BaseVisualizer with style: %s", style)

# Store style settings
self.style = style
self.default_figsize = (12, 8)
# Standard figure sizes
self.figure_sizes = {
"default": (12, 8),
"wide": (16, 8),
"square": (10, 10),
"tall": (8, 12),
"small": (8, 6),
"large": (15, 10),
"medium": (10, 6)
}

# Enhanced color schemes
# Standardized color schemes
self.colors = {
# Primary colors for main data series
"primary": "#4688C7", # Steel blue
"secondary": "#FF9F1C", # Orange
"tertiary": "#37B067", # Green
# Status colors
"positive": "#2ECC71", # Green
"negative": "#E74C3C", # Red
"neutral": "#95A5A6", # Gray
"current": "lightgray", # For current values
"optimal": "#4688C7", # For optimal values
"grid": "#E0E0E0", # For grid lines
# Chart elements
"grid": "#E0E0E0", # Light gray for grid lines
"baseline": "#CCCCCC", # Medium gray for baseline/reference lines
"annotation": "#666666", # Dark gray for annotations
# Channel-specific colors (for consistency across plots)
"channels": {
"facebook": "#3B5998",
"search": "#4285F4",
"display": "#34A853",
"youtube": "#FF0000",
"twitter": "#1DA1F2",
"email": "#DB4437",
"print": "#9C27B0",
"tv": "#E91E63",
"radio": "#795548",
"ooh": "#607D8B",
},
}

# Standard line styles
self.line_styles = {
"solid": "-",
"dashed": "--",
"dotted": ":",
"dashdot": "-.",
}
logger.debug("Color scheme initialized: %s", self.colors)

# Plot settings
self.font_sizes = {
"title": 14,
"subtitle": 12,
"label": 12,
"tick": 10,
"annotation": 9,
"legend": 10,

# Standard markers
self.markers = {
"circle": "o",
"square": "s",
"triangle": "^",
"diamond": "D",
"plus": "+",
"cross": "x",
"star": "*",
}
logger.debug("Font sizes configured: %s", self.font_sizes)

# Default alpha values
self.alpha = {"primary": 0.7, "secondary": 0.5, "grid": 0.3, "annotation": 0.7}
# Font configurations
self.fonts = {
"family": "sans-serif",
"sizes": {
"title": 14,
"subtitle": 12,
"label": 11,
"tick": 10,
"annotation": 9,
"legend": 10,
"small": 8,
},
}

# Default spacing
self.spacing = {"tight_layout_pad": 1.05, "subplot_adjust_hspace": 0.4}
# Common alpha values
self.alpha = {
"primary": 0.8,
"secondary": 0.6,
"grid": 0.3,
"annotation": 0.7,
"highlight": 0.9,
"background": 0.2,
}

# Standard spacing
self.spacing = {
"tight_layout_pad": 1.05,
"subplot_adjust_hspace": 0.4,
"label_pad": 10,
"title_pad": 20,
}

# Initialize plot tracking
self.current_figure: Optional[plt.Figure] = None
self.current_axes: Optional[Union[plt.Axes, np.ndarray]] = None

# Apply default style
# Apply default style and settings
self._setup_plot_style()
logger.info("BaseVisualizer initialization completed")

def format_number(self, x: float, pos=None) -> str:
"""Format large numbers with K/M/B abbreviations.

Args:
x: Number to format
pos: Position parameter (required by matplotlib formatter but not used)

Returns:
Formatted string representation of the number
"""
try:
if abs(x) >= 1e9:
return f"{x/1e9:.1f}B"
elif abs(x) >= 1e6:
return f"{x/1e6:.1f}M"
elif abs(x) >= 1e3:
return f"{x/1e3:.1f}K"
else:
return f"{x:.1f}"
except (TypeError, ValueError):
return str(x)

def _setup_plot_style(self) -> None:
"""Configure default plotting style."""
logger.debug("Setting up plot style with style: %s", self.style)
logger.debug("Setting up plot style")
try:
plt.style.use(self.style)

plt.rcParams.update(
{
"figure.figsize": self.default_figsize,
# Figure settings
"figure.figsize": self.figure_sizes["default"],
"figure.facecolor": "white",
# Font settings
"font.family": self.fonts["family"],
"font.size": self.fonts["sizes"]["label"],
# Axes settings
"axes.grid": True,
"axes.spines.top": False,
"axes.spines.right": False,
"font.size": self.font_sizes["label"],
"axes.labelsize": self.fonts["sizes"]["label"],
"axes.titlesize": self.fonts["sizes"]["title"],
# Grid settings
"grid.alpha": self.alpha["grid"],
"grid.color": self.colors["grid"],
# Legend settings
"legend.fontsize": self.fonts["sizes"]["legend"],
"legend.framealpha": self.alpha["annotation"],
# Tick settings
"xtick.labelsize": self.fonts["sizes"]["tick"],
"ytick.labelsize": self.fonts["sizes"]["tick"],
}
)
logger.debug("Plot style parameters updated successfully")
@@ -191,6 +283,117 @@ def setup_axis(
logger.error("Failed to setup axis: %s", str(e))
raise

def _add_standardized_grid(
self,
ax: plt.Axes,
axis: str = "both",
alpha: Optional[float] = None,
color: Optional[str] = None,
linestyle: Optional[str] = None
) -> None:
"""Add standardized grid to plot."""
ax.grid(
True,
axis=axis,
alpha=alpha or self.alpha["grid"],
color=color or self.colors["grid"],
linestyle=linestyle or self.line_styles["solid"],
zorder=0
)
ax.set_axisbelow(True)

def _add_standardized_legend(
self,
ax: plt.Axes,
title: Optional[str] = None,
loc: str = "lower right",
ncol: int = 1,
handles: Optional[List] = None,
labels: Optional[List[str]] = None,
) -> None:
"""Add standardized legend to plot.

Args:
ax: Matplotlib axes to add legend to
title: Optional legend title
loc: Legend location
ncol: Number of columns in legend
handles: Optional list of legend handles
labels: Optional list of legend labels
"""
legend_handles = handles if handles is not None else ax.get_legend_handles_labels()[0]
legend_labels = labels if labels is not None else ax.get_legend_handles_labels()[1]

legend = ax.legend(
handles=legend_handles,
labels=legend_labels,
title=title,
loc=loc,
ncol=ncol,
fontsize=self.fonts["sizes"]["legend"],
framealpha=self.alpha["annotation"],
title_fontsize=self.fonts["sizes"]["subtitle"]
)
if legend:
legend.get_frame().set_linewidth(0.5)
legend.get_frame().set_edgecolor(self.colors["grid"])

def _set_standardized_labels(
self,
ax: plt.Axes,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
title: Optional[str] = None
) -> None:
"""Set standardized labels for plot."""
if xlabel:
ax.set_xlabel(
xlabel,
fontsize=self.fonts["sizes"]["label"],
labelpad=self.spacing["label_pad"]
)
if ylabel:
ax.set_ylabel(
ylabel,
fontsize=self.fonts["sizes"]["label"],
labelpad=self.spacing["label_pad"]
)
if title:
ax.set_title(
title,
fontsize=self.fonts["sizes"]["title"],
pad=self.spacing["title_pad"]
)

def _format_standardized_ticks(
self,
ax: plt.Axes,
x_rotation: int = 0,
y_rotation: int = 0
) -> None:
"""Format tick labels with standardized styling."""
ax.tick_params(
axis='both',
labelsize=self.fonts["sizes"]["tick"]
)
plt.setp(
ax.get_xticklabels(),
rotation=x_rotation,
ha='right' if x_rotation > 0 else 'center'
)
plt.setp(
ax.get_yticklabels(),
rotation=y_rotation,
va='center'
)

def _set_standardized_spines(self, ax: plt.Axes, spines: List[str] = None) -> None:
"""Configure plot spines with standardized styling."""
if spines is None:
spines = ['top', 'right']
for spine in spines:
ax.spines[spine].set_visible(False)

def add_percentage_annotation(
self,
ax: plt.Axes,
190 changes: 73 additions & 117 deletions python/src/robyn/visualization/feature_visualization.py
Original file line number Diff line number Diff line change
@@ -92,29 +92,6 @@ def plot_spend_exposure(
) -> Dict[str, plt.Figure]:
"""
Generates a spend-exposure plot for a specified channel.
Parameters:
-----------
channel : str
The name of the channel for which the spend-exposure plot is to be generated.
Returns:
--------
plt.Figure
The matplotlib Figure object containing the spend-exposure plot.
Raises:
-------
ValueError
If no spend-exposure data or plot data is available for the specified channel.
Exception
If any other error occurs during the plot generation process.
Notes:
------
The function retrieves the model results and plot data for the specified channel from the featurized_mmmdata attribute.
It creates a scatter plot of the actual data and a fitted line plot. The plot includes model information such as
model type, R-squared value, and model-specific parameters (e.g., Vmax and Km for Michaelis-Menten model or coefficient for linear model).
"""
logger.info("Generating spend-exposure plot for channel: %s", channel)

@@ -128,67 +105,83 @@ def plot_spend_exposure(
),
None,
)
logger.info("Found result for channel %s", channel)
if res is None:
logger.error("Channel %s not found in featurized data results", channel)
raise ValueError(
f"No spend-exposure data available for channel: {channel}"
)
raise ValueError(f"No spend-exposure data available for channel: {channel}")

plot_data = self.featurized_mmmdata.modNLS["plots"].get(channel)
if plot_data is None:
logger.error("Plot data for channel %s not found", channel)
raise ValueError(f"No plot data available for channel: {channel}")
fig, ax = plt.subplots(figsize=(10, 6))

# Create figure using base visualizer methods
fig, ax = self.create_figure(figsize=self.figure_sizes["medium"])

# Plot scatter of actual data
sns.scatterplot(
x="spend",
y="exposure",
data=plot_data,
ax=ax,
alpha=0.6,
alpha=self.alpha["primary"],
label="Actual",
color=self.colors["primary"]
)
logger.debug("Created scatter plot for actual data")

# Plot fitted line
sns.lineplot(
x="spend", y="yhat", data=plot_data, ax=ax, color="red", label="Fitted"
x="spend",
y="yhat",
data=plot_data,
ax=ax,
color=self.colors["secondary"],
label="Fitted"
)
logger.debug("Added fitted line to plot")
ax.set_xlabel(f"Spend [{channel}]")
ax.set_ylabel(f"Exposure [{channel}]")
ax.set_title(f"Spend vs Exposure for {channel}")
# Add model information to the plot

# Set labels and title using base visualizer methods
self._set_standardized_labels(
ax,
xlabel=f"Spend [{channel}]",
ylabel=f"Exposure [{channel}]",
title=f"Spend vs Exposure for {channel}"
)

# Add model information
model_type = res["model_type"]
rsq = res["rsq"]
logger.debug("Model type: %s, R-squared: %f", model_type, rsq)

if model_type == "nls":
Vmax, Km = res["Vmax"], res["Km"]
ax.text(
0.05,
0.95,
f"Model: Michaelis-Menten\nR² = {rsq:.4f}\nVmax = {Vmax:.2f}\nKm = {Km:.2f}",
transform=ax.transAxes,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.7),
)
logger.debug("Added NLS model parameters: Vmax=%f, Km=%f", Vmax, Km)
text = f"Model: Michaelis-Menten\nR² = {rsq:.4f}\nVmax = {Vmax:.2f}\nKm = {Km:.2f}"
else:
coef = res["coef_lm"]
ax.text(
0.05,
0.95,
f"Model: Linear\nR² = {rsq:.4f}\nCoefficient = {coef:.4f}",
transform=ax.transAxes,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.7),
)
logger.debug("Added linear model parameters: coefficient=%f", coef)
plt.legend()
plt.tight_layout()
plt.close()
logger.info(
"Successfully generated spend-exposure plot for channel %s", channel
text = f"Model: Linear\nR² = {rsq:.4f}\nCoefficient = {coef:.4f}"

# Add text box with model information
ax.text(
0.05,
0.95,
text,
transform=ax.transAxes,
verticalalignment="top",
bbox=dict(
boxstyle="round",
facecolor="white",
alpha=self.alpha["annotation"],
edgecolor=self.colors["grid"]
),
fontsize=self.fonts["sizes"]["annotation"]
)

# Add grid and style using base visualizer methods
self._add_standardized_grid(ax)
self._set_standardized_spines(ax)
self._add_standardized_legend(ax, loc='lower right')

# Finalize the figure
self.finalize_figure(tight_layout=True)

logger.info("Successfully generated spend-exposure plot for channel %s", channel)

self.cleanup()
return {"spend-exposure": fig}
except Exception as e:
logger.error(
@@ -199,71 +192,34 @@ def plot_spend_exposure(
)
raise

def plot_feature_importance(
self, feature_importance: Dict[str, float], display: bool = True
) -> Dict[str, plt.Figure]:
"""
Plot the importance of different features in the model.
Args:
feature_importance (Dict[str, float]): Dictionary of feature importances.
Returns:
plt.Figure: A matplotlib Figure object containing the feature importance plot.
"""
logger.info("Generating feature importance plot")
logger.debug("Feature importance data: %s", feature_importance)
try:
# Implementation placeholder
logger.warning("plot_feature_importance method not implemented yet")

except Exception as e:
logger.error("Failed to generate feature importance plot: %s", str(e))
raise

def plot_response_curves(self, display: bool = True) -> Dict[str, plt.Figure]:
"""
Plot response curves for different channels.
Args:
self.featurized_mmmdata (FeaturizedMMMData): The featurized data after feature engineering.
Returns:
Dict[str, plt.Figure]: Dictionary mapping channel names to their respective response curve plots.
"""
logger.info("Generating response curves")
logger.debug("Processing featurized data: %s", self.featurized_mmmdata)
try:
dt_mod = self.featurized_mmmdata.dt_mod
logger.debug("Modified data: %s", dt_mod)
# Rest of the method implementation
logger.warning("plot_response_curves method not fully implemented yet")

except Exception as e:
logger.error("Failed to generate response curves: %s", str(e))
raise

def plot_all(
self, display_plots: bool = True, export_location: Union[str, Path] = None
) -> Dict[str, plt.Figure]:
"""
Override the abstract method plot_all from BaseVisualizer.
Generate all plots available in the feature plotter.
"""
logger.info("Generating all plots")
plot_collect: Dict[str, plt.Figure] = {}

try:
for item in self.featurized_mmmdata.modNLS["results"]:
channel = item["channel"]
self.plot_adstock(channel, display_plots)
# plot_collect.update(self.plot_adstock(channel, display))
# plot_collect.update(self.plot_saturation(channel, display))
plot_collect[channel] = self.plot_spend_exposure(
channel, display_plots
)["spend-exposure"]
# Create plots for each channel only once
channels = {item["channel"] for item in self.featurized_mmmdata.modNLS["results"]}

for channel in channels:
spend_exposure_plot = self.plot_spend_exposure(channel, display=False)
plot_collect[f"{channel}_spend_exposure"] = spend_exposure_plot["spend-exposure"]

# plot_collect.update(self.plot_feature_importance({}, display))
if display_plots:
self.display_plots(plot_collect)

super().display_plots(plot_collect)
if export_location:
self.export_plots_fig(export_location, plot_collect)

return plot_collect
except Exception as e:
logger.error("Failed to generate all plots: %s", str(e))
raise

def __del__(self):
"""Cleanup when the plotter is destroyed."""
self.cleanup()
306 changes: 186 additions & 120 deletions python/src/robyn/visualization/model_convergence_visualizer.py
Original file line number Diff line number Diff line change
@@ -2,22 +2,16 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import Image, display

matplotlib.use("Agg")
import seaborn as sns
from typing import List, Optional, Union
import io
import base64
from typing import Dict, List, Optional, Union, Any
import logging
from robyn.modeling.entities.modeloutputs import Trial
from robyn.visualization.base_visualizer import BaseVisualizer

# Initialize logger for this module
logger = logging.getLogger(__name__)


class ModelConvergenceVisualizer:
class ModelConvergenceVisualizer(BaseVisualizer):
def __init__(
self,
n_cuts: Optional[int] = None,
@@ -26,6 +20,7 @@ def __init__(
moo_cloud_plot: Optional[pd.DataFrame] = None,
moo_distrb_plot: Optional[pd.DataFrame] = None,
):
super().__init__() # Initialize BaseVisualizer
self.n_cuts = n_cuts
self.nrmse_win = nrmse_win
self.ts_validation_plot = ts_validation_plot
@@ -35,7 +30,7 @@ def __init__(

def create_moo_distrb_plot(
self, dt_objfunc_cvg: pd.DataFrame, conv_msg: List[str]
) -> str:
) -> Dict[str, plt.Figure]:
logger.debug(
"Starting moo distribution plot creation with data shape: %s",
dt_objfunc_cvg.shape,
@@ -47,10 +42,8 @@ def create_moo_distrb_plot(
dt_objfunc_cvg["cuts"],
categories=sorted(dt_objfunc_cvg["cuts"].unique(), reverse=True),
)

# Clip values based on quantiles
logger.debug(
"Processing error types: %s", dt_objfunc_cvg["error_type"].unique()
)
for error_type in dt_objfunc_cvg["error_type"].unique():
mask = dt_objfunc_cvg["error_type"] == error_type
original_values = dt_objfunc_cvg.loc[mask, "value"]
@@ -62,11 +55,11 @@ def create_moo_distrb_plot(
quantiles[0],
quantiles[1],
)
# Set the style and color palette
sns.set_style("whitegrid")
sns.set_palette("Set2")
# Create the violin plot with a larger figure size
fig, ax = plt.subplots(figsize=(14, 10))

# Create figure using base visualizer methods
fig, ax = self.create_figure(figsize=self.figure_sizes["default"])

# Create the violin plot
sns.violinplot(
data=dt_objfunc_cvg,
x="value",
@@ -76,29 +69,33 @@ def create_moo_distrb_plot(
inner="quartile",
ax=ax,
)
ax.set_xlabel("Objective functions", fontsize=12, ha="left", x=0)
ax.set_ylabel("Iterations [#]", fontsize=12)
ax.set_title(
"Objective convergence by iterations quantiles",
fontsize=14,
fontweight="bold",

# Set labels and styling using base visualizer methods
self._set_standardized_labels(
ax,
xlabel="Objective functions",
ylabel="Iterations [#]",
title="Objective convergence by iterations quantiles"
)
ax.grid(True, linestyle="--", linewidth=0.5)
# Adjust layout to make room for figtext on the bottom right
plt.subplots_adjust(right=0.75, bottom=0.15)
# Add text annotations on the bottom right
self._add_standardized_grid(ax)
self._set_standardized_spines(ax)
self._add_standardized_legend(ax, loc='lower right')

# Add convergence messages
plt.figtext(
0.98,
0,
0.02,
"\n".join(conv_msg),
ha="right",
va="bottom",
fontsize=8,
fontsize=self.fonts["sizes"]["small"],
wrap=True,
)
plt.tight_layout()

self.finalize_figure(tight_layout=True)
logger.info("Successfully created moo distribution plot")
return self._convert_plot_to_base64(fig)
return {"moo_distribution": fig}

except Exception as e:
logger.error(
"Failed to create moo distribution plot: %s", str(e), exc_info=True
@@ -107,7 +104,7 @@ def create_moo_distrb_plot(

def create_moo_cloud_plot(
self, df: pd.DataFrame, conv_msg: List[str], calibrated: bool
) -> str:
) -> Dict[str, plt.Figure]:
logger.debug(
"Starting moo cloud plot creation with data shape: %s, calibrated=%s",
df.shape,
@@ -119,21 +116,19 @@ def create_moo_cloud_plot(
original_nrmse = df["nrmse"]
quantiles = np.quantile(original_nrmse, self.nrmse_win)
df["nrmse"] = np.clip(original_nrmse, *quantiles)
logger.debug(
"Clipped NRMSE values: min=%f, max=%f", quantiles[0], quantiles[1]
)
# Set the style and color palette
sns.set_style("whitegrid")
sns.set_palette("Set2")
# Create the scatter plot
fig, ax = plt.subplots(figsize=(12, 10))

# Create figure using base visualizer methods
fig, ax = self.create_figure(figsize=self.figure_sizes["default"])

# Create scatter plot
scatter = ax.scatter(
df["nrmse"],
df["decomp.rssd"],
c=df["ElapsedAccum"],
cmap="viridis",
alpha=0.7,
alpha=self.alpha["primary"]
)

if calibrated and "mape" in df.columns:
logger.debug("Adding calibrated MAPE visualization")
sizes = (df["mape"] - df["mape"].min()) / (
@@ -144,154 +139,177 @@ def create_moo_cloud_plot(
df["nrmse"],
df["decomp.rssd"],
s=sizes,
alpha=0.5,
alpha=self.alpha["secondary"],
edgecolor="w",
linewidth=0.5,
)

# Add colorbar
plt.colorbar(scatter, label="Time [s]")
ax.set_xlabel("NRMSE", fontsize=12, ha="left", x=0)
ax.set_ylabel("DECOMP.RSSD", fontsize=12)
ax.set_title(
"Multi-objective evolutionary performance",
fontsize=14,
fontweight="bold",

# Set labels and styling using base visualizer methods
self._set_standardized_labels(
ax,
xlabel="NRMSE",
ylabel="DECOMP.RSSD",
title="Multi-objective evolutionary performance"
)
# Add text annotations on the bottom right
self._add_standardized_grid(ax)
self._set_standardized_spines(ax)

# Add convergence messages
plt.figtext(
0.98,
0,
0.02,
"\n".join(conv_msg),
ha="right",
va="bottom",
fontsize=8,
fontsize=self.fonts["sizes"]["small"],
wrap=True,
)
plt.tight_layout()

self.finalize_figure(tight_layout=True)
logger.info("Successfully created moo cloud plot")
return self._convert_plot_to_base64(fig)
return {"moo_cloud": fig}

except Exception as e:
logger.error("Failed to create moo cloud plot: %s", str(e), exc_info=True)
raise

def create_ts_validation_plot(self, trials: List[Trial]) -> str:
def create_ts_validation_plot(self, trials: List[Trial]) -> Dict[str, plt.Figure]:
logger.debug(
"Starting time-series validation plot creation with %d trials", len(trials)
)
try:
# Concatenate trial data
# Prepare data
result_hyp_param = pd.concat(
[trial.result_hyp_param for trial in trials], ignore_index=True
)
result_hyp_param["trial"] = (
result_hyp_param.groupby("sol_id").cumcount() + 1
)
result_hyp_param["iteration"] = result_hyp_param.index + 1
logger.debug("Processing metrics for validation plot")

# Process metrics
result_hyp_param_long = result_hyp_param.melt(
id_vars=["sol_id", "trial", "train_size", "iteration"],
value_vars=[
"rsq_train",
"rsq_val",
"rsq_test",
"nrmse_train",
"nrmse_val",
"nrmse_test",
"rsq_train", "rsq_val", "rsq_test",
"nrmse_train", "nrmse_val", "nrmse_test"
],
var_name="metric",
value_name="value",
value_name="value"
)

# Extract dataset and metric type
result_hyp_param_long["dataset"] = (
result_hyp_param_long["metric"].str.split("_").str[-1]
)
result_hyp_param_long["metric_type"] = (
result_hyp_param_long["metric"].str.split("_").str[0]
)
result_hyp_param_long["dataset"] = result_hyp_param_long["metric"].str.split("_").str[-1]
result_hyp_param_long["metric_type"] = result_hyp_param_long["metric"].str.split("_").str[0]

# Winsorize the data
logger.debug("Winsorizing metric values")
result_hyp_param_long["value"] = result_hyp_param_long.groupby(
"metric_type"
)["value"].transform(
result_hyp_param_long["value"] = result_hyp_param_long.groupby("metric_type")["value"].transform(
lambda x: np.clip(
x,
np.percentile(x, self.nrmse_win[0] * 100),
np.percentile(x, self.nrmse_win[1] * 100),
)
)
# Set the style and color palette
sns.set_style("whitegrid")
sns.set_palette("Set2")
# Determine the number of trials

# Create figure using base visualizer methods
num_trials = result_hyp_param["trial"].nunique()
# Create subplots
fig, axes = plt.subplots(
num_trials + 1,
1,
figsize=(12, 5 * (num_trials + 1)),
gridspec_kw={"height_ratios": [3] * num_trials + [1]},
)
fig = plt.figure(figsize=self.figure_sizes["default"])

# Create grid for subplots
gs = fig.add_gridspec(num_trials + 1, 1, height_ratios=[3] * num_trials + [1])

# NRMSE plots for each trial
for i, (trial, ax) in enumerate(
zip(result_hyp_param["trial"].unique(), axes[:-1])
):
for i, trial in enumerate(result_hyp_param["trial"].unique()):
ax = fig.add_subplot(gs[i])
nrmse_data = result_hyp_param_long[
(result_hyp_param_long["metric_type"] == "nrmse")
& (result_hyp_param_long["trial"] == trial)
]

# Create plots
sns.scatterplot(
data=nrmse_data,
x="iteration",
y="value",
hue="dataset",
style="dataset",
markers=["o", "s", "D"], # Different markers for train, val, test
markers=["o", "s", "D"],
ax=ax,
alpha=0.6,
alpha=self.alpha["primary"]
)

sns.lineplot(
data=nrmse_data,
x="iteration",
y="value",
hue="dataset",
ax=ax,
legend=False,
linewidth=1,
linewidth=1
)
ax.set_ylabel(f"NRMSE [Trial {trial}]", fontsize=12, fontweight="bold")
ax.set_xlabel("Iteration", fontsize=12, fontweight="bold")
ax.legend(title="Dataset", loc="upper right")

# Style the subplot
self._set_standardized_labels(
ax,
ylabel=f"NRMSE [Trial {trial}]",
xlabel="Iteration" if i == num_trials - 1 else ""
)
self._add_standardized_grid(ax)
self._set_standardized_spines(ax)
self._add_standardized_legend(ax, loc='lower right')

# Only show x-label on bottom plot
if i < num_trials - 1:
ax.set_xlabel("")

# Train Size plot
ax = fig.add_subplot(gs[-1])
sns.scatterplot(
data=result_hyp_param,
x="iteration",
y="train_size",
hue="trial",
ax=axes[-1],
ax=ax,
legend=False,
)
axes[-1].set_ylabel("Train Size", fontsize=12, fontweight="bold")
axes[-1].set_xlabel("Iteration", fontsize=12, fontweight="bold")
axes[-1].set_ylim(0, 1)
axes[-1].yaxis.set_major_formatter(

# Style the train size plot
self._set_standardized_labels(
ax,
xlabel="Iteration",
ylabel="Train Size"
)
self._add_standardized_grid(ax)
self._set_standardized_spines(ax)

ax.set_ylim(0, 1)
ax.yaxis.set_major_formatter(
plt.FuncFormatter(lambda y, _: "{:.0%}".format(y))
)
# Set the overall title
plt.suptitle(
"Time-series validation & Convergence", fontsize=14, fontweight="bold"

# Set overall title
fig.suptitle(
"Time-series validation & Convergence",
fontsize=self.fonts["sizes"]["title"],
fontweight="bold",
y=1.02
)
plt.tight_layout()

self.finalize_figure(tight_layout=True)
logger.info("Successfully created time-series validation plot")
return self._convert_plot_to_base64(fig)
return {"ts_validation": fig}

except Exception as e:
logger.error(
"Failed to create time-series validation plot: %s",
str(e),
exc_info=True,
)
raise

def _convert_plot_to_base64(self, fig: plt.Figure) -> str:
logger.debug("Converting plot to base64")
try:
@@ -307,24 +325,72 @@ def _convert_plot_to_base64(self, fig: plt.Figure) -> str:
logger.error("Failed to convert plot to base64: %s", str(e), exc_info=True)
raise

def display_moo_distrb_plot(self):
"""Display the MOO Distribution Plot."""
self._display_base64_image(self.moo_distrb_plot)

def display_moo_cloud_plot(self):
"""Display the MOO Cloud Plot."""
self._display_base64_image(self.moo_cloud_plot)
def display_convergence_plots(self, plots_dict: Dict[str, Any]) -> None:
"""
Display all convergence plots from a dictionary.
"""
logger.info("Displaying convergence plots")
try:
if 'moo_distrb_plot' in plots_dict and plots_dict['moo_distrb_plot']:
logger.info("Displaying MOO distribution plot")
for name, fig in plots_dict['moo_distrb_plot'].items():
plt.figure(fig.number)
plt.show()

def display_ts_validation_plot(self):
"""Display the Time-Series Validation Plot."""
self._display_base64_image(self.ts_validation_plot)
if 'moo_cloud_plot' in plots_dict and plots_dict['moo_cloud_plot']:
logger.info("Displaying MOO cloud plot")
for name, fig in plots_dict['moo_cloud_plot'].items():
plt.figure(fig.number)
plt.show()

def _display_base64_image(self, base64_image: str):
"""Helper method to display a base64-encoded image."""
display(Image(data=base64.b64decode(base64_image)))
if 'ts_validation_plot' in plots_dict and plots_dict['ts_validation_plot']:
logger.info("Displaying time series validation plot")
for name, fig in plots_dict['ts_validation_plot'].items():
plt.figure(fig.number)
plt.show()
except Exception as e:
logger.error(f"Error displaying plots: {str(e)}")
raise

def plot_all(
self, display_plots: bool = True, export_location: Union[str, Path] = None
) -> None:
) -> Dict[str, plt.Figure]:
"""
Generate all available plots.
"""
logger.info("Generating all plots")
plot_collect: Dict[str, plt.Figure] = {}

try:
# Generate plots if data is available
if self.moo_distrb_plot is not None:
logger.info("Creating moo distribution plot")
plot_collect.update(self.create_moo_distrb_plot(self.moo_distrb_plot, []))

if self.moo_cloud_plot is not None:
logger.info("Creating moo cloud plot")
plot_collect.update(self.create_moo_cloud_plot(self.moo_cloud_plot, [], False))

if self.ts_validation_plot is not None:
logger.info("Creating time series validation plot")
plot_collect.update(self.create_ts_validation_plot(self.ts_validation_plot))

if display_plots:
logger.info(f"Displaying plots: {list(plot_collect.keys())}")
for plot_name, fig in plot_collect.items():
plt.figure(fig.number)
plt.show()

if export_location:
logger.info(f"Exporting plots to: {export_location}")
self.export_plots_fig(export_location, plot_collect)

return plot_collect

except Exception as e:
logger.error("Failed to generate all plots: %s", str(e))
raise

logger.warning("this method is not yet implemented")
def __del__(self):
"""Cleanup when the visualizer is destroyed."""
self.cleanup()
2,068 changes: 1,247 additions & 821 deletions python/src/robyn/visualization/pareto_visualizer.py

Large diffs are not rendered by default.