Skip to content

Commit 0de4960

Browse files
committed
修改了两处bug
1 parent 3db0f41 commit 0de4960

File tree

6 files changed

+5
-4
lines changed

6 files changed

+5
-4
lines changed

__pycache__/model.cpython-36.pyc

38 Bytes
Binary file not shown.

__pycache__/train_eval.cpython-36.pyc

46 Bytes
Binary file not shown.

__pycache__/utils.cpython-36.pyc

38 Bytes
Binary file not shown.

main_zh.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class Config(object):
1515
def __init__(self):
1616
self.model_name="lm_model"
17-
self.data_ori="E:/data/word_nlp/cnews_data/"
17+
self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
1818
self.train_path="train_0.csv"
1919
self.valid_path="train_0.csv"
2020
self.test_path="test_100.csv"
@@ -37,6 +37,7 @@ def __init__(self):
3737
self.train_len=0
3838
self.test_len = 0
3939
self.valid_len = 0
40+
self.mode="train"
4041
#data_path="E:/study_series/2020_3/re_write_classify/data/"
4142
#data_path="/mnt/data3/wuchunsheng/code/nlper/NLP_task/text_classification/my_classification_cnews/2020_3_30/text_classify/data/"
4243

@@ -50,4 +51,4 @@ def __init__(self):
5051

5152
train(config,model,train_iter, valid_iter, test_iter)
5253

53-
test(config,model,TEXT, test_iter)## 测试的是一个正批量的
54+
test(config,model,TEXT, test_iter)## 测试的是一个正批量的

test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
class Config(object):
1717
def __init__(self):
1818
self.model_name = "lm_model"
19-
self.data_ori = "E:/data/word_nlp/cnews_data/"
19+
self.data_ori = "/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
2020
self.train_path = "train_0.csv"
2121
self.valid_path = "train_0.csv"
2222
self.test_path = "test_100.csv"

train_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test(config, model, TEXT, test_iter):
7272

7373

7474
def test_sentence(config, model, TEXT, sentence):
75-
inputs = torch.Tensor([TEXT.vocab[one] for one in sentence]).long().to(config.device)
75+
inputs = torch.Tensor([TEXT.vocab.stoi[one] for one in sentence]).long().to(config.device)
7676
inputs = inputs.view(-1, 1)
7777
# print(inputs.shape)
7878
arrs = model(inputs)

0 commit comments

Comments
 (0)