Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use dispersion fitter without rich.progress #2258

Draft
wants to merge 1 commit into
base: pre/2.8
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/test_plugins/test_dispersion_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import responses
import tidy3d as td
from tidy3d.components.dispersion_fitter import fit
from tidy3d.exceptions import SetupError, ValidationError
from tidy3d.plugins.dispersion import (
AdvancedFastFitterParam,
Expand Down Expand Up @@ -285,3 +286,18 @@ def test_dispersion_loss_samples():
ep = nAlGaN_mat.eps_model(freq_list)
for e in ep:
assert e.imag >= 0


@responses.activate
def test_fit_no_progress(random_data):
wvl_um, n_data, k_data = random_data
eps_complex = (n_data + 1j * k_data) ** 2
omega = 2 * np.pi * td.C_0 / wvl_um

medium, rms = fit(
omega_data=omega,
resp_data=eps_complex,
scale_factor=td.HBAR,
advanced_param=advanced_param,
show_progress=False,
)
82 changes: 81 additions & 1 deletion tidy3d/components/dispersion_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import scipy
from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator
from rich.progress import Progress

from ..constants import fp_eps
from ..exceptions import ValidationError
Expand Down Expand Up @@ -759,6 +758,7 @@ def fit(
tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS,
advanced_param: AdvancedFastFitterParam = None,
scale_factor: PositiveFloat = 1,
show_progress: bool = True,
) -> Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float]:
"""Fit data using a fast fitting algorithm.

Expand Down Expand Up @@ -815,6 +815,8 @@ def fit(
Advanced parameters for fitting.
scale_factor : PositiveFloat, optional
Factor to rescale frequency by before fitting.
show_progress : bool = True
Whether to show a progress bar for the fitting.

Returns
-------
Expand All @@ -823,6 +825,8 @@ def fit(
The dispersive medium parameters have the form (resp_inf, poles, residues)
and are in the original unscaled units.
"""
if show_progress:
from rich.progress import Progress

if max_num_poles < min_num_poles:
raise ValidationError(
Expand Down Expand Up @@ -862,6 +866,82 @@ def make_configs():

configs = make_configs()

if not show_progress:
for num_poles, relaxed, smooth, logspacing, optimize_eps_inf in configs:
model = init_model.updated_copy(
num_poles=num_poles,
relaxed=relaxed,
smooth=smooth,
logspacing=logspacing,
optimize_eps_inf=optimize_eps_inf,
)
model = _fit_fixed_parameters((min_num_poles, max_num_poles), model)

if model.rms_error < best_model.rms_error:
log.debug(
f"Fitter: possible improved fit with "
f"rms_error={model.rms_error:.3g} found using "
f"relaxed={model.relaxed}, "
f"smooth={model.smooth}, "
f"logspacing={model.logspacing}, "
f"optimize_eps_inf={model.optimize_eps_inf}, "
f"loss_in_bounds={model.loss_in_bounds}, "
f"passivity_optimized={model.passivity_optimized}, "
f"sellmeier_passivity={model.sellmeier_passivity}."
)
if model.loss_in_bounds and model.sellmeier_passivity:
best_model = model
else:
if not warned_about_passivity_num_iters and model.passivity_num_iters_too_small:
warned_about_passivity_num_iters = True
log.warning(
"Did not finish enforcing passivity in dispersion fitter. "
"If the fit is not good enough, consider increasing "
"'AdvancedFastFitterParam.passivity_num_iters'."
)
if (
not warned_about_slsqp_constraint_scale
and model.slsqp_constraint_scale_too_small
):
warned_about_slsqp_constraint_scale = True
log.warning(
"SLSQP constraint scale may be too small. "
"If the fit is not good enough, consider increasing "
"'AdvancedFastFitterParam.slsqp_constraint_scale'."
)

# if below tolerance, return
if best_model.rms_error < tolerance_rms:
log.info(
"Found optimal fit with weighted RMS error %.3g",
best_model.rms_error,
)
if best_model.show_unweighted_rms:
log.info(
"Unweighted RMS error %.3g",
best_model.unweighted_rms_error,
)
return (
best_model.pole_residue,
best_model.rms_error,
)

# if exited loop, did not reach tolerance (warn)
log.warning(
"Unable to fit with weighted RMS error under 'tolerance_rms' of %.3g", tolerance_rms
)
log.info("Returning best fit with weighted RMS error %.3g", best_model.rms_error)
if best_model.show_unweighted_rms:
log.info(
"Unweighted RMS error %.3g",
best_model.unweighted_rms_error,
)

return (
best_model.pole_residue,
best_model.rms_error,
)

with Progress(console=get_logging_console()) as progress:
task = progress.add_task(
f"Fitting to weighted RMS of {tolerance_rms}...",
Expand Down