Skip to content

Commit 587ed9f

Browse files
aaa123githzhwcmhf
authored andcommittedAug 26, 2019
replace pytorch_pretrained_bert with pytorch-transformers (#324)
1 parent 9baf86a commit 587ed9f

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed
 

‎cotk/dataloader/bert_dataloader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .._utils import trim_before_target
66

77
try:
8-
from pytorch_pretrained_bert import BertTokenizer
8+
from pytorch_transformers import BertTokenizer
99
except ImportError as err:
1010
from .._utils.imports import DummyObject
1111
BertTokenizer = DummyObject(err)
@@ -22,8 +22,8 @@ class BERTLanguageProcessingBase(LanguageProcessingBase):
2222

2323
BERT_VOCAB_NAME = r"""
2424
bert_vocab_name (str): A string indicates which bert model is used, it will be a
25-
parameter passed to `pytorch-pretrained-BERT.BertTokenizer.from_pretrained
26-
<https://github.com/huggingface/pytorch-pretrained-BERT#berttokenizer>`_.
25+
parameter passed to `pytorch-transformers.BertTokenizer.from_pretrained
26+
<https://github.com/huggingface/pytorch-transformers#berttokenizer>`_.
2727
It can be 'bert-[base|large]-[uncased|cased]' or a local path."""
2828

2929
ARGUMENTS = LanguageProcessingBase.ARGUMENTS + BERT_VOCAB_NAME

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def run_tests(self):
4545
"pytest>=3.6.0",
4646
"pytest-cov==2.4.0",
4747
"checksumdir",
48-
"pytorch-pretrained-bert>=0.6.0"
48+
"pytorch-transformers>=1.1.0"
4949
]
5050
},
5151
cmdclass={'test': LibTest},

‎tests/metric/metric_base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import copy
44
import json
5+
import re
56

67
import numpy as np
78

@@ -385,7 +386,9 @@ def version_test(metric_class, dataloader=None):
385386
metric.forward(**batch)
386387
res = metric.close()
387388
for key, val in res.items():
388-
if isinstance(val, (np.float, np.float16, np.float32, np.float64, np.float128, float)):
389+
if isinstance(val, float) or re.match(r"<class 'numpy\.float\d*'>", str(type(val))):
389390
res[key] = float(val)
391+
elif isinstance(val, int) or re.match(r"<class 'numpy\.int\d*'>", str(type(val))):
392+
res[key] = int(val)
390393
assert same_dict(res, data['output'], exact_equal=False), "Version {} error".format(version)
391394
# assert metric.close() == data['output'], "Version {} error".format(version)

0 commit comments

Comments
 (0)