Skip to content

Commit

Permalink
🔧 🚀 allow ModeData to be passed to path integral computations with sm…
Browse files Browse the repository at this point in the history
…all refactor to reduce duplicate code
  • Loading branch information
dmarek-flex committed Feb 19, 2025
1 parent dab6854 commit f656a4c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 64 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `num_freqs` is now set to 3 by default for the `PlaneWave`, `GaussianBeam`, and `AnalyticGaussianBeam` sources, which makes the injection more accurate in broadband cases.
- Nonlinear models `KerrNonlinearity` and `TwoPhotonAbsorption` now default to using the physical real fields instead of complex fields.
- Added warning when a lumped element is not completely within the simulation bounds, since now lumped elements will only have an effect on the `Simulation` when they are completely within the simulation bounds.
- Allow `ModeData` to be passed to path integral computations in the `microwave` plugin.

### Fixed
- Make gauge selection for non-converged modes more robust.
Expand Down
13 changes: 12 additions & 1 deletion tests/test_plugins/test_microwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
size=(1, 1, 0),
freqs=FS,
mode_spec=td.ModeSpec(num_modes=2),
name="mode_solver",
),
td.ModeMonitor(
center=(0, 0, 0),
size=(1, 1, 0),
freqs=FS,
mode_spec=td.ModeSpec(num_modes=3),
store_fields_direction="+",
name="mode",
),
],
Expand Down Expand Up @@ -280,7 +288,7 @@ def test_time_monitor_voltage_integral():


def test_mode_solver_monitor_voltage_integral():
"""Check VoltageIntegralAxisAligned runs on mode solver data."""
"""Check VoltageIntegralAxisAligned runs on ModeData and ModeSolverData."""
length = 0.5
size = [0, 0, 0]
size[1] = length
Expand All @@ -291,6 +299,7 @@ def test_mode_solver_monitor_voltage_integral():
sign="+",
)

voltage_integral.compute_voltage(SIM_Z_DATA["mode_solver"])
voltage_integral.compute_voltage(SIM_Z_DATA["mode"])


Expand Down Expand Up @@ -495,12 +504,14 @@ def test_time_monitor_custom_current_integral():


def test_mode_solver_custom_current_integral():
"""Test that both ModeData and ModeSolverData are allowed types."""
length = 0.5
size = [0, 0, 0]
size[1] = length
# Make box
vertices = [(0.2, -0.2), (0.2, 0.2), (-0.2, 0.2), (-0.2, -0.2), (0.2, -0.2)]
current_integral = mw.CustomCurrentIntegral2D(axis=2, position=0, vertices=vertices)
current_integral.compute_current(SIM_Z_DATA["mode_solver"])
current_integral.compute_current(SIM_Z_DATA["mode"])


Expand Down
21 changes: 15 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,21 +1024,30 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData:
def make_mode_data(monitor: td.ModeMonitor) -> td.ModeData:
"""make a random ModeData from a ModeMonitor."""
_ = np.arange(monitor.mode_spec.num_modes)
coords_ind = {
"f": list(monitor.freqs),
"mode_index": np.arange(monitor.mode_spec.num_modes),
}
index_coords = {}
index_coords["f"] = list(monitor.freqs)
index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes)
n_complex = make_data(
coords=coords_ind, data_array_type=td.ModeIndexDataArray, is_complex=True
coords=index_coords, data_array_type=td.ModeIndexDataArray, is_complex=True
)
coords_amps = dict(direction=["+", "-"])
coords_amps.update(coords_ind)
coords_amps.update(index_coords)
amps = make_data(coords=coords_amps, data_array_type=td.ModeAmpsDataArray, is_complex=True)
field_cmps = {}
if monitor.store_fields_direction is not None:
for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
coords = get_spatial_coords_dict(simulation, monitor, field_name)
coords["f"] = list(monitor.freqs)
coords["mode_index"] = index_coords["mode_index"]
field_cmps[field_name] = make_data(
coords=coords, data_array_type=td.ScalarModeFieldDataArray, is_complex=True
)
return td.ModeData(
monitor=monitor,
n_complex=n_complex,
amps=amps,
grid_expanded=simulation.discretize_monitor(monitor),
**field_cmps,
)

def make_flux_data(monitor: td.FluxMonitor) -> td.FluxData:
Expand Down
29 changes: 7 additions & 22 deletions tidy3d/plugins/microwave/custom_path_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@
import xarray as xr

