Skip to content

Commit

Permalink
Switch from model.predict(x) to model(x).numpy() for TF performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Jan 19, 2025
1 parent d4cebe0 commit ab97032
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ To check if the input values really are the same, you can print them, in Python
input = ...
print(input)
print(input.shape)
result = model.predict([input])
result = model([input]).numpy()
print(result)
print(result.shape) # result[0].shape in case of multiple output tensors
```
Expand Down
9 changes: 6 additions & 3 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_model_input_layers(model):
def measure_predict(model, data_in):
"""Returns output and duration in seconds"""
start_time = datetime.datetime.now()
data_out = model.predict(data_in)
data_out = model(data_in).numpy()
end_time = datetime.datetime.now()
duration = end_time - start_time
print('Forward pass took {} s.'.format(duration.total_seconds()))
Expand Down Expand Up @@ -558,11 +558,14 @@ def get_all_weights(model, prefix):
layers = model.layers
assert K.image_data_format() == 'channels_last'
for layer in layers:
layer_type = type(layer).__name__
for node in layer._inbound_nodes:
if "training" in node.arguments.kwargs:
assert node.arguments.kwargs["training"] is not True, \
is_layer_with_accidental_training_flag = layer_type in ("CenterCrop", "Resizing")
has_training = node.arguments.kwargs["training"] is True
assert not has_training or is_layer_with_accidental_training_flag, \
"training=true is not supported, see https://github.com/Dobiasd/frugally-deep/issues/284"
layer_type = type(layer).__name__

name = prefix + layer.name
assert is_ascii(name)
if name in result:
Expand Down

0 comments on commit ab97032

Please sign in to comment.