Skip to content

Commit 7a4e315

Browse files
fix: DEV-1693: Fix model extraction for ML_BACKEND_V2 (#86)
* fix: DEV-1693: Fix model extraction for ML_BACKEND_V2 * Change default value for LABEL_STUDIO_ML_BACKEND_V2
1 parent c5e317d commit 7a4e315

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

label_studio_ml/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from flask import Flask, request, jsonify
55
from rq.exceptions import NoSuchJobError
66

7-
from .model import LabelStudioMLManager
7+
from .model import LabelStudioMLManager, LABEL_STUDIO_ML_BACKEND_V2_DEFAULT
88
from .exceptions import exception_handler
99

1010
logger = logging.getLogger(__name__)
@@ -95,7 +95,7 @@ def health():
9595
return jsonify({
9696
'status': 'UP',
9797
'model_dir': _manager.model_dir,
98-
'v2': os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=True)
98+
'v2': os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT)
9999
})
100100

101101

label_studio_ml/model.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
logger = logging.getLogger(__name__)
3737

38+
LABEL_STUDIO_ML_BACKEND_V2_DEFAULT = False
39+
3840
@attr.s
3941
class ModelWrapper(object):
4042
model = attr.ib()
@@ -444,24 +446,40 @@ def _key(cls, project):
444446

445447
@classmethod
446448
def has_active_model(cls, project):
447-
return cls._key(project) in cls._current_model
449+
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
450+
# TODO: Deprecated branch since LS 1.5
451+
return cls._key(project) in cls._current_model
452+
else:
453+
return cls._current_model is not None
448454

449455
@classmethod
450456
def get(cls, project):
451-
key = cls._key(project)
452-
logger.debug('Get project ' + str(key))
453-
return cls._current_model.get(key)
457+
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
458+
# TODO: Deprecated branch since LS 1.5
459+
key = cls._key(project)
460+
logger.debug('Get project ' + str(key))
461+
return cls._current_model.get(key)
462+
else:
463+
return cls._current_model
454464

455465
@classmethod
456466
def create(cls, project=None, label_config=None, train_output=None, version=None, **kwargs):
457467
key = cls._key(project)
458468
logger.debug('Create project ' + str(key))
459469
kwargs.update(cls.init_kwargs)
460-
cls._current_model[key] = ModelWrapper(
461-
model=cls.model_class(label_config=label_config, train_output=train_output, **kwargs),
462-
model_version=version or cls._generate_version()
463-
)
464-
return cls._current_model[key]
470+
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
471+
# TODO: Deprecated branch since LS 1.5
472+
cls._current_model[key] = ModelWrapper(
473+
model=cls.model_class(label_config=label_config, train_output=train_output, **kwargs),
474+
model_version=version or cls._generate_version()
475+
)
476+
return cls._current_model[key]
477+
else:
478+
cls._current_model = ModelWrapper(
479+
model=cls.model_class(label_config=label_config, train_output=train_output, **kwargs),
480+
model_version=version or cls._generate_version()
481+
)
482+
return cls._current_model
465483

466484
@classmethod
467485
def get_or_create(
@@ -476,8 +494,8 @@ def get_or_create(
476494

477495
@classmethod
478496
def fetch(cls, project=None, label_config=None, force_reload=False, **kwargs):
479-
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=True):
480-
# TODO: Deprecated branch
497+
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
498+
# TODO: Deprecated branch since LS 1.5
481499
if cls.without_redis():
482500
logger.debug('Fetch ' + project + ' from local directory')
483501
job_result = cls._get_latest_job_result_from_workdir(project) or {}
@@ -554,7 +572,7 @@ def is_training(cls, project):
554572
def predict(
555573
cls, tasks, project=None, label_config=None, force_reload=False, try_fetch=True, **kwargs
556574
):
557-
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=True):
575+
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
558576
if try_fetch:
559577
m = cls.fetch(project, label_config, force_reload)
560578
else:

0 commit comments

Comments
 (0)