Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove sdec plotter class from liv plot class. #2929

Draft
wants to merge 30 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7d914fe
Rename sdecdata class to SimulationPacketData and shift it outside of…
KasukabeDefenceForce Jan 7, 2025
0116ae5
Update simulation packet data imports in sdec plot
KasukabeDefenceForce Jan 7, 2025
7246dad
Fix import from simulation packet data class
KasukabeDefenceForce Jan 7, 2025
41bc461
Function to create a dictionary containing virtual and real packet data
KasukabeDefenceForce Jan 7, 2025
a2be3fe
Update sdec plotter doc string
KasukabeDefenceForce Jan 7, 2025
77e045a
Replace sdecdata with new simulationpacketdata
KasukabeDefenceForce Jan 7, 2025
b0b9398
Reducing duplicate spectrum
KasukabeDefenceForce Jan 7, 2025
0e6eecf
Remove unnecessary else statements
KasukabeDefenceForce Jan 7, 2025
e7a4388
simplify vpacket_tracker access in SimulationPacketData in virtual mode
KasukabeDefenceForce Jan 7, 2025
2fcb4ac
Add create_packet_data_dict_from_hdf utility function
KasukabeDefenceForce Jan 7, 2025
9cf946f
Implement create_packet_data_dict_from_hdf in sdec and liv plot classes
KasukabeDefenceForce Jan 7, 2025
903f3e5
Update util function name for clarity
KasukabeDefenceForce Jan 7, 2025
821748a
Shift create packet data dict functions to plot_util.py
KasukabeDefenceForce Jan 7, 2025
480d9a4
Update imports from plot_util.py
KasukabeDefenceForce Jan 7, 2025
2c583b0
Remove unwanted if statements from from_hdf method
KasukabeDefenceForce Jan 7, 2025
33aa8af
Simplify paths to hdf files of different properties
KasukabeDefenceForce Jan 7, 2025
1a7684a
Add else statement to improve readability
KasukabeDefenceForce Jan 7, 2025
447d557
Rename simulation packet data to VisualizationData
KasukabeDefenceForce Jan 7, 2025
1db357e
Fix import error
KasukabeDefenceForce Jan 7, 2025
9b996fb
Shift time_explosion and velocity to visualization data class
KasukabeDefenceForce Jan 7, 2025
b499edf
Calculating velocity in init method
KasukabeDefenceForce Jan 7, 2025
24fb361
Calculate velocity if not given as input
KasukabeDefenceForce Jan 7, 2025
4d31d07
Calculate velocity if not specified in sim object
KasukabeDefenceForce Jan 7, 2025
b71e734
Update docstring
KasukabeDefenceForce Jan 7, 2025
d666340
Shift parse_species_list to util file
KasukabeDefenceForce Jan 8, 2025
501d246
remove sdec plotter class from liv plot
KasukabeDefenceForce Jan 8, 2025
fbbb82f
Return a dictionary instead of a tuple
KasukabeDefenceForce Jan 10, 2025
459b0ae
add missing full_species_list
KasukabeDefenceForce Jan 10, 2025
b451afd
set full_species_list to none if no species list is given
KasukabeDefenceForce Jan 13, 2025
bc371f8
refactor for loop to create full_species_list
KasukabeDefenceForce Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions tardis/visualization/plot_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
"""Utility functions to be used in plotting."""

import re

import numpy as np

from tardis.util.base import (
element_symbol2atomic_number,
int_to_roman,
roman_to_int,
species_string_to_tuple,
)
from tardis.visualization.tools.visualization_data import (
VisualizationData,
)


def axis_label_in_latex(label_text, unit, only_text=True):
"""
Expand Down Expand Up @@ -79,3 +90,140 @@
"""
color_tuple_255 = tuple([int(x * 255) for x in color_tuple[:3]])
return f"rgb{color_tuple_255}"

def create_packet_data_dict_from_simulation(sim):
"""
Create a dictionary containing virtual and real packet data based on simulation state.

Parameters
----------
sim : tardis.simulation.Simulation
TARDIS Simulation object produced by running a simulation

Returns
-------
dict
Dictionary containing 'virtual' and 'real' SimulationPacketData instances
"""
packet_data = {
"real": VisualizationData.from_simulation(sim, "real")
}
if sim.transport.transport_state.virt_logging:
packet_data["virtual"] = VisualizationData.from_simulation(sim, "virtual")
else:
packet_data["virtual"] = None

