Skip to content

Commit 87a33b4

Browse files
Update examples to use webhook training
1 parent 3a028e8 commit 87a33b4

File tree

9 files changed

+38
-6
lines changed

9 files changed

+38
-6
lines changed

label_studio_ml/examples/bert/bert_classifier.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
from label_studio_ml.model import LabelStudioMLBase
1414

15-
from utils import prepare_texts, calc_slope
16-
15+
from utils import prepare_texts, calc_slope, get_annotated_dataset
1716

1817
if torch.cuda.is_available():
1918
device = torch.device("cuda")
@@ -128,6 +127,10 @@ def predict(self, tasks, **kwargs):
128127
return predictions
129128

130129
def fit(self, completions, workdir=None, cache_dir=None, **kwargs):
130+
# check if training is from web hook and load tasks from api
131+
if kwargs.get('data'):
132+
project_id = kwargs['data']['project']['id']
133+
completions = get_annotated_dataset(project_id)
131134
input_texts = []
132135
output_labels, output_labels_idx = [], []
133136
label2idx = {l: i for i, l in enumerate(self.labels)}

label_studio_ml/examples/flair/ner_ml_backend.py

+7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import os
1010

1111
#writing class with inheretance
12+
from label_studio_ml.utils import get_annotated_dataset
13+
14+
1215
class SequenceTaggerModel(LabelStudioMLBase):
1316
def __init__(self, **kwargs):
1417
#initialize base class
@@ -87,6 +90,10 @@ def convert_to_ls_annotation(self, flair_sentences):
8790
return results
8891

8992
def fit(self, completions, workdir=None, **kwargs):
93+
# check if training is from web hook
94+
if kwargs.get('data'):
95+
project_id = kwargs['data']['project']['id']
96+
completions = get_annotated_dataset(project_id)
9097
#completions contain ALL the annotated samples.
9198
#train a model from scratch here.
9299
flair_sents = []

label_studio_ml/examples/mmdetection/mmdetection.py

+3
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def predict(self, tasks, **kwargs):
128128
'score': avg_score
129129
}]
130130

131+
def fit(self, completions, workdir=None, **kwargs):
132+
return {}
133+
131134

132135
def json_load(file, int_keys=False):
133136
with io.open(file, encoding='utf8') as f:

label_studio_ml/examples/ner/ner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def fit(
464464
warmup_steps=0, save_steps=50, dump_dataset=True, cache_dir='~/.heartex/cache', train_logs=None,
465465
**kwargs
466466
):
467-
# check if training is from web hook
467+
# check if training is from web hook and load tasks from api
468468
if kwargs.get('data'):
469469
project_id = kwargs['data']['project']['id']
470470
completions = get_annotated_dataset(project_id)

label_studio_ml/examples/pytorch_transfer_learning/pytorch_transfer_learning.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchvision import models, transforms
1515

1616
from label_studio_ml.model import LabelStudioMLBase
17-
from label_studio_ml.utils import get_single_tag_keys, get_choice, is_skipped, get_local_path
17+
from label_studio_ml.utils import get_single_tag_keys, get_choice, is_skipped, get_local_path, get_annotated_dataset
1818

1919
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
2020

@@ -177,6 +177,10 @@ def predict(self, tasks, **kwargs):
177177
return predictions
178178

179179
def fit(self, completions, workdir=None, batch_size=32, num_epochs=10, **kwargs):
180+
# check if training is from web hook and load tasks from api
181+
if kwargs.get('data'):
182+
project_id = kwargs['data']['project']['id']
183+
completions = get_annotated_dataset(project_id)
180184
image_urls, image_classes = [], []
181185
print('Collecting annotations...')
182186
for completion in completions:

label_studio_ml/examples/simple_text_classifier/simple_text_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _get_annotated_dataset(self, project_id):
102102
return json.loads(response.content)
103103

104104
def fit(self, annotations, workdir=None, **kwargs):
105-
# check if training is from web hook
105+
# check if training is from web hook and load tasks from api
106106
if kwargs.get('data'):
107107
project_id = kwargs['data']['project']['id']
108108
tasks = self._get_annotated_dataset(project_id)

label_studio_ml/examples/substring_matching/substring_matching.py

+4
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,7 @@ def _extract_meta(task):
9191
meta['start'] = task['value']['start']
9292
meta['end'] = task['value']['end']
9393
return meta
94+
95+
def fit(self, completions, workdir=None, **kwargs):
96+
# save some training outputs to the job result
97+
return {'random': random.randint(1, 10)}

label_studio_ml/examples/tensorflow/mobilenet_finetune.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from PIL import Image
88
from label_studio_ml.model import LabelStudioMLBase
9-
from label_studio_ml.utils import get_image_local_path, get_single_tag_keys, get_choice, is_skipped
9+
from label_studio_ml.utils import get_image_local_path, get_single_tag_keys, get_choice, is_skipped, \
10+
get_annotated_dataset
1011

1112
logger = logging.getLogger(__name__)
1213
feature_extractor_model = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4'
@@ -62,6 +63,10 @@ def predict(self, tasks, **kwargs):
6263
}]
6364

6465
def fit(self, completions, workdir=None, **kwargs):
66+
# check if training is from web hook and load tasks from api
67+
if kwargs.get('data'):
68+
project_id = kwargs['data']['project']['id']
69+
completions = get_annotated_dataset(project_id)
6570

6671
annotations = []
6772
for completion in completions:

label_studio_ml/examples/tesseract/tesseract.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import random
2+
13
from PIL import Image
24
import pytesseract as pt
35
from label_studio_ml.model import LabelStudioMLBase
@@ -74,3 +76,7 @@ def _extract_meta(task):
7476
meta["original_width"] = task['original_width']
7577
meta["original_height"] = task['original_height']
7678
return meta
79+
80+
def fit(self, completions, workdir=None, **kwargs):
81+
# save some training outputs to the job result
82+
return {'random': random.randint(1, 10)}

0 commit comments

Comments
 (0)