|
1 | 1 | {
|
2 | 2 | "cells": [
|
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "ef6184e4", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Token Classification Benchmark " |
| 9 | + ] |
| 10 | + }, |
3 | 11 | {
|
4 | 12 | "cell_type": "markdown",
|
5 | 13 | "id": "fc2bb2f0",
|
|
18 | 26 | "name": "stderr",
|
19 | 27 | "output_type": "stream",
|
20 | 28 | "text": [
|
21 |
| - "2022-10-08 00:52:36.013979: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", |
22 |
| - "2022-10-08 00:52:36.014003: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" |
| 29 | + "2022-10-09 00:39:31.824063: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n" |
23 | 30 | ]
|
24 | 31 | }
|
25 | 32 | ],
|
26 | 33 | "source": [
|
27 | 34 | "import numpy as np\n",
|
28 |
| - "import string\n", |
29 | 35 | "import os \n",
|
30 | 36 | "from itertools import repeat \n",
|
31 | 37 | "from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline\n",
|
32 | 38 | "from cleanlab.rank import get_label_quality_scores as main_get_label_quality_scores\n",
|
33 | 39 | "from cleanlab.filter import find_label_issues as main_find_label_issues \n",
|
| 40 | + "from utils import readfile, get_probs, get_pred_probs \n", |
34 | 41 | "\n",
|
35 | 42 | "from cleanlab.internal.token_classification_utils import get_sentence, filter_sentence, mapping, merge_probs\n",
|
36 | 43 | "import matplotlib.pyplot as plt \n",
|
37 | 44 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
38 |
| - "from token_classification_utils import get_pred_probs\n", |
39 |
| - "import sklearn.metrics as metrics " |
| 45 | + "from sklearn import metrics " |
40 | 46 | ]
|
41 | 47 | },
|
42 | 48 | {
|
|
50 | 56 | {
|
51 | 57 | "cell_type": "code",
|
52 | 58 | "execution_count": 2,
|
53 |
| - "id": "2443bf37", |
54 |
| - "metadata": {}, |
55 |
| - "outputs": [], |
56 |
| - "source": [ |
57 |
| - "def readfile(filepath, sep=' '): \n", |
58 |
| - " \"\"\" \n", |
59 |
| - " Reads file in CoNLL format (IOB2) \n", |
60 |
| - " \"\"\"\n", |
61 |
| - " lines = open(filepath)\n", |
62 |
| - " \n", |
63 |
| - " data, sentence, label = [], [], []\n", |
64 |
| - " for line in lines:\n", |
65 |
| - " if len(line) == 0 or line.startswith('-DOCSTART') or line[0] == '\\n':\n", |
66 |
| - " if len(sentence) > 0:\n", |
67 |
| - " data.append((sentence, label))\n", |
68 |
| - " sentence, label = [], []\n", |
69 |
| - " continue\n", |
70 |
| - " splits = line.split(sep) \n", |
71 |
| - " word = splits[0]\n", |
72 |
| - " if len(word) > 0 and word[0].isalpha() and word.isupper():\n", |
73 |
| - " word = word[0] + word[1:].lower()\n", |
74 |
| - " sentence.append(word)\n", |
75 |
| - " label.append(entity_map[splits[-1][:-1]])\n", |
76 |
| - "\n", |
77 |
| - " if len(sentence) > 0:\n", |
78 |
| - " data.append((sentence, label))\n", |
79 |
| - " \n", |
80 |
| - " given_words = [d[0] for d in data] \n", |
81 |
| - " given_labels = [d[1] for d in data] \n", |
82 |
| - " \n", |
83 |
| - " return given_words, given_labels " |
84 |
| - ] |
85 |
| - }, |
86 |
| - { |
87 |
| - "cell_type": "code", |
88 |
| - "execution_count": 3, |
89 | 59 | "id": "68106a8d",
|
90 | 60 | "metadata": {},
|
91 | 61 | "outputs": [],
|
92 | 62 | "source": [
|
93 | 63 | "filepath = 'data/conll.txt'\n",
|
94 |
| - "entities = ['O', 'B-MISC', 'I-MISC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']\n", |
95 |
| - "entity_map = {entity: i for i, entity in enumerate(entities)} \n", |
96 |
| - "\n", |
97 | 64 | "given_words, given_labels_unmerged = readfile(filepath) \n",
|
98 | 65 | "sentences = list(map(get_sentence, given_words)) \n",
|
99 | 66 | "\n",
|
|
104 | 71 | },
|
105 | 72 | {
|
106 | 73 | "cell_type": "code",
|
107 |
| - "execution_count": 4, |
| 74 | + "execution_count": 3, |
108 | 75 | "id": "90ca23e4",
|
109 | 76 | "metadata": {},
|
110 | 77 | "outputs": [],
|
|
138 | 105 | },
|
139 | 106 | {
|
140 | 107 | "cell_type": "code",
|
141 |
| - "execution_count": 5, |
| 108 | + "execution_count": 4, |
142 | 109 | "id": "0e0100aa",
|
143 | 110 | "metadata": {},
|
144 | 111 | "outputs": [],
|
|
161 | 128 | },
|
162 | 129 | {
|
163 | 130 | "cell_type": "code",
|
164 |
| - "execution_count": 6, |
| 131 | + "execution_count": 5, |
165 | 132 | "id": "b9ba86c2",
|
166 | 133 | "metadata": {},
|
167 | 134 | "outputs": [],
|
|
280 | 247 | },
|
281 | 248 | {
|
282 | 249 | "cell_type": "code",
|
283 |
| - "execution_count": 7, |
284 |
| - "id": "edc20cee", |
285 |
| - "metadata": {}, |
286 |
| - "outputs": [], |
287 |
| - "source": [ |
288 |
| - "def get_probs(sentence): \n", |
289 |
| - " ''' \n", |
290 |
| - " @parameter sentence: string \n", |
291 |
| - " \n", |
292 |
| - " @return probs: np.array of shape (n, m) \n", |
293 |
| - " where n is the number of tokens in the sentence and m is the number of classes. \n", |
294 |
| - " probs[i][j] is the probability that the i'th sentence belongs to entity j. The \n", |
295 |
| - " first and last probs are excluded because the first and last tokens are always \n", |
296 |
| - " [CLS] and [SEP], to represent the start and end of the sentence, respectively. \n", |
297 |
| - " '''\n", |
298 |
| - " def softmax(logit): \n", |
299 |
| - " return np.exp(logit) / np.sum(np.exp(logit)) \n", |
300 |
| - " \n", |
301 |
| - " forward = pipe.forward(pipe.preprocess(sentence)) \n", |
302 |
| - " logits = forward['logits'][0].numpy() \n", |
303 |
| - " probs = np.array([softmax(logit) for logit in logits]) \n", |
304 |
| - " probs = probs[1:-1] \n", |
305 |
| - " return probs " |
306 |
| - ] |
307 |
| - }, |
308 |
| - { |
309 |
| - "cell_type": "code", |
310 |
| - "execution_count": 8, |
| 250 | + "execution_count": 6, |
311 | 251 | "id": "c09d4762",
|
312 | 252 | "metadata": {},
|
313 | 253 | "outputs": [],
|
|
326 | 266 | },
|
327 | 267 | {
|
328 | 268 | "cell_type": "code",
|
329 |
| - "execution_count": 9, |
| 269 | + "execution_count": 7, |
330 | 270 | "id": "5740ff1b",
|
331 | 271 | "metadata": {},
|
332 | 272 | "outputs": [],
|
|
371 | 311 | },
|
372 | 312 | {
|
373 | 313 | "cell_type": "code",
|
374 |
| - "execution_count": 10, |
| 314 | + "execution_count": 8, |
375 | 315 | "id": "7328bae5",
|
376 | 316 | "metadata": {},
|
377 | 317 | "outputs": [],
|
|
406 | 346 | },
|
407 | 347 | {
|
408 | 348 | "cell_type": "code",
|
409 |
| - "execution_count": 11, |
| 349 | + "execution_count": 9, |
410 | 350 | "id": "a991528f",
|
411 | 351 | "metadata": {},
|
412 | 352 | "outputs": [],
|
|
473 | 413 | },
|
474 | 414 | {
|
475 | 415 | "cell_type": "code",
|
476 |
| - "execution_count": 12, |
| 416 | + "execution_count": 10, |
477 | 417 | "id": "7be9c3b9",
|
478 | 418 | "metadata": {},
|
479 | 419 | "outputs": [],
|
|
485 | 425 | "\n",
|
486 | 426 | "sentence_tokens = [[tokenizer.decode(token) for token in tokenizer(sentence)['input_ids']] for sentence in sentences] \n",
|
487 | 427 | "sentence_tokens = [[token.replace('#', '') for token in sentence_token][1:-1] for sentence_token in sentence_tokens] \n",
|
488 |
| - "sentence_probs = list(map(get_probs, sentences)) \n", |
| 428 | + "sentence_probs = list(map(get_probs, repeat(pipe), sentences)) \n", |
489 | 429 | "\n",
|
490 | 430 | "model_maps = given_maps \n",
|
491 | 431 | "sentence_probs = list(map(merge_probs, sentence_probs, repeat(model_maps)))\n",
|
|
502 | 442 | },
|
503 | 443 | {
|
504 | 444 | "cell_type": "code",
|
505 |
| - "execution_count": 13, |
| 445 | + "execution_count": 11, |
506 | 446 | "id": "7a7d1982",
|
507 | 447 | "metadata": {},
|
508 | 448 | "outputs": [
|
|
525 | 465 | },
|
526 | 466 | {
|
527 | 467 | "cell_type": "code",
|
528 |
| - "execution_count": 14, |
| 468 | + "execution_count": 12, |
529 | 469 | "id": "296eae74",
|
530 | 470 | "metadata": {},
|
531 | 471 | "outputs": [
|
|
548 | 488 | },
|
549 | 489 | {
|
550 | 490 | "cell_type": "code",
|
551 |
| - "execution_count": 15, |
| 491 | + "execution_count": 13, |
552 | 492 | "id": "f76c0131",
|
553 | 493 | "metadata": {},
|
554 | 494 | "outputs": [
|
|
571 | 511 | },
|
572 | 512 | {
|
573 | 513 | "cell_type": "code",
|
574 |
| - "execution_count": 16, |
| 514 | + "execution_count": 14, |
575 | 515 | "id": "a52b93ba",
|
576 | 516 | "metadata": {},
|
577 | 517 | "outputs": [
|
|
594 | 534 | },
|
595 | 535 | {
|
596 | 536 | "cell_type": "code",
|
597 |
| - "execution_count": 17, |
| 537 | + "execution_count": 15, |
598 | 538 | "id": "faab2641",
|
599 | 539 | "metadata": {},
|
600 | 540 | "outputs": [
|
|
617 | 557 | },
|
618 | 558 | {
|
619 | 559 | "cell_type": "code",
|
620 |
| - "execution_count": 18, |
| 560 | + "execution_count": 16, |
621 | 561 | "id": "ad742241",
|
622 | 562 | "metadata": {},
|
623 | 563 | "outputs": [
|
|
648 | 588 | },
|
649 | 589 | {
|
650 | 590 | "cell_type": "code",
|
651 |
| - "execution_count": 19, |
| 591 | + "execution_count": 17, |
652 | 592 | "id": "2a873a0d",
|
653 | 593 | "metadata": {},
|
654 | 594 | "outputs": [],
|
|
660 | 600 | "\n",
|
661 | 601 | "sentence_tokens = [[tokenizer.decode(token) for token in tokenizer(sentence)['input_ids']] for sentence in sentences] \n",
|
662 | 602 | "sentence_tokens = [[token.replace('#', '') for token in sentence_token][1:-1] for sentence_token in sentence_tokens] \n",
|
663 |
| - "sentence_probs = list(map(get_probs, sentences)) \n", |
| 603 | + "sentence_probs = list(map(get_probs, repeat(pipe), sentences)) \n", |
664 | 604 | "\n",
|
665 | 605 | "model_maps = [4, 1, 3, 4, 1, 3, 2, 0] \n",
|
666 | 606 | "sentence_probs = list(map(merge_probs, sentence_probs, repeat(model_maps)))\n",
|
|
677 | 617 | },
|
678 | 618 | {
|
679 | 619 | "cell_type": "code",
|
680 |
| - "execution_count": 20, |
| 620 | + "execution_count": 18, |
681 | 621 | "id": "568de001",
|
682 | 622 | "metadata": {},
|
683 | 623 | "outputs": [
|
|
700 | 640 | },
|
701 | 641 | {
|
702 | 642 | "cell_type": "code",
|
703 |
| - "execution_count": 21, |
| 643 | + "execution_count": 19, |
704 | 644 | "id": "a9a43fed",
|
705 | 645 | "metadata": {},
|
706 | 646 | "outputs": [
|
|
723 | 663 | },
|
724 | 664 | {
|
725 | 665 | "cell_type": "code",
|
726 |
| - "execution_count": 22, |
| 666 | + "execution_count": 20, |
727 | 667 | "id": "268ffb85",
|
728 | 668 | "metadata": {},
|
729 | 669 | "outputs": [
|
|
746 | 686 | },
|
747 | 687 | {
|
748 | 688 | "cell_type": "code",
|
749 |
| - "execution_count": 23, |
| 689 | + "execution_count": 21, |
750 | 690 | "id": "1265eb42",
|
751 | 691 | "metadata": {},
|
752 | 692 | "outputs": [
|
|
769 | 709 | },
|
770 | 710 | {
|
771 | 711 | "cell_type": "code",
|
772 |
| - "execution_count": 24, |
| 712 | + "execution_count": 22, |
773 | 713 | "id": "d89485fa",
|
774 | 714 | "metadata": {},
|
775 | 715 | "outputs": [
|
|
792 | 732 | },
|
793 | 733 | {
|
794 | 734 | "cell_type": "code",
|
795 |
| - "execution_count": 25, |
| 735 | + "execution_count": 23, |
796 | 736 | "id": "d4e24ff3",
|
797 | 737 | "metadata": {},
|
798 | 738 | "outputs": [
|
|
823 | 763 | },
|
824 | 764 | {
|
825 | 765 | "cell_type": "code",
|
826 |
| - "execution_count": 26, |
| 766 | + "execution_count": 24, |
827 | 767 | "id": "ae613cfe",
|
828 | 768 | "metadata": {},
|
829 | 769 | "outputs": [],
|
|
835 | 775 | "\n",
|
836 | 776 | "sentence_tokens = [[tokenizer.decode(token) for token in tokenizer(sentence)['input_ids']] for sentence in sentences] \n",
|
837 | 777 | "sentence_tokens = [[token.replace('#', '') for token in sentence_token][1:-1] for sentence_token in sentence_tokens] \n",
|
838 |
| - "sentence_probs = list(map(get_probs, sentences)) \n", |
| 778 | + "sentence_probs = list(map(get_probs, repeat(pipe), sentences)) \n", |
839 | 779 | "pred_probs = list(map(get_pred_probs, sentence_probs, sentence_tokens, given_words)) \n",
|
840 | 780 | "\n",
|
841 | 781 | "statistics = {(method, cleanlab_method): evaluate(method, cleanlab_method, pred_probs, error_unmerged, unmerged=True) \n",
|
|
849 | 789 | },
|
850 | 790 | {
|
851 | 791 | "cell_type": "code",
|
852 |
| - "execution_count": 27, |
| 792 | + "execution_count": 25, |
853 | 793 | "id": "60066cc8",
|
854 | 794 | "metadata": {},
|
855 | 795 | "outputs": [
|
|
872 | 812 | },
|
873 | 813 | {
|
874 | 814 | "cell_type": "code",
|
875 |
| - "execution_count": 28, |
| 815 | + "execution_count": 26, |
876 | 816 | "id": "c258b3e8",
|
877 | 817 | "metadata": {},
|
878 | 818 | "outputs": [
|
|
895 | 835 | },
|
896 | 836 | {
|
897 | 837 | "cell_type": "code",
|
898 |
| - "execution_count": 29, |
| 838 | + "execution_count": 27, |
899 | 839 | "id": "19dcc23c",
|
900 | 840 | "metadata": {},
|
901 | 841 | "outputs": [
|
|
918 | 858 | },
|
919 | 859 | {
|
920 | 860 | "cell_type": "code",
|
921 |
| - "execution_count": 30, |
| 861 | + "execution_count": 28, |
922 | 862 | "id": "d576c6bf",
|
923 | 863 | "metadata": {},
|
924 | 864 | "outputs": [
|
|
941 | 881 | },
|
942 | 882 | {
|
943 | 883 | "cell_type": "code",
|
944 |
| - "execution_count": 31, |
| 884 | + "execution_count": 29, |
945 | 885 | "id": "ec4ba4de",
|
946 | 886 | "metadata": {},
|
947 | 887 | "outputs": [
|
|
964 | 904 | },
|
965 | 905 | {
|
966 | 906 | "cell_type": "code",
|
967 |
| - "execution_count": 32, |
| 907 | + "execution_count": 30, |
968 | 908 | "id": "a052f38b",
|
969 | 909 | "metadata": {},
|
970 | 910 | "outputs": [
|
|
0 commit comments