Skip to content

Commit d45de88

Browse files
committed
make predict for single observation available in package
1 parent 81d2a78 commit d45de88

File tree

4 files changed

+13
-14
lines changed

4 files changed

+13
-14
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ test-package:
4242
coverage run -m pytest ./tests/
4343
coverage report -m --include=./tests/*
4444

45-
test: generate-dataset train prediction test-package ## run extensive tests
45+
test: generate-dataset train prediction clean test-package ## run extensive tests
4646

4747
help: ## show help on available commands
4848
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'

ml_skeleton_py/model/__init__.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
from .train import train
44

5-
from .predict import predict_from_file
5+
from .predict import predict_from_file, predict
66

7-
__all__ = [
8-
"train",
9-
"predict_from_file",
10-
]
7+
__all__ = ["train", "predict", "predict_from_file"]

ml_skeleton_py/model/predict.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def load_model(model_name: str) -> BaseEstimator:
2828
Uses lru for caching.
2929
3030
Parameters:
31-
model_name (str): model name e.g. "lr.p"
31+
model_name (str): model name (including extension) e.g. "lr.p"
3232
3333
Returns:
3434
model (Pipeline or BaseEstimator): a model that can make predictions
@@ -38,14 +38,16 @@ def load_model(model_name: str) -> BaseEstimator:
3838
return model
3939

4040

41-
def predict(model_name: str, observation: np.array) -> float:
41+
def predict(observation: np.array, model_name: str = "lr.p") -> float:
4242
"""
4343
Predict one single observation.
4444
4545
Parameters:
46-
model_name (str): the name of the model file that you want to use
46+
observation (np.array): the input observation
4747
48-
observation (np.array): the variables of the input observation
48+
model_name (str): the name of the model file that you want to load
49+
(including extension)
50+
default value = "lr.p"
4951
5052
Return:
5153
prediction (float): the prediction
@@ -82,7 +84,7 @@ def predict_from_file(model: str, input_df: str, output_df: str) -> np.array:
8284
logger.info(f"running predictions for input: {input_data}")
8385

8486
# make predictions
85-
preds = [predict(model, [x]) for x in np.array(input_data)]
87+
preds = [predict([x], model) for x in np.array(input_data)]
8688
# transform single axis array to a column
8789
preds = np.array(preds).reshape(-1, 1)
8890

tests/test_predict.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_predict_1() -> None:
4747
-1.25077894,
4848
-0.30741284,
4949
]
50-
prediction = predict(model_name, [observation])
50+
prediction = predict([observation], model_name)
5151
assert type(float(prediction)) == float
5252

5353

@@ -88,7 +88,7 @@ def test_predict_2() -> None:
8888
-0.67713242,
8989
-0.20959966,
9090
]
91-
prediction = predict(model_name, [observation])
91+
prediction = predict([observation], model_name)
9292
assert type(float(prediction)) == float
9393

9494

@@ -129,7 +129,7 @@ def test_predict_3() -> None:
129129
-1.08784478,
130130
3.03081115,
131131
]
132-
prediction = predict(model_name, [observation])
132+
prediction = predict([observation], model_name)
133133
assert type(float(prediction)) == float
134134

135135

0 commit comments

Comments
 (0)