Skip to content

Commit 7fafa12

Browse files
committed
Add token-classification benchmark and token-level benchmark
1 parent 0b10362 commit 7fafa12

File tree

3 files changed

+462
-102
lines changed

3 files changed

+462
-102
lines changed

token-classification-benchmark.ipynb

+42-102
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "ef6184e4",
6+
"metadata": {},
7+
"source": [
8+
"# Token Classification Benchmark "
9+
]
10+
},
311
{
412
"cell_type": "markdown",
513
"id": "fc2bb2f0",
@@ -18,25 +26,23 @@
1826
"name": "stderr",
1927
"output_type": "stream",
2028
"text": [
21-
"2022-10-08 00:52:36.013979: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
22-
"2022-10-08 00:52:36.014003: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
29+
"2022-10-09 00:39:31.824063: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
2330
]
2431
}
2532
],
2633
"source": [
2734
"import numpy as np\n",
28-
"import string\n",
2935
"import os \n",
3036
"from itertools import repeat \n",
3137
"from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline\n",
3238
"from cleanlab.rank import get_label_quality_scores as main_get_label_quality_scores\n",
3339
"from cleanlab.filter import find_label_issues as main_find_label_issues \n",
40+
"from utils import readfile, get_probs, get_pred_probs \n",
3441
"\n",
3542
"from cleanlab.internal.token_classification_utils import get_sentence, filter_sentence, mapping, merge_probs\n",
3643
"import matplotlib.pyplot as plt \n",
3744
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
38-
"from token_classification_utils import get_pred_probs\n",
39-
"import sklearn.metrics as metrics "
45+
"from sklearn import metrics "
4046
]
4147
},
4248
{
@@ -50,50 +56,11 @@
5056
{
5157
"cell_type": "code",
5258
"execution_count": 2,
53-
"id": "2443bf37",
54-
"metadata": {},
55-
"outputs": [],
56-
"source": [
57-
"def readfile(filepath, sep=' '): \n",
58-
" \"\"\" \n",
59-
" Reads file in CoNLL format (IOB2) \n",
60-
" \"\"\"\n",
61-
" lines = open(filepath)\n",
62-
" \n",
63-
" data, sentence, label = [], [], []\n",
64-
" for line in lines:\n",
65-
" if len(line) == 0 or line.startswith('-DOCSTART') or line[0] == '\\n':\n",
66-
" if len(sentence) > 0:\n",
67-
" data.append((sentence, label))\n",
68-
" sentence, label = [], []\n",
69-
" continue\n",
70-
" splits = line.split(sep) \n",
71-
" word = splits[0]\n",
72-
" if len(word) > 0 and word[0].isalpha() and word.isupper():\n",
73-
" word = word[0] + word[1:].lower()\n",
74-
" sentence.append(word)\n",
75-
" label.append(entity_map[splits[-1][:-1]])\n",
76-
"\n",
77-
" if len(sentence) > 0:\n",
78-
" data.append((sentence, label))\n",
79-
" \n",
80-
" given_words = [d[0] for d in data] \n",
81-
" given_labels = [d[1] for d in data] \n",
82-
" \n",
83-
" return given_words, given_labels "
84-
]
85-
},
86-
{
87-
"cell_type": "code",
88-
"execution_count": 3,
8959
"id": "68106a8d",
9060
"metadata": {},
9161
"outputs": [],
9262
"source": [
9363
"filepath = 'data/conll.txt'\n",
94-
"entities = ['O', 'B-MISC', 'I-MISC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']\n",
95-
"entity_map = {entity: i for i, entity in enumerate(entities)} \n",
96-
"\n",
9764
"given_words, given_labels_unmerged = readfile(filepath) \n",
9865
"sentences = list(map(get_sentence, given_words)) \n",
9966
"\n",
@@ -104,7 +71,7 @@
10471
},
10572
{
10673
"cell_type": "code",
107-
"execution_count": 4,
74+
"execution_count": 3,
10875
"id": "90ca23e4",
10976
"metadata": {},
11077
"outputs": [],
@@ -138,7 +105,7 @@
138105
},
139106
{
140107
"cell_type": "code",
141-
"execution_count": 5,
108+
"execution_count": 4,
142109
"id": "0e0100aa",
143110
"metadata": {},
144111
"outputs": [],
@@ -161,7 +128,7 @@
161128
},
162129
{
163130
"cell_type": "code",
164-
"execution_count": 6,
131+
"execution_count": 5,
165132
"id": "b9ba86c2",
166133
"metadata": {},
167134
"outputs": [],
@@ -280,34 +247,7 @@
280247
},
281248
{
282249
"cell_type": "code",
283-
"execution_count": 7,
284-
"id": "edc20cee",
285-
"metadata": {},
286-
"outputs": [],
287-
"source": [
288-
"def get_probs(sentence): \n",
289-
" ''' \n",
290-
" @parameter sentence: string \n",
291-
" \n",
292-
" @return probs: np.array of shape (n, m) \n",
293-
" where n is the number of tokens in the sentence and m is the number of classes. \n",
294-
" probs[i][j] is the probability that the i'th sentence belongs to entity j. The \n",
295-
" first and last probs are excluded because the first and last tokens are always \n",
296-
" [CLS] and [SEP], to represent the start and end of the sentence, respectively. \n",
297-
" '''\n",
298-
" def softmax(logit): \n",
299-
" return np.exp(logit) / np.sum(np.exp(logit)) \n",
300-
" \n",
301-
" forward = pipe.forward(pipe.preprocess(sentence)) \n",
302-
" logits = forward['logits'][0].numpy() \n",
303-
" probs = np.array([softmax(logit) for logit in logits]) \n",
304-
" probs = probs[1:-1] \n",
305-
" return probs "
306-
]
307-
},
308-
{
309-
"cell_type": "code",
310-
"execution_count": 8,
250+
"execution_count": 6,
311251
"id": "c09d4762",
312252
"metadata": {},
313253
"outputs": [],
@@ -326,7 +266,7 @@
326266
},
327267
{
328268
"cell_type": "code",
329-
"execution_count": 9,
269+
"execution_count": 7,
330270
"id": "5740ff1b",
331271
"metadata": {},
332272
"outputs": [],
@@ -371,7 +311,7 @@
371311
},
372312
{
373313
"cell_type": "code",
374-
"execution_count": 10,
314+
"execution_count": 8,
375315
"id": "7328bae5",
376316
"metadata": {},
377317
"outputs": [],
@@ -406,7 +346,7 @@
406346
},
407347
{
408348
"cell_type": "code",
409-
"execution_count": 11,
349+
"execution_count": 9,
410350
"id": "a991528f",
411351
"metadata": {},
412352
"outputs": [],
@@ -473,7 +413,7 @@
473413
},
474414
{
475415
"cell_type": "code",
476-
"execution_count": 12,
416+
"execution_count": 10,
477417
"id": "7be9c3b9",
478418
"metadata": {},
479419
"outputs": [],
@@ -485,7 +425,7 @@
485425
"\n",
486426
"sentence_tokens = [[tokenizer.decode(token) for token in tokenizer(sentence)['input_ids']] for sentence in sentences] \n",
487427
"sentence_tokens = [[token.replace('#', '') for token in sentence_token][1:-1] for sentence_token in sentence_tokens] \n",
488-
"sentence_probs = list(map(get_probs, sentences)) \n",
428+
"sentence_probs = list(map(get_probs, repeat(pipe), sentences)) \n",
489429
"\n",
490430
"model_maps = given_maps \n",
491431
"sentence_probs = list(map(merge_probs, sentence_probs, repeat(model_maps)))\n",
@@ -502,7 +442,7 @@
502442
},
503443
{
504444
"cell_type": "code",
505-
"execution_count": 13,
445+
"execution_count": 11,
506446
"id": "7a7d1982",
507447
"metadata": {},
508448
"outputs": [
@@ -525,7 +465,7 @@
525465
},
526466
{
527467
"cell_type": "code",
528-
"execution_count": 14,
468+
"execution_count": 12,
529469
"id": "296eae74",
530470
"metadata": {},
531471
"outputs": [
@@ -548,7 +488,7 @@
548488
},
549489
{
550490
"cell_type": "code",
551-
"execution_count": 15,
491+
"execution_count": 13,
552492
"id": "f76c0131",
553493
"metadata": {},
554494
"outputs": [
@@ -571,7 +511,7 @@
571511
},
572512
{
573513
"cell_type": "code",
574-
"execution_count": 16,
514+
"execution_count": 14,
575515
"id": "a52b93ba",
576516
"metadata": {},
577517
"outputs": [
@@ -594,7 +534,7 @@
594534
},
595535
{
596536
"cell_type": "code",
597-
"execution_count": 17,
537+
"execution_count": 15,
598538
"id": "faab2641",
599539
"metadata": {},
600540
"outputs": [
@@ -617,7 +557,7 @@
617557
},
618558
{
619559
"cell_type": "code",
620-
"execution_count": 18,
560+
"execution_count": 16,
621561
"id": "ad742241",
622562
"metadata": {},
623563
"outputs": [
@@ -648,7 +588,7 @@
648588
},
649589
{
650590
"cell_type": "code",
651-
"execution_count": 19,
591+
"execution_count": 17,
652592
"id": "2a873a0d",
653593
"metadata": {},
654594
"outputs": [],
@@ -660,7 +600,7 @@
660600
"\n",
661601
"sentence_tokens = [[tokenizer.decode(token) for token in tokenizer(sentence)['input_ids']] for sentence in sentences] \n",
662602
"sentence_tokens = [[token.replace('#', '') for token in sentence_token][1:-1] for sentence_token in sentence_tokens] \n",
663-
"sentence_probs = list(map(get_probs, sentences)) \n",
603+
"sentence_probs = list(map(get_probs, repeat(pipe), sentences)) \n",
664604
"\n",
665605
"model_maps = [4, 1, 3, 4, 1, 3, 2, 0] \n",
666606
"sentence_probs = list(map(merge_probs, sentence_probs, repeat(model_maps)))\n",
@@ -677,7 +617,7 @@
677617
},
678618
{
679619
"cell_type": "code",
680-
"execution_count": 20,
620+
"execution_count": 18,
681621
"id": "568de001",
682622
"metadata": {},
683623
"outputs": [
@@ -700,7 +640,7 @@
700640
},
701641
{
702642
"cell_type": "code",
703-
"execution_count": 21,
643+
"execution_count": 19,
704644
"id": "a9a43fed",
705645
"metadata": {},
706646
"outputs": [
@@ -723,7 +663,7 @@
723663
},
724664
{
725665
"cell_type": "code",
726-
"execution_count": 22,
666+
"execution_count": 20,
727667
"id": "268ffb85",
728668
"metadata": {},
729669
"outputs": [
@@ -746,7 +686,7 @@
746686
},
747687
{
748688
"cell_type": "code",
749-
"execution_count": 23,
689+
"execution_count": 21,
750690
"id": "1265eb42",
751691
"metadata": {},
752692
"outputs": [
@@ -769,7 +709,7 @@
769709
},
770710
{
771711
"cell_type": "code",
772-
"execution_count": 24,
712+
"execution_count": 22,
773713
"id": "d89485fa",
774714
"metadata": {},
775715
"outputs": [
@@ -792,7 +732,7 @@
792732
},
793733
{
794734
"cell_type": "code",
795-
"execution_count": 25,
735+
"execution_count": 23,
796736
"id": "d4e24ff3",
797737
"metadata": {},
798738
"outputs": [
@@ -823,7 +763,7 @@
823763
},
824764
{
825765
"cell_type": "code",
826-
"execution_count": 26,
766+
"execution_count": 24,
827767
"id": "ae613cfe",
828768
"metadata": {},
829769
"outputs": [],
@@ -835,7 +775,7 @@
835775
"\n",
836776
"sentence_tokens = [[tokenizer.decode(token) for token in tokenizer(sentence)['input_ids']] for sentence in sentences] \n",
837777
"sentence_tokens = [[token.replace('#', '') for token in sentence_token][1:-1] for sentence_token in sentence_tokens] \n",
838-
"sentence_probs = list(map(get_probs, sentences)) \n",
778+
"sentence_probs = list(map(get_probs, repeat(pipe), sentences)) \n",
839779
"pred_probs = list(map(get_pred_probs, sentence_probs, sentence_tokens, given_words)) \n",
840780
"\n",
841781
"statistics = {(method, cleanlab_method): evaluate(method, cleanlab_method, pred_probs, error_unmerged, unmerged=True) \n",
@@ -849,7 +789,7 @@
849789
},
850790
{
851791
"cell_type": "code",
852-
"execution_count": 27,
792+
"execution_count": 25,
853793
"id": "60066cc8",
854794
"metadata": {},
855795
"outputs": [
@@ -872,7 +812,7 @@
872812
},
873813
{
874814
"cell_type": "code",
875-
"execution_count": 28,
815+
"execution_count": 26,
876816
"id": "c258b3e8",
877817
"metadata": {},
878818
"outputs": [
@@ -895,7 +835,7 @@
895835
},
896836
{
897837
"cell_type": "code",
898-
"execution_count": 29,
838+
"execution_count": 27,
899839
"id": "19dcc23c",
900840
"metadata": {},
901841
"outputs": [
@@ -918,7 +858,7 @@
918858
},
919859
{
920860
"cell_type": "code",
921-
"execution_count": 30,
861+
"execution_count": 28,
922862
"id": "d576c6bf",
923863
"metadata": {},
924864
"outputs": [
@@ -941,7 +881,7 @@
941881
},
942882
{
943883
"cell_type": "code",
944-
"execution_count": 31,
884+
"execution_count": 29,
945885
"id": "ec4ba4de",
946886
"metadata": {},
947887
"outputs": [
@@ -964,7 +904,7 @@
964904
},
965905
{
966906
"cell_type": "code",
967-
"execution_count": 32,
907+
"execution_count": 30,
968908
"id": "a052f38b",
969909
"metadata": {},
970910
"outputs": [

0 commit comments

Comments
 (0)