Skip to content

Commit f656a4c

Browse files
committed
🔧 🚀 allow ModeData to be passed to path integral computations with small refactor to reduce duplicate code
1 parent dab6854 commit f656a4c

File tree

6 files changed

+67
-64
lines changed

6 files changed

+67
-64
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- `num_freqs` is now set to 3 by default for the `PlaneWave`, `GaussianBeam`, and `AnalyticGaussianBeam` sources, which makes the injection more accurate in broadband cases.
2424
- Nonlinear models `KerrNonlinearity` and `TwoPhotonAbsorption` now default to using the physical real fields instead of complex fields.
2525
- Added warning when a lumped element is not completely within the simulation bounds, since now lumped elements will only have an effect on the `Simulation` when they are completely within the simulation bounds.
26+
- Allow `ModeData` to be passed to path integral computations in the `microwave` plugin.
2627

2728
### Fixed
2829
- Make gauge selection for non-converged modes more robust.

tests/test_plugins/test_microwave.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@
4343
size=(1, 1, 0),
4444
freqs=FS,
4545
mode_spec=td.ModeSpec(num_modes=2),
46+
name="mode_solver",
47+
),
48+
td.ModeMonitor(
49+
center=(0, 0, 0),
50+
size=(1, 1, 0),
51+
freqs=FS,
52+
mode_spec=td.ModeSpec(num_modes=3),
53+
store_fields_direction="+",
4654
name="mode",
4755
),
4856
],
@@ -280,7 +288,7 @@ def test_time_monitor_voltage_integral():
280288

281289

282290
def test_mode_solver_monitor_voltage_integral():
283-
"""Check VoltageIntegralAxisAligned runs on mode solver data."""
291+
"""Check VoltageIntegralAxisAligned runs on ModeData and ModeSolverData."""
284292
length = 0.5
285293
size = [0, 0, 0]
286294
size[1] = length
@@ -291,6 +299,7 @@ def test_mode_solver_monitor_voltage_integral():
291299
sign="+",
292300
)
293301

302+
voltage_integral.compute_voltage(SIM_Z_DATA["mode_solver"])
294303
voltage_integral.compute_voltage(SIM_Z_DATA["mode"])
295304

296305

@@ -495,12 +504,14 @@ def test_time_monitor_custom_current_integral():
495504

496505

497506
def test_mode_solver_custom_current_integral():
507+
"""Test that both ModeData and ModeSolverData are allowed types."""
498508
length = 0.5
499509
size = [0, 0, 0]
500510
size[1] = length
501511
# Make box
502512
vertices = [(0.2, -0.2), (0.2, 0.2), (-0.2, 0.2), (-0.2, -0.2), (0.2, -0.2)]
503513
current_integral = mw.CustomCurrentIntegral2D(axis=2, position=0, vertices=vertices)
514+
current_integral.compute_current(SIM_Z_DATA["mode_solver"])
504515
current_integral.compute_current(SIM_Z_DATA["mode"])
505516

506517

tests/utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,21 +1024,30 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData:
10241024
def make_mode_data(monitor: td.ModeMonitor) -> td.ModeData:
10251025
"""make a random ModeData from a ModeMonitor."""
10261026
_ = np.arange(monitor.mode_spec.num_modes)
1027-
coords_ind = {
1028-
"f": list(monitor.freqs),
1029-
"mode_index": np.arange(monitor.mode_spec.num_modes),
1030-
}
1027+
index_coords = {}
1028+
index_coords["f"] = list(monitor.freqs)
1029+
index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes)
10311030
n_complex = make_data(
1032-
coords=coords_ind, data_array_type=td.ModeIndexDataArray, is_complex=True
1031+
coords=index_coords, data_array_type=td.ModeIndexDataArray, is_complex=True
10331032
)
10341033
coords_amps = dict(direction=["+", "-"])
1035-
coords_amps.update(coords_ind)
1034+
coords_amps.update(index_coords)
10361035
amps = make_data(coords=coords_amps, data_array_type=td.ModeAmpsDataArray, is_complex=True)
1036+
field_cmps = {}
1037+
if monitor.store_fields_direction is not None:
1038+
for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
1039+
coords = get_spatial_coords_dict(simulation, monitor, field_name)
1040+
coords["f"] = list(monitor.freqs)
1041+
coords["mode_index"] = index_coords["mode_index"]
1042+
field_cmps[field_name] = make_data(
1043+
coords=coords, data_array_type=td.ScalarModeFieldDataArray, is_complex=True
1044+
)
10371045
return td.ModeData(
10381046
monitor=monitor,
10391047
n_complex=n_complex,
10401048
amps=amps,
10411049
grid_expanded=simulation.discretize_monitor(monitor),
1050+
**field_cmps,
10421051
)
10431052

