From f656a4c0f4cac20821084f900de866836bac65e8 Mon Sep 17 00:00:00 2001 From: dmarek Date: Tue, 18 Feb 2025 15:47:10 -0500 Subject: [PATCH] :wrench: :rocket: allow ModeData to be passed to path integral computations with small refactor to reduce duplicate code --- CHANGELOG.md | 1 + tests/test_plugins/test_microwave.py | 13 +++- tests/utils.py | 21 +++++-- .../microwave/custom_path_integrals.py | 29 +++------ .../plugins/microwave/impedance_calculator.py | 4 +- tidy3d/plugins/microwave/path_integrals.py | 63 +++++++++---------- 6 files changed, 67 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 342e38d7e3..8b3823b968 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `num_freqs` is now set to 3 by default for the `PlaneWave`, `GaussianBeam`, and `AnalyticGaussianBeam` sources, which makes the injection more accurate in broadband cases. - Nonlinear models `KerrNonlinearity` and `TwoPhotonAbsorption` now default to using the physical real fields instead of complex fields. - Added warning when a lumped element is not completely within the simulation bounds, since now lumped elements will only have an effect on the `Simulation` when they are completely within the simulation bounds. +- Allow `ModeData` to be passed to path integral computations in the `microwave` plugin. ### Fixed - Make gauge selection for non-converged modes more robust. diff --git a/tests/test_plugins/test_microwave.py b/tests/test_plugins/test_microwave.py index 333c5a86ea..02006c1118 100644 --- a/tests/test_plugins/test_microwave.py +++ b/tests/test_plugins/test_microwave.py @@ -43,6 +43,14 @@ size=(1, 1, 0), freqs=FS, mode_spec=td.ModeSpec(num_modes=2), + name="mode_solver", + ), + td.ModeMonitor( + center=(0, 0, 0), + size=(1, 1, 0), + freqs=FS, + mode_spec=td.ModeSpec(num_modes=3), + store_fields_direction="+", name="mode", ), ], @@ -280,7 +288,7 @@ def test_time_monitor_voltage_integral(): def test_mode_solver_monitor_voltage_integral(): - """Check VoltageIntegralAxisAligned runs on mode solver data.""" + """Check VoltageIntegralAxisAligned runs on ModeData and ModeSolverData.""" length = 0.5 size = [0, 0, 0] size[1] = length @@ -291,6 +299,7 @@ def test_mode_solver_monitor_voltage_integral(): sign="+", ) + voltage_integral.compute_voltage(SIM_Z_DATA["mode_solver"]) voltage_integral.compute_voltage(SIM_Z_DATA["mode"]) @@ -495,12 +504,14 @@ def test_time_monitor_custom_current_integral(): def test_mode_solver_custom_current_integral(): + """Test that both ModeData and ModeSolverData are allowed types.""" length = 0.5 size = [0, 0, 0] size[1] = length # Make box vertices = [(0.2, -0.2), (0.2, 0.2), (-0.2, 0.2), (-0.2, -0.2), (0.2, -0.2)] current_integral = mw.CustomCurrentIntegral2D(axis=2, position=0, vertices=vertices) + current_integral.compute_current(SIM_Z_DATA["mode_solver"]) current_integral.compute_current(SIM_Z_DATA["mode"]) diff --git a/tests/utils.py b/tests/utils.py index 12ade06413..ee0ba3d47d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1024,21 +1024,30 @@ def make_diff_data(monitor: td.DiffractionMonitor) -> td.DiffractionData: def make_mode_data(monitor: td.ModeMonitor) -> td.ModeData: """make a random ModeData from a ModeMonitor.""" _ = np.arange(monitor.mode_spec.num_modes) - coords_ind = { - "f": list(monitor.freqs), - "mode_index": np.arange(monitor.mode_spec.num_modes), - } + index_coords = {} + index_coords["f"] = list(monitor.freqs) + index_coords["mode_index"] = np.arange(monitor.mode_spec.num_modes) n_complex = make_data( - coords=coords_ind, data_array_type=td.ModeIndexDataArray, is_complex=True + coords=index_coords, data_array_type=td.ModeIndexDataArray, is_complex=True ) coords_amps = dict(direction=["+", "-"]) - coords_amps.update(coords_ind) + coords_amps.update(index_coords) amps = make_data(coords=coords_amps, data_array_type=td.ModeAmpsDataArray, is_complex=True) + field_cmps = {} + if monitor.store_fields_direction is not None: + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + coords = get_spatial_coords_dict(simulation, monitor, field_name) + coords["f"] = list(monitor.freqs) + coords["mode_index"] = index_coords["mode_index"] + field_cmps[field_name] = make_data( + coords=coords, data_array_type=td.ScalarModeFieldDataArray, is_complex=True + ) return td.ModeData( monitor=monitor, n_complex=n_complex, amps=amps, grid_expanded=simulation.discretize_monitor(monitor), + **field_cmps, ) def make_flux_data(monitor: td.FluxMonitor) -> td.FluxData: diff --git a/tidy3d/plugins/microwave/custom_path_integrals.py b/tidy3d/plugins/microwave/custom_path_integrals.py index 5fcb981565..6b6c48469a 100644 --- a/tidy3d/plugins/microwave/custom_path_integrals.py +++ b/tidy3d/plugins/microwave/custom_path_integrals.py @@ -10,20 +10,18 @@ import xarray as xr from ...components.base import cached_property -from ...components.data.data_array import FreqDataArray, FreqModeDataArray, TimeDataArray -from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData from ...components.geometry.base import Geometry from ...components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate, Direction from ...components.viz import add_ax_if_none from ...constants import MICROMETER, fp_eps -from ...exceptions import DataError, SetupError +from ...exceptions import SetupError from .path_integrals import ( AbstractAxesRH, + AxisAlignedPathIntegral, CurrentIntegralAxisAligned, IntegralResultTypes, MonitorDataTypes, VoltageIntegralAxisAligned, - _check_em_field_supported, ) from .viz import ( ARROW_CURRENT, @@ -94,11 +92,9 @@ def compute_integral( h_field_name = f"{field}{dim1}" v_field_name = f"{field}{dim2}" + # Validate that fields are present - if h_field_name not in em_field.field_components: - raise DataError(f"'field_name' '{h_field_name}' not found.") - if v_field_name not in em_field.field_components: - raise DataError(f"'field_name' '{v_field_name}' not found.") + em_field._check_fields_stored([h_field_name, v_field_name]) # Select fields lying on the plane plane_indexer = {dim3: self.position} @@ -133,18 +129,7 @@ def compute_integral( # Integrate along the path result = integrand.integrate(coord="s") result = result.reset_coords(drop=True) - - if isinstance(em_field, FieldData): - return FreqDataArray(data=result.data, coords=result.coords) - elif isinstance(em_field, FieldTimeData): - return TimeDataArray(data=result.data, coords=result.coords) - else: - if not isinstance(em_field, ModeSolverData): - raise TypeError( - f"Unsupported 'em_field' type: {type(em_field)}. " - "Expected one of 'FieldData', 'FieldTimeData', 'ModeSolverData'." - ) - return FreqModeDataArray(data=result.data, coords=result.coords) + return AxisAlignedPathIntegral._make_result_data_array(result) @staticmethod def _compute_dl_component(coord_array: xr.DataArray, closed_contour=False) -> np.array: @@ -270,7 +255,7 @@ def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes: :class:`.IntegralResultTypes` Result of voltage computation over remaining dimensions (frequency, time, mode indices). """ - _check_em_field_supported(em_field=em_field) + AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field) voltage = -1.0 * self.compute_integral(field="E", em_field=em_field) voltage = VoltageIntegralAxisAligned._set_data_array_attributes(voltage) return voltage @@ -343,7 +328,7 @@ def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes: :class:`.IntegralResultTypes` Result of current computation over remaining dimensions (frequency, time, mode indices). """ - _check_em_field_supported(em_field=em_field) + AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field) current = self.compute_integral(field="H", em_field=em_field) current = CurrentIntegralAxisAligned._set_data_array_attributes(current) return current diff --git a/tidy3d/plugins/microwave/impedance_calculator.py b/tidy3d/plugins/microwave/impedance_calculator.py index 0d38874872..6fa2c54089 100644 --- a/tidy3d/plugins/microwave/impedance_calculator.py +++ b/tidy3d/plugins/microwave/impedance_calculator.py @@ -13,11 +13,11 @@ from ...exceptions import ValidationError from .custom_path_integrals import CustomCurrentIntegral2D, CustomVoltageIntegral2D from .path_integrals import ( + AxisAlignedPathIntegral, CurrentIntegralAxisAligned, IntegralResultTypes, MonitorDataTypes, VoltageIntegralAxisAligned, - _check_em_field_supported, ) VoltageIntegralTypes = Union[VoltageIntegralAxisAligned, CustomVoltageIntegral2D] @@ -55,7 +55,7 @@ def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes: :class:`.IntegralResultTypes` Result of impedance computation over remaining dimensions (frequency, time, mode indices). """ - _check_em_field_supported(em_field=em_field) + AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field) # If both voltage and current integrals have been defined then impedance is computed directly if self.voltage_integral: diff --git a/tidy3d/plugins/microwave/path_integrals.py b/tidy3d/plugins/microwave/path_integrals.py index 60101f1a1b..fdd9cf08b5 100644 --- a/tidy3d/plugins/microwave/path_integrals.py +++ b/tidy3d/plugins/microwave/path_integrals.py @@ -3,11 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Union import numpy as np import pydantic.v1 as pd import shapely as shapely +import xarray as xr from ...components.base import Tidy3dBaseModel, cached_property from ...components.data.data_array import ( @@ -18,7 +19,7 @@ ScalarModeFieldDataArray, TimeDataArray, ) -from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData +from ...components.data.monitor_data import FieldData, FieldTimeData, ModeData, ModeSolverData from ...components.geometry.base import Box, Geometry from ...components.types import Ax, Axis, Coordinate2D, Direction from ...components.validators import assert_line, assert_plane @@ -33,20 +34,11 @@ plot_params_voltage_plus, ) -MonitorDataTypes = Union[FieldData, FieldTimeData, ModeSolverData] +MonitorDataTypes = Union[FieldData, FieldTimeData, ModeData, ModeSolverData] EMScalarFieldType = Union[ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray] IntegralResultTypes = Union[FreqDataArray, FreqModeDataArray, TimeDataArray] -def _check_em_field_supported(em_field: Any): - """Function for validating correct data arrays.""" - if not isinstance(em_field, (FieldData, FieldTimeData, ModeSolverData)): - raise DataError( - "'em_field' type not supported. Supported types are " - "'FieldData', 'FieldTimeData', 'ModeSolverData'." - ) - - class AbstractAxesRH(Tidy3dBaseModel, ABC): """Represents an axis-aligned right-handed coordinate system with one axis preferred. Typically `main_axis` would refer to the normal axis of a plane. @@ -137,18 +129,7 @@ def compute_integral(self, scalar_field: EMScalarFieldType) -> IntegralResultTyp coords_interp, method=method, kwargs={"fill_value": "extrapolate"} ) result = scalar_field.integrate(coord=coord) - if isinstance(scalar_field, ScalarFieldDataArray): - return FreqDataArray(data=result.data, coords=result.coords) - elif isinstance(scalar_field, ScalarFieldTimeDataArray): - return TimeDataArray(data=result.data, coords=result.coords) - else: - if not isinstance(scalar_field, ScalarModeFieldDataArray): - raise TypeError( - f"Unsupported 'scalar_field' type: {type(scalar_field)}. " - "Expected one of 'ScalarFieldDataArray', 'ScalarFieldTimeDataArray', " - "'ScalarModeFieldDataArray'." - ) - return FreqModeDataArray(data=result.data, coords=result.coords) + return self._make_result_data_array(result) def _get_field_along_path(self, scalar_field: EMScalarFieldType) -> EMScalarFieldType: """Returns a selection of the input ``scalar_field`` ready for integration.""" @@ -206,6 +187,26 @@ def _vertices_2D(self, axis: Axis) -> tuple[Coordinate2D, Coordinate2D]: v = [min[1], max[1]] return (u, v) + @staticmethod + def _check_monitor_data_supported(em_field: MonitorDataTypes): + """Helper for validating that monitor data is supported.""" + if not isinstance(em_field, (FieldData, FieldTimeData, ModeData, ModeSolverData)): + supported_types = list(MonitorDataTypes.__args__) + raise DataError( + f"'em_field' type {type(em_field)} not supported. Supported types are " + f"{supported_types}" + ) + + @staticmethod + def _make_result_data_array(result: xr.DataArray) -> IntegralResultTypes: + """Helper for creating the proper result type.""" + if "t" in result.coords: + return TimeDataArray(data=result.data, coords=result.coords) + elif "f" in result.coords and "mode_index" in result.coords: + return FreqModeDataArray(data=result.data, coords=result.coords) + else: + return FreqDataArray(data=result.data, coords=result.coords) + class VoltageIntegralAxisAligned(AxisAlignedPathIntegral): """Class for computing the voltage between two points defined by an axis-aligned line.""" @@ -218,12 +219,11 @@ class VoltageIntegralAxisAligned(AxisAlignedPathIntegral): def compute_voltage(self, em_field: MonitorDataTypes) -> IntegralResultTypes: """Compute voltage along path defined by a line.""" - _check_em_field_supported(em_field=em_field) + self._check_monitor_data_supported(em_field=em_field) e_component = "xyz"[self.main_axis] field_name = f"E{e_component}" - # Validate that the field is present - if field_name not in em_field.field_components: - raise DataError(f"'field_name' '{field_name}' not found.") + # Validate that fields are present + em_field._check_fields_stored([field_name]) e_field = em_field.field_components[field_name] voltage = self.compute_integral(e_field) @@ -376,7 +376,7 @@ class CurrentIntegralAxisAligned(AbstractAxesRH, Box): def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes: """Compute current flowing in loop defined by the outer edge of a rectangle.""" - _check_em_field_supported(em_field=em_field) + AxisAlignedPathIntegral._check_monitor_data_supported(em_field=em_field) ax1 = self.remaining_axes[0] ax2 = self.remaining_axes[1] h_component = "xyz"[ax1] @@ -384,10 +384,7 @@ def compute_current(self, em_field: MonitorDataTypes) -> IntegralResultTypes: h_field_name = f"H{h_component}" v_field_name = f"H{v_component}" # Validate that fields are present - if h_field_name not in em_field.field_components: - raise DataError(f"'field_name' '{h_field_name}' not found.") - if v_field_name not in em_field.field_components: - raise DataError(f"'field_name' '{v_field_name}' not found.") + em_field._check_fields_stored([h_field_name, v_field_name]) h_horizontal = em_field.field_components[h_field_name] h_vertical = em_field.field_components[v_field_name]