Skip to content

Commit 89dc9c6

Browse files
committed
修改了transformer中embedding导入的形式,之前是错误的
1 parent 0de4960 commit 89dc9c6

10 files changed

+290
-26
lines changed

__pycache__/model.cpython-36.pyc

-23 Bytes
Binary file not shown.

__pycache__/train_eval.cpython-36.pyc

84 Bytes
Binary file not shown.
3.21 KB
Binary file not shown.

__pycache__/utils.cpython-36.pyc

-23 Bytes
Binary file not shown.

load_model.ipynb

+137
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,143 @@
365365
"test_sentence(model, TEXT, sentence)"
366366
]
367367
},
368+
{
369+
"cell_type": "code",
370+
"execution_count": 24,
371+
"metadata": {},
372+
"outputs": [],
373+
"source": [
374+
"import codecs\n",
375+
"path=\"E:/data/word_nlp/cnews_data/bert_embedding\"\n",
376+
"\n",
377+
"lines = codecs.open(path, encoding=\"utf-8\")"
378+
]
379+
},
380+
{
381+
"cell_type": "code",
382+
"execution_count": 25,
383+
"metadata": {},
384+
"outputs": [],
385+
"source": [
386+
"\n",
387+
"vocab={'<unk>': 0, '<pad>': 1, ',': 2, '的': 3, '。': 4, '是':5}\n",
388+
"embedding_vec = [line.replace(\"\\n\", \"\") for line in lines if line.split(\" \")[0] in vocab]"
389+
]
390+
},
391+
{
392+
"cell_type": "code",
393+
"execution_count": 28,
394+
"metadata": {},
395+
"outputs": [
396+
{
397+
"data": {
398+
"text/plain": [
399+
"4"
400+
]
401+
},
402+
"execution_count": 28,
403+
"metadata": {},
404+
"output_type": "execute_result"
405+
}
406+
],
407+
"source": [
408+
"len(embedding_vec)\n",
409+
"#embedding_vec\n",
410+
"#embedding_vec"
411+
]
412+
},
413+
{
414+
"cell_type": "code",
415+
"execution_count": 21,
416+
"metadata": {},
417+
"outputs": [
418+
{
419+
"name": "stdout",
420+
"output_type": "stream",
421+
"text": [
422+
"21131\n"
423+
]
424+
}
425+
],
426+
"source": [
427+
"\n",
428+
"lines = codecs.open(path, encoding=\"utf-8\")\n",
429+
"embedding_vec = [line for line in lines ]\n",
430+
"print(len(embedding_vec))\n",
431+
"#embedding_vec"
432+
]
433+
},
434+
{
435+
"cell_type": "code",
436+
"execution_count": 23,
437+
"metadata": {},
438+
"outputs": [
439+
{
440+
"name": "stdout",
441+
"output_type": "stream",
442+
"text": [
443+
"\n",
444+
"\n",
445+
"\n",
446+
"\n"
447+
]
448+
}
449+
],
450+
"source": [
451+
"for one in embedding_vec:\n",
452+
" if one.split(\" \")[0] in vocab:\n",
453+
" print(one.split(\" \")[0])"
454+
]
455+
},
456+
{
457+
"cell_type": "code",
458+
"execution_count": 6,
459+
"metadata": {},
460+
"outputs": [],
461+
"source": [
462+
"#lines = codecs.open(path, encoding=\"utf-8\")\n",
463+
"#for line in list(lines)[0:10]:\n",
464+
"# print(line[0])"
465+
]
466+
},
467+
{
468+
"cell_type": "code",
469+
"execution_count": 17,
470+
"metadata": {},
471+
"outputs": [
472+
{
473+
"data": {
474+
"text/plain": [
475+
"4"
476+
]
477+
},
478+
"execution_count": 17,
479+
"metadata": {},
480+
"output_type": "execute_result"
481+
}
482+
],
483+
"source": [
484+
"vocab[\"\"]"
485+
]
486+
},
487+
{
488+
"cell_type": "code",
489+
"execution_count": 19,
490+
"metadata": {},
491+
"outputs": [
492+
{
493+
"name": "stdout",
494+
"output_type": "stream",
495+
"text": [
496+
"yes\n"
497+
]
498+
}
499+
],
500+
"source": [
501+
"if \"\" in vocab:\n",
502+
" print(\"yes\")"
503+
]
504+
},
368505
{
369506
"cell_type": "code",
370507
"execution_count": null,

main_zh.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from torch.autograd import Variable as V
1111

1212
from model import RNNModel
13+
from transformer import TransformerModel
1314

1415
class Config(object):
1516
def __init__(self):
1617
self.model_name="lm_model"
17-
self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
18+
#self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
19+
self.data_ori="E:/data/word_nlp/cnews_data/"
1820
self.train_path="train_0.csv"
1921
self.valid_path="train_0.csv"
2022
self.test_path="test_100.csv"
@@ -32,12 +34,18 @@ def __init__(self):
3234
self.hidden_size=200
3335
self.nlayers=1
3436
self.dropout=0.5
35-
self.epoch=2
37+
self.epoch=1
3638

3739
self.train_len=0
3840
self.test_len = 0
3941
self.valid_len = 0
4042
self.mode="train"
43+
44+
## transformer的参数
45+
self.dropout=0.5
46+
self.max_len=5000
47+
self.nhead=2
48+
4149
#data_path="E:/study_series/2020_3/re_write_classify/data/"
4250
#data_path="/mnt/data3/wuchunsheng/code/nlper/NLP_task/text_classification/my_classification_cnews/2020_3_30/text_classify/data/"
4351

@@ -47,8 +55,10 @@ def __init__(self):
4755

4856

4957
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50-
model = RNNModel(config, TEXT).to(device)
58+
#model = RNNModel(config, TEXT).to(device)
59+
60+
model=TransformerModel(config, TEXT).to(device)
5161

52-
train(config,model,train_iter, valid_iter, test_iter)
62+
#train(config,model,train_iter, valid_iter, test_iter)
5363

5464
test(config,model,TEXT, test_iter)## 测试的是一个正批量的

temp.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.nn import TransformerEncoder, TransformerEncoderLayer

test.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,41 @@
1111
from train_eval import test_sentence,load_model
1212

1313
from model import RNNModel
14-
14+
from transformer import TransformerModel
1515

1616
class Config(object):
1717
def __init__(self):
18-
self.model_name = "lm_model"
19-
self.data_ori = "/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
20-
self.train_path = "train_0.csv"
21-
self.valid_path = "train_0.csv"
22-
self.test_path = "test_100.csv"
23-
self.sen_max_length = 150
24-
# self.embedding_path = "need_bertembedding"
18+
self.model_name="lm_model"
19+
#self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
20+
self.data_ori="E:/data/word_nlp/cnews_data/"
21+
self.train_path="train_0.csv"
22+
self.valid_path="train_0.csv"
23+
self.test_path="test_100.csv"
24+
self.sen_max_length=150
25+
#self.embedding_path = "need_bertembedding"
2526
self.embedding_path = "bert_embedding"
26-
self.embedding_dim = 768
27-
self.vocab_maxsize = 4000
28-
self.vocab_minfreq = 10
29-
self.save_path = "lm_ckpt"
27+
self.embedding_dim=768
28+
self.vocab_maxsize=4000
29+
self.vocab_minfreq=10
30+
self.save_path="lm_ckpt"
3031

3132
self.batch_size = 64
3233
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3334

34-
self.hidden_size = 200
35-
self.nlayers = 1
36-
self.dropout = 0.5
37-
self.epoch = 2
35+
self.hidden_size=200
36+
self.nlayers=1
37+
self.dropout=0.5
38+
self.epoch=2
3839

39-
self.train_len = 0
40+
self.train_len=0
4041
self.test_len = 0
4142
self.valid_len = 0
43+
self.mode="test"
4244

43-
self.mode = "test"
45+
## transformer的参数
46+
self.dropout=0.5
47+
self.max_len=5000
48+
self.nhead=2
4449

4550

4651
# data_path="E:/study_series/2020_3/re_write_classify/data/"
@@ -49,10 +54,13 @@ def __init__(self):
4954
config = Config()
5055
train_iter, valid_iter, test_iter, TEXT = generate_data(config)
5156

52-
model = RNNModel(config, TEXT).to(config.device)
57+
#model = RNNModel(config, TEXT).to(config.device)
58+
model=TransformerModel(config, TEXT).to(config.device)
59+
5360
model =load_model(config, model)
5461

55-
sen="体育"
62+
sen="comment体育项目"
63+
sen="".join(['c', 'o', 'n', 't', 'e', 'x', 't', ',', 'l', 'a', 'b', 'e', 'l'])
5664
res=test_sentence(config, model ,TEXT, sen)
5765
print(res)
5866

train_eval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def train(config,model,train_iter, valid_iter,test_iter):
1717
for batch in train_iter:
1818
# reset the hidden state or else the model will try to backpropagate to the
1919
# beginning of the dataset, requiring lots of time and a lot of memory
20-
model.reset_history()
20+
if(model.mode_type!="transformer"):
21+
model.reset_history()
2122

2223
optimizer.zero_grad()
2324

@@ -42,7 +43,8 @@ def train(config,model,train_iter, valid_iter,test_iter):
4243
model.eval()
4344
# model.train()
4445
for batch in valid_iter:
45-
model.reset_history()
46+
if(model.mode_type!="transformer"):
47+
model.reset_history()
4648
text, targets = batch.text.to(config.device), batch.target.to(config.device)
4749
prediction = model(text).to(config.device)
4850
loss = criterion(prediction.view(-1, config.n_tokens), targets.view(-1)).to(config.device)
@@ -66,6 +68,7 @@ def test(config, model, TEXT, test_iter):
6668
print(len(inputs_word))
6769

6870
arrs = model(b.text.cuda()).cuda().data.cpu().numpy()
71+
print(arrs.shape)
6972
preds = word_ids_to_sentence(np.argmax(arrs, axis=2), TEXT.vocab)
7073

7174
print(preds)
@@ -76,6 +79,7 @@ def test_sentence(config, model, TEXT, sentence):
7679
inputs = inputs.view(-1, 1)
7780
# print(inputs.shape)
7881
arrs = model(inputs)
82+
print("arrs shape: ",arrs.shape)
7983
preds = word_ids_to_sentence(np.argmax(arrs.detach().cpu(), axis=2), TEXT.vocab)
8084

8185
return preds

0 commit comments

Comments
 (0)