10441053
def make_flux_data(monitor: td.FluxMonitor) -> td.FluxData:

tidy3d/plugins/microwave/custom_path_integrals.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,18 @@
1010
import xarray as xr
1111

1212
from ...components.base import cached_property
13-
from ...components.data.data_array import FreqDataArray, FreqModeDataArray, TimeDataArray
14-
from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData
1513
from ...components.geometry.base import Geometry
1614
from ...components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate, Direction
1715
from ...components.viz import add_ax_if_none
1816
from ...constants import MICROMETER, fp_eps
19-
from ...exceptions import DataError, SetupError
17+
from ...exceptions import SetupError
2018
from .path_integrals import (
2119
AbstractAxesRH,
20+
AxisAlignedPathIntegral,
2221
CurrentIntegralAxisAligned,
2322
IntegralResultTypes,
2423
MonitorDataTypes,
2524
VoltageIntegralAxisAligned,
26-
_check_em_field_supported,
2725
)
2826
from .viz import (
2927
ARROW_CURRENT,
@@ -94,11 +92,9 @@ def compute_integral(
9492

9593
h_field_name = f"{field}{dim1}"
9694
v_field_name = f"{field}{dim2}"
95+
9796
# Validate that fields are present
98-
if h_field_name not in em_field.field_components:
99-
raise DataError(f"'field_name' '{h_field_name}' not found.")
100-
if v_field_name not in em_field.field_components:
101-
raise DataError(f"'field_name' '{v_field_name}' not found.")
97+
em_field._check_fields_stored([h_field_name, v_field_name])
10298

10399
# Select fields lying on the plane
104100
plane_indexer = {dim3: self.position}
@@ -133,18 +129,7 @@ def compute_integral(
133129
# Integrate along the path
134130
result = integrand.integrate(coord="s")
135131
result = result.reset_coords(drop=True)
136-
137-
if isinstance(em_field, FieldData):
138-
return FreqDataArray(data=result.data, coords=result.coords)
139-
elif isinstance(em_field, FieldTimeData):
140-
return TimeDataArray(data=result.data, coords=result.coords)
141-
else:
142-
if not isinstance(em_field, ModeSolverData):
143-
raise TypeError(
144-
f"Unsupported 'em_field' type: {type(em_field)}. "
145-
"Expected one of 'FieldData', 'FieldTimeData', 'ModeSolverData'."
146-
)
147-
return FreqModeDataArray(data=result.data, coords=result.coords)
132+
return AxisAlignedPathIntegral._make_result_data_array(result)
148133

149134
@staticmethod
150135
def _compute_dl_component(coord_array: xr.DataArray, closed_contour=False) -> np.array:
@@ -270,7 +255,7 @@ def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
270255
:class:`.IntegralResultTypes`
271256
Result of voltage computation over remaining dimensions (frequency, time, mode indices).
272257
"""
273-
_check_em_field_supported(em_field=em_field)
258+
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
274259
voltage = -1.0 * self.compute_integral(field="E", em_field=em_field)
275260
voltage = VoltageIntegralAxisAligned._set_data_array_attributes(voltage)
276261
return voltage
@@ -343,7 +328,7 @@ def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
343328
:class:`.IntegralResultTypes`
344329
Result of current computation over remaining dimensions (frequency, time, mode indices).
345330
"""
346-
_check_em_field_supported(em_field=em_field)
331+
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
347332
current = self.compute_integral(field="H", em_field=em_field)
348333
current = CurrentIntegralAxisAligned._set_data_array_attributes(current)
349334
return current

tidy3d/plugins/microwave/impedance_calculator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from ...exceptions import ValidationError
1414
from .custom_path_integrals import CustomCurrentIntegral2D, CustomVoltageIntegral2D
1515
from .path_integrals import (
16+
AxisAlignedPathIntegral,
1617
CurrentIntegralAxisAligned,
1718
IntegralResultTypes,
1819
MonitorDataTypes,
1920
VoltageIntegralAxisAligned,
20-
_check_em_field_supported,
2121
)
2222

2323
VoltageIntegralTypes = Union[VoltageIntegralAxisAligned, CustomVoltageIntegral2D]
@@ -55,7 +55,7 @@ def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
5555
:class:`.IntegralResultTypes`
5656
Result of impedance computation over remaining dimensions (frequency, time, mode indices).
5757
"""
58-
_check_em_field_supported(em_field=em_field)
58+
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
5959

6060
# If both voltage and current integrals have been defined then impedance is computed directly
6161
if self.voltage_integral:

tidy3d/plugins/microwave/path_integrals.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6-
from typing import Any, Union
6+
from typing import Union
77

88
import numpy as np
99
import pydantic.v1 as pd
1010
import shapely as shapely
11+
import xarray as xr
1112

1213
from ...components.base import Tidy3dBaseModel, cached_property
1314
from ...components.data.data_array import (
@@ -18,7 +19,7 @@
1819
ScalarModeFieldDataArray,
1920
TimeDataArray,
2021
)
21-
from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData
22+
from ...components.data.monitor_data import FieldData, FieldTimeData, ModeData, ModeSolverData
2223
from ...components.geometry.base import Box, Geometry
2324
from ...components.types import Ax, Axis, Coordinate2D, Direction
2425
from ...components.validators import assert_line, assert_plane
@@ -33,20 +34,11 @@
3334
plot_params_voltage_plus,
3435
)
3536

36-
MonitorDataTypes = Union[FieldData, FieldTimeData, ModeSolverData]
37+
MonitorDataTypes = Union[FieldData, FieldTimeData, ModeData, ModeSolverData]
3738
EMScalarFieldType = Union[ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray]
3839
IntegralResultTypes = Union[FreqDataArray, FreqModeDataArray, TimeDataArray]
3940

4041

41-
def _check_em_field_supported(em_field: Any):
42-
"""Function for validating correct data arrays."""
43-
if not isinstance(em_field, (FieldData, FieldTimeData, ModeSolverData)):
44-
raise DataError(
45-
"'em_field' type not supported. Supported types are "
46-
"'FieldData', 'FieldTimeData', 'ModeSolverData'."
47-
)
48-
49-
5042
class AbstractAxesRH(Tidy3dBaseModel, ABC):
5143
"""Represents an axis-aligned right-handed coordinate system with one axis preferred.
5244
Typically `main_axis` would refer to the normal axis of a plane.
@@ -137,18 +129,7 @@ def compute_integral(self, scalar_field: EMScalarFieldType) -> IntegralResultTyp
137129
coords_interp, method=method, kwargs={"fill_value": "extrapolate"}
138130
)
139131
result = scalar_field.integrate(coord=coord)
140-
if isinstance(scalar_field, ScalarFieldDataArray):
141-
return FreqDataArray(data=result.data, coords=result.coords)
142-
elif isinstance(scalar_field, ScalarFieldTimeDataArray):
143-
return TimeDataArray(data=result.data, coords=result.coords)
144-
else:
145-
if not isinstance(scalar_field, ScalarModeFieldDataArray):
146-
raise TypeError(
147-
f"Unsupported 'scalar_field' type: {type(scalar_field)}. "
148-
"Expected one of 'ScalarFieldDataArray', 'ScalarFieldTimeDataArray', "
149-
"'ScalarModeFieldDataArray'."
150-
)
151-
return FreqModeDataArray(data=result.data, coords=result.coords)
132+
return self._make_result_data_array(result)
152133

153134
def _get_field_along_path(self, scalar_field: EMScalarFieldType) -> EMScalarFieldType:
154135
"""Returns a selection of the input ``scalar_field`` ready for integration."""
@@ -206,6 +187,26 @@ def _vertices_2D(self, axis: Axis) -> tuple[Coordinate2D, Coordinate2D]:
206187
v = [min[1], max[1]]
207188
return (u, v)
208189

190+
@staticmethod
191+
def _check_monitor_data_supported(em_field: MonitorDataTypes):
192+
"""Helper for validating that monitor data is supported."""
193+
if not isinstance(em_field, (FieldData, FieldTimeData, ModeData, ModeSolverData)):
194+
supported_types = list(MonitorDataTypes.__args__)
195+
raise DataError(
196+
f"'em_field' type {type(em_field)} not supported. Supported types are "
197+
f"{supported_types}"
198+
)
199+
200+
@staticmethod
201+
def _make_result_data_array(result: xr.DataArray) -> IntegralResultTypes:
202+
"""Helper for creating the proper result type."""
203+
if "t" in result.coords:
204+
return TimeDataArray(data=result.data, coords=result.coords)
205+
elif "f" in result.coords and "mode_index" in result.coords:
206+
return FreqModeDataArray(data=result.data, coords=result.coords)
207+
else:
208+
return FreqDataArray(data=result.data, coords=result.coords)
209+
209210

210211
class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):
211212
"""Class for computing the voltage between two points defined by an axis-aligned line."""
@@ -218,12 +219,11 @@ class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):
218219

