diff --git a/predictor/prediction.py b/predictor/prediction.py index fd098ef..4bedf8d 100644 --- a/predictor/prediction.py +++ b/predictor/prediction.py @@ -12,11 +12,10 @@ try: import tflite_runtime.interpreter as tflite except ImportError: - print( - "TFlite_runtime is not installed , Predictions with .tflite extension won't work" - ) + print("TFlite_runtime is not installed.") try: - from tensorflow import keras + from tensorflow import keras, lite + except ImportError: print("Tensorflow is not installed , Predictions with .h5 or .tf won't work") @@ -66,7 +65,11 @@ def run_prediction( start = time.time() print(f"Using : {checkpoint_path}") if checkpoint_path.endswith(".tflite"): - interpreter = tflite.Interpreter(model_path=checkpoint_path) + try: + interpreter = tflite.Interpreter(model_path=checkpoint_path) + except Exception as ex: + interpreter = lite.Interpreter(model_path=checkpoint_path) + interpreter.resize_tensor_input( interpreter.get_input_details()[0]["index"], (BATCH_SIZE, 256, 256, 3) )