-
Notifications
You must be signed in to change notification settings - Fork 449
/
Copy pathtest_trainer.py
54 lines (45 loc) · 2.08 KB
/
test_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import unittest
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling
class TestTrainer(unittest.TestCase):
def test_case_1(self):
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl",
"loss": Loss(None),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}
trainer = SeqLabelTrainer(**args)
train_data = [
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)
data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)
model = SeqLabeling(args)
trainer.train(network=model, train_data=data_set, dev_data=data_set)
# If this can run, everything is OK.
os.system("rm -rf save")
print("pickle path deleted")