219220
def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
220221
"""Compute voltage along path defined by a line."""
221-
_check_em_field_supported(em_field=em_field)
222+
self._check_monitor_data_supported(em_field=em_field)
222223
e_component = "xyz"[self.main_axis]
223224
field_name = f"E{e_component}"
224-
# Validate that the field is present
225-
if field_name not in em_field.field_components:
226-
raise DataError(f"'field_name' '{field_name}' not found.")
225+
# Validate that fields are present
226+
em_field._check_fields_stored([field_name])
227227
e_field = em_field.field_components[field_name]
228228

229229
voltage = self.compute_integral(e_field)
@@ -376,18 +376,15 @@ class CurrentIntegralAxisAligned(AbstractAxesRH, Box):
376376

377377
def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes:
378378
"""Compute current flowing in loop defined by the outer edge of a rectangle."""
379-
_check_em_field_supported(em_field=em_field)
379+
AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field)
380380
ax1 = self.remaining_axes[0]
381381
ax2 = self.remaining_axes[1]
382382
h_component = "xyz"[ax1]
383383
v_component = "xyz"[ax2]
384384
h_field_name = f"H{h_component}"
385385
v_field_name = f"H{v_component}"
386386
# Validate that fields are present
387-
if h_field_name not in em_field.field_components:
388-
raise DataError(f"'field_name' '{h_field_name}' not found.")
389-
if v_field_name not in em_field.field_components:
390-
raise DataError(f"'field_name' '{v_field_name}' not found.")
387+
em_field._check_fields_stored([h_field_name, v_field_name])
391388
h_horizontal = em_field.field_components[h_field_name]
392389
h_vertical = em_field.field_components[v_field_name]
393390

0 commit comments

Comments
 (0)