3
3
from __future__ import annotations
4
4
5
5
from abc import ABC , abstractmethod
6
- from typing import Any , Union
6
+ from typing import Union
7
7
8
8
import numpy as np
9
9
import pydantic .v1 as pd
10
10
import shapely as shapely
11
+ import xarray as xr
11
12
12
13
from ...components .base import Tidy3dBaseModel , cached_property
13
14
from ...components .data .data_array import (
18
19
ScalarModeFieldDataArray ,
19
20
TimeDataArray ,
20
21
)
21
- from ...components .data .monitor_data import FieldData , FieldTimeData , ModeSolverData
22
+ from ...components .data .monitor_data import FieldData , FieldTimeData , ModeData , ModeSolverData
22
23
from ...components .geometry .base import Box , Geometry
23
24
from ...components .types import Ax , Axis , Coordinate2D , Direction
24
25
from ...components .validators import assert_line , assert_plane
33
34
plot_params_voltage_plus ,
34
35
)
35
36
36
- MonitorDataTypes = Union [FieldData , FieldTimeData , ModeSolverData ]
37
+ MonitorDataTypes = Union [FieldData , FieldTimeData , ModeData , ModeSolverData ]
37
38
EMScalarFieldType = Union [ScalarFieldDataArray , ScalarFieldTimeDataArray , ScalarModeFieldDataArray ]
38
39
IntegralResultTypes = Union [FreqDataArray , FreqModeDataArray , TimeDataArray ]
39
40
40
41
41
- def _check_em_field_supported (em_field : Any ):
42
- """Function for validating correct data arrays."""
43
- if not isinstance (em_field , (FieldData , FieldTimeData , ModeSolverData )):
44
- raise DataError (
45
- "'em_field' type not supported. Supported types are "
46
- "'FieldData', 'FieldTimeData', 'ModeSolverData'."
47
- )
48
-
49
-
50
42
class AbstractAxesRH (Tidy3dBaseModel , ABC ):
51
43
"""Represents an axis-aligned right-handed coordinate system with one axis preferred.
52
44
Typically `main_axis` would refer to the normal axis of a plane.
@@ -137,18 +129,7 @@ def compute_integral(self, scalar_field: EMScalarFieldType) -> IntegralResultTyp
137
129
coords_interp , method = method , kwargs = {"fill_value" : "extrapolate" }
138
130
)
139
131
result = scalar_field .integrate (coord = coord )
140
- if isinstance (scalar_field , ScalarFieldDataArray ):
141
- return FreqDataArray (data = result .data , coords = result .coords )
142
- elif isinstance (scalar_field , ScalarFieldTimeDataArray ):
143
- return TimeDataArray (data = result .data , coords = result .coords )
144
- else :
145
- if not isinstance (scalar_field , ScalarModeFieldDataArray ):
146
- raise TypeError (
147
- f"Unsupported 'scalar_field' type: { type (scalar_field )} . "
148
- "Expected one of 'ScalarFieldDataArray', 'ScalarFieldTimeDataArray', "
149
- "'ScalarModeFieldDataArray'."
150
- )
151
- return FreqModeDataArray (data = result .data , coords = result .coords )
132
+ return self ._make_result_data_array (result )
152
133
153
134
def _get_field_along_path (self , scalar_field : EMScalarFieldType ) -> EMScalarFieldType :
154
135
"""Returns a selection of the input ``scalar_field`` ready for integration."""
@@ -206,6 +187,26 @@ def _vertices_2D(self, axis: Axis) -> tuple[Coordinate2D, Coordinate2D]:
206
187
v = [min [1 ], max [1 ]]
207
188
return (u , v )
208
189
190
+ @staticmethod
191
+ def _check_monitor_data_supported (em_field : MonitorDataTypes ):
192
+ """Helper for validating that monitor data is supported."""
193
+ if not isinstance (em_field , (FieldData , FieldTimeData , ModeData , ModeSolverData )):
194
+ supported_types = list (MonitorDataTypes .__args__ )
195
+ raise DataError (
196
+ f"'em_field' type { type (em_field )} not supported. Supported types are "
197
+ f"{ supported_types } "
198
+ )
199
+
200
+ @staticmethod
201
+ def _make_result_data_array (result : xr .DataArray ) -> IntegralResultTypes :
202
+ """Helper for creating the proper result type."""
203
+ if "t" in result .coords :
204
+ return TimeDataArray (data = result .data , coords = result .coords )
205
+ elif "f" in result .coords and "mode_index" in result .coords :
206
+ return FreqModeDataArray (data = result .data , coords = result .coords )
207
+ else :
208
+ return FreqDataArray (data = result .data , coords = result .coords )
209
+
209
210
210
211
class VoltageIntegralAxisAligned (AxisAlignedPathIntegral ):
211
212
"""Class for computing the voltage between two points defined by an axis-aligned line."""
@@ -218,12 +219,11 @@ class VoltageIntegralAxisAligned(AxisAlignedPathIntegral):
218
219
219
220
def compute_voltage (self , em_field : MonitorDataTypes ) -> IntegralResultTypes :
220
221
"""Compute voltage along path defined by a line."""
221
- _check_em_field_supported (em_field = em_field )
222
+ self . _check_monitor_data_supported (em_field = em_field )
222
223
e_component = "xyz" [self .main_axis ]
223
224
field_name = f"E{ e_component } "
224
- # Validate that the field is present
225
- if field_name not in em_field .field_components :
226
- raise DataError (f"'field_name' '{ field_name } ' not found." )
225
+ # Validate that fields are present
226
+ em_field ._check_fields_stored ([field_name ])
227
227
e_field = em_field .field_components [field_name ]
228
228
229
229
voltage = self .compute_integral (e_field )
@@ -376,18 +376,15 @@ class CurrentIntegralAxisAligned(AbstractAxesRH, Box):
376
376
377
377
def compute_current (self , em_field : MonitorDataTypes ) -> IntegralResultTypes :
378
378
"""Compute current flowing in loop defined by the outer edge of a rectangle."""
379
- _check_em_field_supported (em_field = em_field )
379
+ AxisAlignedPathIntegral . _check_monitor_data_supported (em_field = em_field )
380
380
ax1 = self .remaining_axes [0 ]
381
381
ax2 = self .remaining_axes [1 ]
382
382
h_component = "xyz" [ax1 ]
383
383
v_component = "xyz" [ax2 ]
384
384
h_field_name = f"H{ h_component } "
385
385
v_field_name = f"H{ v_component } "
386
386
# Validate that fields are present
387
- if h_field_name not in em_field .field_components :
388
- raise DataError (f"'field_name' '{ h_field_name } ' not found." )
389
- if v_field_name not in em_field .field_components :
390
- raise DataError (f"'field_name' '{ v_field_name } ' not found." )
387
+ em_field ._check_fields_stored ([h_field_name , v_field_name ])
391
388
h_horizontal = em_field .field_components [h_field_name ]
392
389
h_vertical = em_field .field_components [v_field_name ]
393
390
0 commit comments