Skip to content

Commit

Permalink
fix ruff tests since last check
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed May 15, 2024
1 parent 2cee466 commit be8e8a1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 33 deletions.
14 changes: 1 addition & 13 deletions tests/test_plugins/test_component_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,6 @@
)
from tidy3d.exceptions import SetupError, Tidy3dKeyError
from ..utils import run_emulated
from ..test_web.test_webapi import (
mock_upload,
mock_metadata,
mock_get_info,
mock_start,
mock_monitor,
mock_download,
mock_load,
mock_job_status,
mock_load,
set_api_key,
)

# Waveguide height
wg_height = 0.22
Expand Down Expand Up @@ -402,7 +390,7 @@ def test_import_smatrix_smatrix():

def test_to_from_file_batch(monkeypatch, tmp_path):
modeler = make_component_modeler(path_dir=str(tmp_path))
s_matrix = run_component_modeler(monkeypatch, modeler)
_ = run_component_modeler(monkeypatch, modeler)

batch = td.web.Batch(simulations=dict())

Expand Down
34 changes: 15 additions & 19 deletions tests/test_plugins/test_invdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,20 @@
import tidy3d as td
import tidy3d.plugins.adjoint as tda
import tidy3d.plugins.invdes as tdi
import matplotlib.pyplot as plt

# use single threading pipeline
from . import test_adjoint as ta

ta.NUM_PROC_PARALLEL = 1

from .test_adjoint import use_emulated_run, use_emulated_run_async
from ..utils import run_emulated, log_capture, assert_log_level, AssertLogLevel
from ..utils import run_emulated, assert_log_level, AssertLogLevel

FREQ0 = 1e14
L_SIM = 1.0
MNT_NAME1 = "mnt_name1"
MNT_NAME2 = "mnt_name2"
HISTORY_FNAME = "tests/data/invdes_history.pkl"

ta.NUM_PROC_PARALLEL = 1

mnt1 = td.FieldMonitor(
center=(L_SIM / 3.0, 0, 0), size=(0, td.inf, td.inf), freqs=[FREQ0], name=MNT_NAME1
)
Expand Down Expand Up @@ -72,10 +70,10 @@ def test_region_params():

design_region = make_design_region()

PARAMS_0 = np.random.random(design_region.params_shape)
PARAMS_0 = design_region.params_random
PARAMS_0 = design_region.params_ones
PARAMS_0 = design_region.params_zeros
_ = np.random.random(design_region.params_shape)
_ = design_region.params_random
_ = design_region.params_ones
_ = design_region.params_zeros


def test_region_penalties():
Expand Down Expand Up @@ -287,7 +285,7 @@ def test_invdes_multi_same_length():
output_monitor_names = [([MNT_NAME1, MNT_NAME2], None)[i % 2] for i in range(n)]
invdes = invdes.updated_copy(output_monitor_names=output_monitor_names)

ds = invdes.designs
_ = invdes.designs


def make_optimizer():
Expand Down Expand Up @@ -316,7 +314,7 @@ def test_default_params(use_emulated_run):

optimizer = make_optimizer()

PARAMS_0 = np.random.random(optimizer.design.design_region.params_shape)
_ = np.random.random(optimizer.design.design_region.params_shape)

optimizer.run()

Expand Down Expand Up @@ -366,7 +364,7 @@ def test_result_store_full_results_is_false(use_emulated_run):
assert len(result.history[key]) == optimizer.num_steps

# this should still work, even if ``store_full_results == False``
val_last1 = result.last["params"]
_ = result.last["params"]


def test_continue_run_fns(use_emulated_run):
Expand Down Expand Up @@ -411,21 +409,21 @@ def test_result(use_emulated_run, use_emulated_run_async, tmp_path):
assert np.allclose(val_last1, val_last2)

result.plot_optimization()
sim_data_last = result.sim_data_last(task_name="last")
_ = result.sim_data_last(task_name="last")


def test_result_data(use_emulated_run):
"""Test methods of the ``InverseDesignResult`` object."""

result = make_result(use_emulated_run)
sim_last = result.sim_last
sim_data_last = result.sim_data_last(task_name="last")
_ = result.sim_last
_ = result.sim_data_last(task_name="last")


def test_result_data_multi(use_emulated_run_async, tmp_path):
result_multi = make_result_multi(use_emulated_run_async)
sim_last = result_multi.sim_last
sim_data_last = result_multi.sim_data_last(task_name="last")
_ = result_multi.sim_last
_ = result_multi.sim_data_last(task_name="last")


def test_result_empty():
Expand Down Expand Up @@ -496,8 +494,6 @@ def test_jax_array_impl_import_pass(tmp_path, log_capture):
def test_fn_source_error(monkeypatch, exception, ok):
"""Make sure type errors are caught when grabbing function source code."""

import inspect

def getsource_error(*args, **kwargs):
raise exception

Expand Down

0 comments on commit be8e8a1

Please sign in to comment.