Skip to content

Commit 83b8210

Browse files
committed
修改了transformer的log_max部分
1 parent 89dc9c6 commit 83b8210

8 files changed

+77
-23
lines changed

__pycache__/model.cpython-36.pyc

23 Bytes
Binary file not shown.

__pycache__/train_eval.cpython-36.pyc

828 Bytes
Binary file not shown.
47 Bytes
Binary file not shown.

__pycache__/utils.cpython-36.pyc

23 Bytes
Binary file not shown.

main_zh.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchtext import data
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from train_eval import train,test
8+
from train_eval import train,test, test_one_sentence
99

1010
from torch.autograd import Variable as V
1111

@@ -15,8 +15,8 @@
1515
class Config(object):
1616
def __init__(self):
1717
self.model_name="lm_model"
18-
#self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
19-
self.data_ori="E:/data/word_nlp/cnews_data/"
18+
self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
19+
#self.data_ori="E:/data/word_nlp/cnews_data/"
2020
self.train_path="train_0.csv"
2121
self.valid_path="train_0.csv"
2222
self.test_path="test_100.csv"
@@ -34,7 +34,7 @@ def __init__(self):
3434
self.hidden_size=200
3535
self.nlayers=1
3636
self.dropout=0.5
37-
self.epoch=1
37+
self.epoch=100
3838

3939
self.train_len=0
4040
self.test_len = 0
@@ -59,6 +59,9 @@ def __init__(self):
5959

6060
model=TransformerModel(config, TEXT).to(device)
6161

62-
#train(config,model,train_iter, valid_iter, test_iter)
62+
train(config,model,train_iter, valid_iter, test_iter)
6363

64-
test(config,model,TEXT, test_iter)## 测试的是一个正批量的
64+
#res=test(config,model,TEXT, test_iter)## 测试的是一个正批量的
65+
#print(res)
66+
res=test_one_sentence(config, model, TEXT, test_iter)
67+
print(res)

test.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
class Config(object):
1717
def __init__(self):
1818
self.model_name="lm_model"
19-
#self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
20-
self.data_ori="E:/data/word_nlp/cnews_data/"
19+
self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
20+
#self.data_ori="E:/data/word_nlp/cnews_data/"
2121
self.train_path="train_0.csv"
2222
self.valid_path="train_0.csv"
2323
self.test_path="test_100.csv"
@@ -35,7 +35,7 @@ def __init__(self):
3535
self.hidden_size=200
3636
self.nlayers=1
3737
self.dropout=0.5
38-
self.epoch=2
38+
self.epoch=20
3939

4040
self.train_len=0
4141
self.test_len = 0
@@ -59,8 +59,24 @@ def __init__(self):
5959

6060
model =load_model(config, model)
6161

62-
sen="comment体育项目"
63-
sen="".join(['c', 'o', 'n', 't', 'e', 'x', 't', ',', 'l', 'a', 'b', 'e', 'l'])
64-
res=test_sentence(config, model ,TEXT, sen)
65-
print(res)
66-
62+
#sen="目"*50
63+
sen="体育快讯"
64+
#sen="".join(['c', 'o', 'n', 't', 'e', 'x', 't', ',', 'l', 'a', 'b', 'e', 'l'])
65+
#res=test_sentence(config, model ,TEXT, sen)
66+
#print(sen)
67+
#print(res)
68+
#res=test(config,model,TEXT, test_iter)
69+
#print(res)
70+
print("=========================")
71+
sen="篮球"
72+
#sen="体育"
73+
sen_ori=sen
74+
while(len(sen)<20):
75+
print("输入文本: ",sen)
76+
sen_pred=" ".join(test_sentence(config,model, TEXT,sen))
77+
sen+=sen_pred[1:]
78+
sen=sen.replace(" ","")
79+
print("文本生成: ", sen)
80+
print("*"*20)
81+
print("输入: ", sen_ori)
82+
print("生成: ", sen)

train_eval.py

+36-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
def train(config,model,train_iter, valid_iter,test_iter):
99

10-
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.7, 0.99))
10+
#optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.7, 0.99))
11+
optimizer = optim.Adam(model.parameters(), lr=1e-2, betas=(0.7, 0.99))
1112
criterion = nn.CrossEntropyLoss()
1213

