Skip to content

Commit 3a028e8

Browse files
fix: DEV-2523: Support webhook data loading in NER ml backend example
1 parent da74915 commit 3a028e8

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

label_studio_ml/examples/ner/ner.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from transformers import AdamW, get_linear_schedule_with_warmup
2626

2727
from label_studio_ml.model import LabelStudioMLBase
28+
from label_studio_ml.utils import get_annotated_dataset
2829
from utils import calc_slope
2930

30-
3131
logger = logging.getLogger(__name__)
3232

3333

@@ -342,7 +342,7 @@ def __init__(self, **kwargs):
342342
self.to_name = self.info['to_name'][0]
343343
self.value = self.info['inputs'][0]['value']
344344

345-
if not self.train_output:
345+
if not self.train_output or (not self.train_output.get('model_path')):
346346
self.labels = self.info['labels']
347347
else:
348348
self.load(self.train_output)
@@ -464,6 +464,13 @@ def fit(
464464
warmup_steps=0, save_steps=50, dump_dataset=True, cache_dir='~/.heartex/cache', train_logs=None,
465465
**kwargs
466466
):
467+
# check if training is from web hook
468+
if kwargs.get('data'):
469+
project_id = kwargs['data']['project']['id']
470+
completions = get_annotated_dataset(project_id)
471+
# assert that there annotations
472+
assert len(completions) > 0
473+
467474
train_logs = train_logs or os.path.join(workdir, 'train_logs')
468475
os.makedirs(train_logs, exist_ok=True)
469476
logger.debug('Prepare models')

label_studio_ml/examples/simple_text_classifier/simple_text_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, **kwargs):
4040
self.to_name = self.info['to_name'][0]
4141
self.value = self.info['inputs'][0]['value']
4242

43-
if not self.train_output:
43+
if (not self.train_output) or (self.train_output and not self.train_output.get('model_file')):
4444
# If there is no trainings, define cold-started the simple TF-IDF text classifier
4545
self.reset_model()
4646
# This is an array of <Choice> labels

label_studio_ml/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@
2929
from rq.job import Job
3030
from colorama import Fore
3131

32-
from label_studio_tools.core.utils.params import get_bool_env
32+
from label_studio_tools.core.utils.params import get_bool_env, get_env
3333
from label_studio_tools.core.label_config import parse_config
3434
from label_studio_tools.core.utils.io import get_local_path
3535

3636
logger = logging.getLogger(__name__)
3737

3838
LABEL_STUDIO_ML_BACKEND_V2_DEFAULT = False
39+
LABEL_STUDIO_STRICT_ERRORS = get_env("LS_STRICT_ERRORS", False)
3940

4041
@attr.s
4142
class ModelWrapper(object):
@@ -189,12 +190,12 @@ def _get_result_from_job_id(self, job_id):
189190
if not os.path.exists(job_dir):
190191
logger.warning(f"=> Warning: {job_id} dir doesn't exist. "
191192
f"It seems that you don't have specified model dir.")
192-
return None
193+
return None if LABEL_STUDIO_STRICT_ERRORS else {}
193194
result_file = os.path.join(job_dir, self.JOB_RESULT)
194195
if not os.path.exists(result_file):
195196
logger.warning(f"=> Warning: {job_id} dir doesn't contain result file. "
196197
f"It seems that previous training session ended with error.")
197-
return None
198+
return None if LABEL_STUDIO_STRICT_ERRORS else {}
198199
logger.debug(f'Read result from {result_file}')
199200
with open(result_file) as f:
200201
result = json.load(f)

label_studio_ml/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import logging
3+
import requests
24

35
from PIL import Image
46

@@ -48,3 +50,17 @@ def get_image_local_path(url, image_cache_dir=None, project_dir=None, image_dir=
4850

4951
def get_image_size(filepath):
5052
return Image.open(filepath).size
53+
54+
55+
def get_annotated_dataset(project_id, hostname=None, api_key=None):
56+
"""Just for demo purposes: retrieve annotated data from Label Studio API"""
57+
if hostname is None:
58+
hostname = get_env('HOSTNAME')
59+
if api_key is None:
60+
api_key = get_env('API_KEY')
61+
download_url = f'{hostname.rstrip("/")}/api/projects/{project_id}/export'
62+
response = requests.get(download_url, headers={'Authorization': f'Token {api_key}'})
63+
if response.status_code != 200:
64+
raise Exception(f"Can't load task data using {download_url}, "
65+
f"response status_code = {response.status_code}")
66+
return json.loads(response.content)

0 commit comments

Comments
 (0)