Skip to content

Commit 5faee94

Browse files
committed
fix network call
1 parent 2643929 commit 5faee94

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

bioimageio/core/model_adapters/_tensorflow_model_adapter.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,15 @@ def _forward_keras( # pyright: ignore[reportUnknownParameterType]
200200
None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors
201201
]
202202

203-
result = self._network.call( # pyright: ignore[reportUnknownVariableType]
204-
*tf_tensor
205-
)
203+
result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
204+
205+
assert isinstance(result, dict)
206206

207-
if not isinstance(result, (tuple, list)):
208-
result = [result] # pyright: ignore[reportUnknownVariableType]
207+
# TODO: Use RDF's `outputs[i].id` here
208+
result = list(result.values())
209209

210210
return [ # pyright: ignore[reportUnknownVariableType]
211-
(
212-
None
213-
if r is None
214-
else r if isinstance(r, np.ndarray) else tf.make_ndarray(r)
215-
)
211+
(None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
216212
for r in result # pyright: ignore[reportUnknownVariableType]
217213
]
218214

0 commit comments

Comments
 (0)