-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_utils.py
28 lines (25 loc) · 1.83 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast, XLMRobertaTokenizerFast
def initiate_tokenizer(settings):
if settings['tokenizer'] == 'emanjavacas/GysBERT-v2':
tokenizer = BertTokenizerFast.from_pretrained('emanjavacas/GysBERT-v2')#, model_max_length = 512) # works without model_max_lenght, makes no difference in performance
if settings['tokenizer'] == 'emanjavacas/GysBERT':
tokenizer = BertTokenizerFast.from_pretrained('emanjavacas/GysBERT')#, padding=True)
if settings['tokenizer'] == 'FacebookAI/xlm-roberta-base':
tokenizer = XLMRobertaTokenizerFast.from_pretrained("FacebookAI/xlm-roberta-base")
if settings['tokenizer'] == 'pdelobelle/robbert-v2-dutch-base':
tokenizer = RobertaTokenizerFast.from_pretrained("pdelobelle/robbert-v2-dutch-base")
if settings['tokenizer'] == 'GroNLP/bert-base-dutch-cased':
tokenizer = AutoTokenizer.from_pretrained('GroNLP/bert-base-dutch-cased')
if settings['tokenizer'] == 'bert-base-multilingual-cased':
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
if settings['tokenizer'] == 'google-bert/bert-base-cased':
tokenizer = BertTokenizerFast.from_pretrained('google-bert/bert-base-cased')
if settings['tokenizer'] == 'FacebookAI/roberta-base':
tokenizer = RobertaTokenizerFast.from_pretrained('FacebookAI/roberta-base')
if settings['tokenizer'] == 'FacebookAI/roberta-large':
tokenizer = RobertaTokenizerFast.from_pretrained('FacebookAI/roberta-large')
if settings['tokenizer'] == 'google-bert/bert-large-cased':
tokenizer = BertTokenizerFast.from_pretrained('google-bert/bert-large-cased')
if settings['tokenizer'] == 'emanjavacas/MacBERTh':
tokenizer = BertTokenizerFast.from_pretrained('emanjavacas/MacBERTh')
return(tokenizer)