Skip to content

Commit b591c51

Browse files
committed
fix pipeline
1 parent 36d4dac commit b591c51

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ exclude =
88
build,
99
dist,
1010
tests
11-
max-line-length = 88
11+
max-line-length = 88 # black default
1212
ignore = D202,W503,E203 # conflicts with black

Makefile

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ NO_OF_REPORT_FILES := $(words $(filter-out reports/.gitkeep, $(SRC_FILES)))
1515

1616
clean: ## clean artifacts
1717
@echo ">>> cleaning files"
18-
rm ./data/predictions/* ./data/transformed/* ./models/*.joblib
18+
rm ./data/predictions/* ./data/transformed/* ./models/*.joblib || true
1919

2020
generate-dataset: ## run ETL pipeline
2121
@echo ">>> generating dataset"
@@ -29,9 +29,7 @@ serve: ## Serve model with a REST API using dploy-kickstart
2929
@echo ">>> serving the trained model"
3030
kickstart serve -e ml_skeleton_py/model/predict.py -l .
3131

32-
run-pipeline: ## Run entire pipeline (clean artifacts -> generate-dataset -> train model -> serve model)
33-
@echo ">>> running the pipeline"
34-
clean clean generate-dataset train serve ## clean artifacts generate dataset & train the model & serve
32+
run-pipeline: clean generate-dataset train serve ## clean artifacts generate dataset & train the model & serve
3533

3634
lint: ## flake8 linting and black code style
3735
@echo ">>> black files"

scripts/train.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import argparse
2+
import os
23

34
from ml_skeleton_py import model
5+
from ml_skeleton_py import settings as s
46

57

68
def train() -> None:
@@ -19,7 +21,8 @@ def train() -> None:
1921
help="the serialized model name default lr " "referring to logistic regression",
2022
)
2123
args = parser.parse_args()
22-
model.train(args.dataset, args.model_name)
24+
transformed_data_dir = os.path.join(s.DATA_TRANSFORMED, args.dataset)
25+
model.train(transformed_data_dir, s.MODEL_DIR, args.model_name)
2326

2427

2528
if __name__ == "__main__":

0 commit comments

Comments
 (0)