Skip to content

Commit 00bc981

Browse files
committed
pyright fixes
1 parent 5fefb84 commit 00bc981

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

bioimageio/core/backends/keras_backend.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
except Exception:
2828
import keras
2929

30+
tf_version = None
31+
3032

3133
class KerasModelAdapter(ModelAdapter):
3234
def __init__(

bioimageio/core/backends/onnx_backend.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportUnknownVariableType=false
12
import warnings
23
from typing import Any, List, Optional, Sequence, Union
34

@@ -25,8 +26,8 @@ def __init__(
2526

2627
local_path = download(model_description.weights.onnx.source).path
2728
self._session = rt.InferenceSession(local_path.read_bytes())
28-
onnx_inputs = self._session.get_inputs() # type: ignore
29-
self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore
29+
onnx_inputs = self._session.get_inputs()
30+
self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
3031

3132
if devices is not None:
3233
warnings.warn(
@@ -40,11 +41,11 @@ def _forward_impl(
4041
None, dict(zip(self._input_names, input_arrays))
4142
)
4243
if is_list(result) or is_tuple(result):
43-
result_seq = result
44+
result_seq = list(result)
4445
else:
4546
result_seq = [result]
4647

47-
return result_seq # pyright: ignore[reportReturnType]
48+
return result_seq
4849

4950
def unload(self) -> None:
5051
warnings.warn(

bioimageio/core/backends/pytorch_backend.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,16 @@ def load_torch_state_dict(
146146
state = torch.load(f, map_location=devices[0], weights_only=True)
147147

148148
incompatible = model.load_state_dict(state)
149-
if incompatible is not None and incompatible.missing_keys:
149+
if (
150+
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
151+
and incompatible.missing_keys
152+
):
150153
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
151154

152-
if incompatible is not None and incompatible.unexpected_keys:
155+
if (
156+
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
157+
and incompatible.unexpected_keys
158+
):
153159
logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys)
154160

155161
return model

bioimageio/core/backends/torchscript_backend.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportUnknownVariableType=false
12
import gc
23
import warnings
34
from typing import Any, List, Optional, Sequence, Union

0 commit comments

Comments
 (0)