diff --git a/backend/core/views.py b/backend/core/views.py index a2163ac9..d920cd40 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -60,8 +60,7 @@ ModelSerializer, PredictionParamSerializer, ) -# from .tasks import train_model -from celery import Celery +from .tasks import train_model from .utils import get_dir_size, gpx_generator, process_rawdata, request_rawdata @@ -129,10 +128,8 @@ def create(self, validated_data): # create the model instance instance = Training.objects.create(**validated_data) - celery = Celery() - # run your function here - task = celery.train_model.delay( + task = train_model.delay( dataset_id=instance.model.dataset.id, training_id=instance.id, epochs=instance.epochs, @@ -474,9 +471,7 @@ def post(self, request, *args, **kwargs): batch_size=batch_size, source_imagery=training_instance.source_imagery, ) - celery = Celery() - - task = celery.train_model.delay( + task = train_model.delay( dataset_id=instance.model.dataset.id, training_id=instance.id, epochs=instance.epochs,