Check warning on line 114 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L114

Added line #L114 was not covered by tests

return packet_data

def create_packet_data_dict_from_hdf(hdf_fpath, packets_mode=None):
"""
Create a dictionary containing virtual and real packet data from HDF file.

Parameters
----------
hdf_fpath : str
Valid path to the HDF file where simulation is saved
packets_mode : {'virtual', 'real', None}
Mode of packets to be considered. If None, both modes are returned.

Returns
-------
dict
Dictionary containing 'virtual' and 'real' SimulationPacketData instances
"""
if packets_mode not in [None, "virtual", "real"]:
raise ValueError(

Check warning on line 135 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L134-L135

Added lines #L134 - L135 were not covered by tests
"Invalid value passed to packets_mode. Only "
"allowed values are 'virtual', 'real' or None"
)
if packets_mode == "virtual":
return {

Check warning on line 140 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L139-L140

Added lines #L139 - L140 were not covered by tests
"virtual": VisualizationData.from_hdf(hdf_fpath, "virtual"),
"real": None
}
if packets_mode == "real":
return {

Check warning on line 145 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L144-L145

Added lines #L144 - L145 were not covered by tests
"virtual": None,
"real": VisualizationData.from_hdf(hdf_fpath, "real")
}
return {

Check warning on line 149 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L149

Added line #L149 was not covered by tests
"virtual": VisualizationData.from_hdf(hdf_fpath, "virtual"),
"real": VisualizationData.from_hdf(hdf_fpath, "real")
}


def parse_species_list_util(species_list):
"""
Parse user requested species list and create list of species ids to be used.

Parameters
----------
species_list : list of species to plot
List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours.
Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions
(e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca])

Returns
-------
dict
A dictionary containing:
- full_species_list: List of expanded species (e.g. Si I - V -> [Si I, Si II, ...]).
- species_mapped: Mapping of species ids to species names.
- keep_colour: List of atomic numbers to group elements with consistent colors.
"""
if species_list is None:
return {

Check warning on line 175 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L175

Added line #L175 was not covered by tests
"full_species_list": None,
"species_mapped": None,
"keep_colour": None,
"species_list": None,
}


if any(char.isdigit() for char in " ".join(species_list)):
raise ValueError("All species must be in Roman numeral form, e.g., Si II")

Check warning on line 184 in tardis/visualization/plot_util.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/plot_util.py#L184

Added line #L184 was not covered by tests

full_species_list = []
species_mapped = {}
keep_colour = []
requested_species_ids = []

for species in species_list:
if "-" in species:
element, ion_numerals = species.split(" ")
first_ion_roman, second_ion_roman = ion_numerals.split("-")
ion_range = range(roman_to_int(first_ion_roman), roman_to_int(second_ion_roman) + 1)

full_species_list.extend(f"{element} {int_to_roman(ion)}" for ion in ion_range)
else:
full_species_list.append(species)


for species in full_species_list:
if " " in species:
atomic_number, ion_number = species_string_to_tuple(species)
species_id = (
atomic_number * 100
+ ion_number
)
requested_species_ids.append([species_id])
species_mapped[species_id] = [species_id]
else:
atomic_number = element_symbol2atomic_number(species)
species_ids = [
atomic_number * 100 + ion_number for ion_number in range(atomic_number)
]
requested_species_ids.append(species_ids)
species_mapped[atomic_number * 100] = species_ids
keep_colour.append(atomic_number)

requested_species_ids = [
species_id for temp_list in requested_species_ids for species_id in temp_list
]

return {
"full_species_list": full_species_list,
"species_mapped": species_mapped,
"keep_colour": keep_colour,
"species_list": requested_species_ids,
}
67 changes: 21 additions & 46 deletions tardis/visualization/tools/liv_plot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import logging
import matplotlib.pyplot as plt

import astropy.units as u
import matplotlib.cm as cm
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import astropy.units as u
import plotly.graph_objects as go

import tardis.visualization.tools.sdec_plot as sdec
from tardis.util.base import (
atomic_number2element_symbol,
int_to_roman,
)
import tardis.visualization.tools.sdec_plot as sdec
from tardis.visualization import plot_util as pu