1314
for epoch in range(1, config.epoch + 1):
@@ -55,32 +56,60 @@ def train(config,model,train_iter, valid_iter,test_iter):
5556
torch.save(model.state_dict(), config.save_path)
5657

5758
def test(config, model, TEXT, test_iter):
58-
59+
print("save_path: ", config.save_path)
60+
model.load_state_dict(torch.load(config.save_path))
61+
#print(model)
5962
b = next(iter(test_iter))
60-
print("输入: ", b.text[0])
63+
print(b.text.shape)
64+
#print("输入: ", b.text[0])
6165
#print("输入的句子: ", word_ids_to_sentence(b.text[0],TEXT.vocab))
6266
#print("", word_sentence_to_ids(b.text[0],TEXT.vocab))
6367

6468

69+
print("单条数据: ",b.text[:,1].shape)
6570

6671
inputs_word = word_ids_to_sentence(b.text.cuda().data, TEXT.vocab)
67-
print(inputs_word)
68-
print(len(inputs_word))
72+
#print(inputs_word)
73+
#print(len(inputs_word))
6974

7075
arrs = model(b.text.cuda()).cuda().data.cpu().numpy()
7176
print(arrs.shape)
7277
preds = word_ids_to_sentence(np.argmax(arrs, axis=2), TEXT.vocab)
78+
return preds
79+
#print(preds)
80+
81+
def test_one_sentence(config, model , TEXT,test_iter):
82+
print("save_path: ", config.save_path)
83+
model.load_state_dict(torch.load(config.save_path))
84+
#print(model)
85+
b = next(iter(test_iter))
86+
print(b.text.shape)
87+
print("单条数据: ",b.text[:,1].shape)
88+
print("单条数据: ",b.text[:,1].view(-1,1).shape)
89+
inputs_word = word_ids_to_sentence(b.text[:,1].view(-1,1).cuda().data, TEXT.vocab)
90+
print("inputs_word: ", inputs_word)
91+
arrs = model(b.text[:,1].view(-1,1).cuda()).cuda().data.cpu().numpy()
92+
preds = word_ids_to_sentence(np.argmax(arrs, axis=2), TEXT.vocab)
93+
print("preds----------", preds)
7394

74-
print(preds)
7595

7696

7797
def test_sentence(config, model, TEXT, sentence):
98+
model.load_state_dict(torch.load(config.save_path))
99+
#print(model)
78100
inputs = torch.Tensor([TEXT.vocab.stoi[one] for one in sentence]).long().to(config.device)
101+
print("inputs: ", inputs)
79102
inputs = inputs.view(-1, 1)
80-
# print(inputs.shape)
103+
#print("inputs: ", inputs)
104+
print("inputs shape: ", inputs.shape)
105+
#print(inputs)
81106
arrs = model(inputs)
82107
print("arrs shape: ",arrs.shape)
108+
#print(arrs)
109+
x=np.sum(np.array(arrs.detach().cpu()),axis=2)
83110
preds = word_ids_to_sentence(np.argmax(arrs.detach().cpu(), axis=2), TEXT.vocab)
111+
#print(x)
112+
print("preds: ",preds)
84113

85114
return preds
86115

transformer.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(self,config,TEXT):
6868

6969
self.encoder=nn.Embedding(embeddings.shape[0],config.embedding_dim)
7070
self.decoder=nn.Linear(config.embedding_dim,embeddings.shape[0])
71-
self.init_weights()
71+
if config.mode!="test":
72+
self.init_weights()
7273
config.n_tokens=embeddings.shape[0]
7374

7475

@@ -92,11 +93,16 @@ def forward(self,inputs,has_mask=True):
9293
self.src_mask=mask
9394
else:
9495
self.src_mask=None
96+
#print("1: ", inputs)
9597
inputs=self.encoder(inputs)*math.sqrt(self.embedding_dim)
9698
inputs=self.pos_encoder(inputs)
99+
#print("2: ", inputs)
97100
output=self.transformer_encoder(inputs, self.src_mask)
101+
#print("3: output: ", output)
98102
#print("output shape11: ", output.shape)
99103
output=self.decoder(output)
104+
#print("4:", output)
100105
#print("output shape22: ", output.shape)
101-
return F.log_softmax(output, dim=1)
106+
#print(F.log_softmax(output, dim=2))
107+
return F.log_softmax(output, dim=2)
102108

0 commit comments

Comments
 (0)