from ...components.base import cached_property
from ...components.data.data_array import FreqDataArray, FreqModeDataArray, TimeDataArray
from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData
from ...components.geometry.base import Geometry
from ...components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate, Direction
from ...components.viz import add_ax_if_none
from ...constants import MICROMETER, fp_eps
from ...exceptions import DataError, SetupError
from ...exceptions import SetupError
from .path_integrals import (
AbstractAxesRH,
AxisAlignedPathIntegral,
CurrentIntegralAxisAligned,
IntegralResultTypes,
MonitorDataTypes,
VoltageIntegralAxisAligned,
_check_em_field_supported,
)
from .viz import (
ARROW_CURRENT,
Expand Down Expand Up @@ -94,11 +92,9 @@ def compute_integral(

h_field_name = f"{field}{dim1}"
v_field_name = f"{field}{dim2}"

# Validate that fields are present
if h_field_name not in em_field.field_components:
raise DataError(f"'field_name' '{h_field_name}' not found.")
if v_field_name not in em_field.field_components:
raise DataError(f"'field_name' '{v_field_name}' not found.")
em_field._check_fields_stored([h_field_name, v_field_name])

# Select fields lying on the plane
plane_indexer = {dim3: self.position}
Expand Down Expand Up @@ -133,18 +129,7 @@ def compute_integral(
# Integrate along the path
result = integrand.integrate(coord="s")
result = result.reset_coords(drop=True)

if isinstance(em_field, FieldData):
return FreqDataArray(data=result.data, coords=result.coords)
elif isinstance(em_field, FieldTimeData):
return TimeDataArray(data=result.data, coords=result.coords)
else:
if not isinstance(em_field, ModeSolverData):
raise TypeError(
f"Unsupported 'em_field' type: {type(em_field)}. "
"Expected one of 'FieldData', 'FieldTimeData', 'ModeSolverData'."
)
return FreqModeDataArray(data=result.data, coords=result.coords)
return AxisAlignedPathIntegral._make_result_data_array(result)

@staticmethod
def _compute_dl_component(coord_array: xr.DataArray, closed_contour=False) -> np.array:
Expand Down Expand Up @@ -270,7 +255,7 @@ def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
:class:`.IntegralResultTypes`
Result of voltage computation over remaining dimensions (frequency, time, mode indices).
"""
_check_em_field_supported(em_field=em_field)
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
voltage = -1.0 * self.compute_integral(field="E", em_field=em_field)
voltage = VoltageIntegralAxisAligned._set_data_array_attributes(voltage)
return voltage
Expand Down Expand Up @@ -343,7 +328,7 @@ def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
:class:`.IntegralResultTypes`
Result of current computation over remaining dimensions (frequency, time, mode indices).
"""
_check_em_field_supported(em_field=em_field)
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
current = self.compute_integral(field="H", em_field=em_field)
current = CurrentIntegralAxisAligned._set_data_array_attributes(current)
return current
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/plugins/microwave/impedance_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from ...exceptions import ValidationError
from .custom_path_integrals import CustomCurrentIntegral2D, CustomVoltageIntegral2D
from .path_integrals import (
AxisAlignedPathIntegral,
CurrentIntegralAxisAligned,
IntegralResultTypes,
MonitorDataTypes,
VoltageIntegralAxisAligned,
_check_em_field_supported,
)

VoltageIntegralTypes = Union[VoltageIntegralAxisAligned, CustomVoltageIntegral2D]
Expand Down Expand Up @@ -55,7 +55,7 @@ def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
:class:`.IntegralResultTypes`
Result of impedance computation over remaining dimensions (frequency, time, mode indices).
"""
_check_em_field_supported(em_field=em_field)
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)

# If both voltage and current integrals have been defined then impedance is computed directly
if self.voltage_integral:
Expand Down
63 changes: 30 additions & 33 deletions tidy3d/plugins/microwave/path_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Union
from typing import Union

import numpy as np
import pydantic.v1 as pd
import shapely as shapely
import xarray as xr

from ...components.base import Tidy3dBaseModel, cached_property
from ...components.data.data_array import (
Expand All @@ -18,7 +19,7 @@
ScalarModeFieldDataArray,
TimeDataArray,
)
from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData
from ...components.data.monitor_data import FieldData, FieldTimeData, ModeData, ModeSolverData
from ...components.geometry.base import Box, Geometry
from ...components.types import Ax, Axis, Coordinate2D, Direction
from ...components.validators import assert_line, assert_plane
Expand All @@ -33,20 +34,11 @@
plot_params_voltage_plus,
)

MonitorDataTypes = Union[FieldData, FieldTimeData, ModeSolverData]
MonitorDataTypes = Union[FieldData, FieldTimeData, ModeData, ModeSolverData]
EMScalarFieldType = Union[ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray]
IntegralResultTypes = Union[FreqDataArray, FreqModeDataArray, TimeDataArray]


def _check_em_field_supported(em_field: Any):
"""Function for validating correct data arrays."""
if not isinstance(em_field, (FieldData, FieldTimeData, ModeSolverData)):
raise DataError(
"'em_field' type not supported. Supported types are "
"'FieldData', 'FieldTimeData', 'ModeSolverData'."
)


class AbstractAxesRH(Tidy3dBaseModel, ABC):
"""Represents an axis-aligned right-handed coordinate system with one axis preferred.
Typically `main_axis` would refer to the normal axis of a plane.
Expand Down Expand Up @@ -137,18 +129,7 @@ def compute_integral(self, scalar_field: EMScalarFieldType) -> IntegralResultTyp
coords_interp, method=method, kwargs={"fill_value": "extrapolate"}
)
result = scalar_field.integrate(coord=coord)
if isinstance(scalar_field, ScalarFieldDataArray):
return FreqDataArray(data=result.data, coords=result.coords)
elif isinstance(scalar_field, ScalarFieldTimeDataArray):
return TimeDataArray(data=result.data, coords=result.coords)
else:
if not isinstance(scalar_field, ScalarModeFieldDataArray):
raise TypeError(
f"Unsupported 'scalar_field' type: {type(scalar_field)}. "
"Expected one of 'ScalarFieldDataArray', 'ScalarFieldTimeDataArray', "
"'ScalarModeFieldDataArray'."
)
return FreqModeDataArray(data=result.data, coords=result.coords)
return self._make_result_data_array(result)

def _get_field_along_path(self, scalar_field: EMScalarFieldType) -> EMScalarFieldType:
"""Returns a selection of the input ``scalar_field`` ready for integration."""
Expand Down Expand Up @@ -206,6 +187,26 @@ def _vertices_2D(self, axis: Axis) -> tuple[Coordinate2D, Coordinate2D]:
v = [min[1], max[1]]
return (u, v)

@staticmethod
def _check_monitor_data_supported(em_field: MonitorDataTypes):
"""Helper for validating that monitor data is supported."""
if not isinstance(em_field, (FieldData, FieldTimeData, ModeData, ModeSolverData)):
supported_types = list(MonitorDataTypes.__args__)
raise DataError(
f"'em_field' type {type(em_field)} not supported. Supported types are "
f"{supported_types}"
)

@staticmethod
def _make_result_data_array(result: xr.DataArray) -> IntegralResultTypes:
"""Helper for creating the proper result type."""
if "t" in result.coords:
return TimeDataArray(data=result.data, coords=result.coords)
elif "f" in result.coords and "mode_index" in result.coords:
return FreqModeDataArray(data=result.data, coords=result.coords)
else:
return FreqDataArray(data=result.data, coords=result.coords)


class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):
"""Class for computing the voltage between two points defined by an axis-aligned line."""
Expand All @@ -218,12 +219,11 @@ class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):