logger = logging.getLogger(__name__)
Expand All @@ -21,7 +22,7 @@
Plotting interface for the last interaction velocity plot.
"""

def __init__(self, data, time_explosion, velocity):
def __init__(self, data):
"""
Initialize the plotter with required data from the simulation.

Expand All @@ -31,17 +32,8 @@
Dictionary to store data required for last interaction velocity plot,
for both packet modes (real, virtual).

time_explosion : astropy.units.Quantity
Time of the explosion.

velocity : astropy.units.Quantity
Velocity array from the simulation.
"""

self.data = data
self.time_explosion = time_explosion
self.velocity = velocity
self.sdec_plotter = sdec.SDECPlotter(data)

@classmethod
def from_simulation(cls, sim):
Expand All @@ -57,15 +49,7 @@
-------
LIVPlotter
"""

return cls(
dict(
virtual=sdec.SDECData.from_simulation(sim, "virtual"),
real=sdec.SDECData.from_simulation(sim, "real"),
),
sim.plasma.time_explosion,
sim.simulation_state.velocity,
)
return cls(pu.create_packet_data_dict_from_simulation(sim))

@classmethod
def from_hdf(cls, hdf_fpath):
Expand All @@ -81,23 +65,7 @@
-------
LIVPlotter
"""
with pd.HDFStore(hdf_fpath, "r") as hdf:
time_explosion = (
hdf["/simulation/plasma/scalars"]["time_explosion"] * u.s
)
v_inner = hdf["/simulation/simulation_state/v_inner"] * (u.cm / u.s)
v_outer = hdf["/simulation/simulation_state/v_outer"] * (u.cm / u.s)
velocity = pd.concat(
[v_inner, pd.Series([v_outer.iloc[-1]])], ignore_index=True
).tolist() * (u.cm / u.s)
return cls(
dict(
virtual=sdec.SDECData.from_hdf(hdf_fpath, "virtual"),
real=sdec.SDECData.from_hdf(hdf_fpath, "real"),
),
time_explosion,
velocity,
)
return cls(pu.create_packet_data_dict_from_hdf(hdf_fpath))

Check warning on line 68 in tardis/visualization/tools/liv_plot.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/tools/liv_plot.py#L68

Added line #L68 was not covered by tests

def _parse_species_list(self, species_list, packets_mode, nelements=None):
"""
Expand All @@ -120,10 +88,14 @@
If species list contains invalid entries.

"""
self.sdec_plotter._parse_species_list(species_list)
self._species_list = self.sdec_plotter._species_list
self._species_mapped = self.sdec_plotter._species_mapped
self._keep_colour = self.sdec_plotter._keep_colour
parsed_species_data = pu.parse_species_list_util(species_list)
if parsed_species_data is None:
self._species_list = None

Check warning on line 93 in tardis/visualization/tools/liv_plot.py

View check run for this annotation

Codecov / codecov/patch

tardis/visualization/tools/liv_plot.py#L93

Added line #L93 was not covered by tests
else:
self._full_species_list = parsed_species_data["full_species_list"]
self._species_mapped = parsed_species_data["species_mapped"]
self._keep_colour = parsed_species_data["keep_colour"]
self._species_list = parsed_species_data["species_list"]

if nelements:
interaction_counts = (
Expand Down Expand Up @@ -211,6 +183,8 @@
species_not_wvl_range = []
species_counter = 0

time_explosion = self.data[packets_mode].time_explosion

for specie_list in self._species_mapped.values():
full_v_last = []
for specie in specie_list:
Expand All @@ -227,7 +201,7 @@
g_df["last_interaction_in_r"].values * u.cm
)
v_last_interaction = (
r_last_interaction / self.time_explosion
r_last_interaction / time_explosion
).to("km/s")
full_v_last.extend(v_last_interaction)
if full_v_last:
Expand Down Expand Up @@ -333,7 +307,8 @@
)

self._generate_plot_data(packets_mode)
bin_edges = (self.velocity).to("km/s")
velocity = self.data[packets_mode].velocity
bin_edges = (velocity).to("km/s")

if num_bins:
if num_bins < 1:
Expand Down
Loading
Loading