def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
"""Compute voltage along path defined by a line."""
_check_em_field_supported(em_field=em_field)
self._check_monitor_data_supported(em_field=em_field)
e_component = "xyz"[self.main_axis]
field_name = f"E{e_component}"
# Validate that the field is present
if field_name not in em_field.field_components:
raise DataError(f"'field_name' '{field_name}' not found.")
# Validate that fields are present
em_field._check_fields_stored([field_name])
e_field = em_field.field_components[field_name]

voltage = self.compute_integral(e_field)
Expand Down Expand Up @@ -376,18 +376,15 @@ class CurrentIntegralAxisAligned(AbstractAxesRH, Box):

def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
"""Compute current flowing in loop defined by the outer edge of a rectangle."""
_check_em_field_supported(em_field=em_field)
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
ax1 = self.remaining_axes[0]
ax2 = self.remaining_axes[1]
h_component = "xyz"[ax1]
v_component = "xyz"[ax2]
h_field_name = f"H{h_component}"
v_field_name = f"H{v_component}"
# Validate that fields are present
if h_field_name not in em_field.field_components:
raise DataError(f"'field_name' '{h_field_name}' not found.")
if v_field_name not in em_field.field_components:
raise DataError(f"'field_name' '{v_field_name}' not found.")
em_field._check_fields_stored([h_field_name, v_field_name])
h_horizontal = em_field.field_components[h_field_name]
h_vertical = em_field.field_components[v_field_name]

Expand Down

0 comments on commit f656a4c

Please sign in to comment.