Skip to content

Commit ddc432a

Browse files
authored
Merge pull request #160 from Bone-Fish/dev
[FEATURE] Add Jiuzhang model
2 parents 7abc7d1 + e05f640 commit ddc432a

File tree

13 files changed

+2010
-5
lines changed

13 files changed

+2010
-5
lines changed

AUTHORS.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
[Heng Yu](https://github.com/GNEHUY)
2626

2727
[Tianyun Ji](https://github.com/KINGNEWBLUSH)
28-
The stared contributors are the corresponding authors.
28+
29+
[Chaokun Wang](https://github.com/Bone-Fish)

EduNLP/I2V/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# 2021/8/1 @ tongshiwei
33

44
from .i2v import I2V, get_pretrained_i2v
5-
from .i2v import D2V, W2V, Elmo, Bert, HfAuto, DisenQ, QuesNet
5+
from .i2v import D2V, W2V, Elmo, Bert, HfAuto, DisenQ, QuesNet, Jiuzhang

EduNLP/I2V/i2v.py

+73-3
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from longling import path_append
1212
from EduData import get_data
1313
from ..Tokenizer import Tokenizer, get_tokenizer
14-
from EduNLP.Pretrain import ElmoTokenizer, BertTokenizer, HfAutoTokenizer, DisenQTokenizer, QuesNetTokenizer, Question
14+
from EduNLP.Pretrain import ElmoTokenizer, BertTokenizer, HfAutoTokenizer
15+
from EduNLP.Pretrain import DisenQTokenizer, QuesNetTokenizer, JiuzhangTokenizer
1516
from EduNLP import logger
1617

17-
__all__ = ["I2V", "D2V", "W2V", "Elmo", "Bert", "HfAuto", "DisenQ", "QuesNet", "get_pretrained_i2v"]
18+
__all__ = ["I2V", "D2V", "W2V", "Elmo", "Bert", "HfAuto", "DisenQ", "QuesNet", "get_pretrained_i2v", "Jiuzhang"]
1819

1920

2021
class I2V(object):
@@ -69,6 +70,9 @@ def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None,
6970
if tokenizer == 'bert':
7071
self.tokenizer = BertTokenizer.from_pretrained(
7172
**tokenizer_kwargs if tokenizer_kwargs is not None else {})
73+
elif tokenizer == 'jiuzhang':
74+
self.tokenizer = JiuzhangTokenizer.from_pretrained(
75+
**tokenizer_kwargs if tokenizer_kwargs is not None else {})
7276
elif tokenizer == 'hf_auto':
7377
self.tokenizer = HfAutoTokenizer.from_pretrained(
7478
**tokenizer_kwargs if tokenizer_kwargs is not None else {})
@@ -606,14 +610,80 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, device='cpu', *args, **kwarg
606610
tokenizer_kwargs=tokenizer_kwargs)
607611

608612

613+
class Jiuzhang(I2V):
614+
"""
615+
The model aims to transfer item and tokens to vector with Jiuzhang.
616+
617+
Bases
618+
-------
619+
I2V
620+
621+
Parameters
622+
-----------
623+
tokenizer: str
624+
the tokenizer name
625+
t2v: str
626+
the name of token2vector model
627+
args:
628+
the parameters passed to t2v
629+
tokenizer_kwargs: dict
630+
the parameters passed to tokenizer
631+
pretrained_t2v: bool
632+
True: use pretrained t2v model
633+
False: use your own t2v model
634+
kwargs:
635+
the parameters passed to t2v
636+
637+
Returns
638+
-------
639+
i2v model: Jiuzhang
640+
"""
641+
642+
def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
643+
*args, key=lambda x: x, return_tensors='pt', **kwargs) -> tuple:
644+
"""
645+
It is a function to switch item to vector. And before using the function, it is nesseary to load model.
646+
647+
Parameters
648+
-----------
649+
items : str or dict or list
650+
the item of question, or question list
651+
return_tensors: str
652+
tensor type used in tokenizer
653+
args:
654+
the parameters passed to t2v
655+
kwargs:
656+
the parameters passed to t2v
657+
658+
Returns
659+
--------
660+
vector:list
661+
"""
662+
is_batch = isinstance(items, list)
663+
items = items if is_batch else [items]
664+
inputs = self.tokenize(items, key=key, return_tensors=return_tensors)
665+
return self.t2v.infer_vector(inputs, *args, **kwargs), self.t2v.infer_tokens(inputs, *args, **kwargs)
666+
667+
@classmethod
668+
def from_pretrained(cls, name, model_dir=MODEL_DIR, device='cpu', *args, **kwargs):
669+
model_path = path_append(model_dir, get_pretrained_model_info(name)[0].split('/')[-1], to_str=True)
670+
for i in [".tar.gz", ".tar.bz2", ".tar.bz", ".tar.tgz", ".tar", ".tgz", ".zip", ".rar"]:
671+
model_path = model_path.replace(i, "")
672+
logger.info("model_path: %s" % model_path)
673+
tokenizer_kwargs = {"tokenizer_config_dir": model_path}
674+
return cls("jiuzhang", name, pretrained_t2v=True, model_dir=model_dir, device=device,
675+
tokenizer_kwargs=tokenizer_kwargs)
676+
677+
609678
MODEL_MAP = {
610679
"w2v": W2V,
611680
"d2v": D2V,
612681
"bert": Bert,
613682
"hf_auto": HfAuto,
614683
"disenq": DisenQ,
615684
"quesnet": QuesNet,
616-
"elmo": Elmo
685+
"elmo": Elmo,
686+
"jiuzhang": Jiuzhang,
617687
}
618688

619689

EduNLP/ModelZoo/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .rnn import *
55
from .disenqnet import *
66
from .quesnet import *
7+
from .jiuzhang import *

EduNLP/ModelZoo/jiuzhang/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .jiuzhang import *
2+
from .modeling import CPTModel as JiuzhangModel

EduNLP/ModelZoo/jiuzhang/jiuzhang.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import torch
2+
from torch import nn
3+
import json
4+
import os
5+
from ..base_model import BaseModel
6+
from ..utils import PropertyPredictionOutput, KnowledgePredictionOutput
7+
from transformers import PretrainedConfig
8+
from typing import List
9+
from ..rnn.harnn import HAM
10+
from transformers import BartConfig as JiuzhangConfig
11+
from .modeling import CPTModel as JiuzhangModel
12+
13+
14+
__all__ = ["JiuzhangForPropertyPrediction", "JiuzhangForKnowledgePrediction"]
15+
16+
17+
class JiuzhangForPropertyPrediction(BaseModel):
18+
def __init__(self, pretrained_model_dir=None, head_dropout=0.5, init=True):
19+
super(JiuzhangForPropertyPrediction, self).__init__()
20+
jiuzhang_config = JiuzhangConfig.from_pretrained(pretrained_model_dir)
21+
if init:
22+
print(f'Load Jiuzhang from checkpoint: {pretrained_model_dir}')
23+
self.jiuzhang = JiuzhangModel.from_pretrained(pretrained_model_dir, ignore_mismatched_sizes=True)
24+
else:
25+
print(f'Load Jiuzhang from config: {pretrained_model_dir}')
26+
self.jiuzhang = JiuzhangModel(jiuzhang_config)
27+
self.hidden_size = self.jiuzhang.config.hidden_size
28+
self.head_dropout = head_dropout
29+
self.dropout = nn.Dropout(head_dropout)
30+
self.classifier = nn.Linear(self.hidden_size, 1)
31+
self.sigmoid = nn.Sigmoid()
32+
self.criterion = nn.MSELoss()
33+
34+
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "jiuzhang_config"]}
35+
self.config['architecture'] = 'JiuzhangForPropertyPrediction'
36+
self.config = PretrainedConfig.from_dict(self.config)
37+
38+
def forward(self,
39+
input_ids=None,
40+
attention_mask=None,
41+
token_type_ids=None,
42+
labels=None):
43+
outputs = self.jiuzhang(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
44+
# outputs = self.jiuzhang(input_ids=input_ids, attention_mask=attention_mask)
45+
item_embeds = outputs.last_hidden_state[:, 0, :]
46+
item_embeds = self.dropout(item_embeds)
47+
48+
logits = self.sigmoid(self.classifier(item_embeds)).squeeze(1)
49+
loss = None
50+
if labels is not None:
51+
loss = self.criterion(logits, labels) if labels is not None else None
52+
return PropertyPredictionOutput(
53+
loss=loss,
54+
logits=logits,
55+
)
56+
57+
@classmethod
58+
def from_config(cls, config_path, **kwargs):
59+
config_path = os.path.join(os.path.dirname(config_path), 'model_config.json')
60+
with open(config_path, "r", encoding="utf-8") as rf:
61+
model_config = json.load(rf)
62+
model_config['pretrained_model_dir'] = os.path.dirname(config_path)
63+
model_config.update(kwargs)
64+
return cls(
65+
pretrained_model_dir=model_config['pretrained_model_dir'],
66+
head_dropout=model_config.get("head_dropout", 0.5),
67+
init=model_config.get('init', False)
68+
)
69+
70+
def save_config(self, config_dir):
71+
config_path = os.path.join(config_dir, "model_config.json")
72+
with open(config_path, "w", encoding="utf-8") as wf:
73+
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
74+
self.jiuzhang.config.save_pretrained(config_dir)
75+
76+
77+
class JiuzhangForKnowledgePrediction(BaseModel):
78+
def __init__(self,
79+
pretrained_model_dir=None,
80+
num_classes_list: List[int] = None,
81+
num_total_classes: int = None,
82+
head_dropout=0.5,
83+
flat_cls_weight=0.5,
84+
attention_unit_size=256,
85+
fc_hidden_size=512,
86+
beta=0.5,
87+
init=True
88+
):
89+
super(JiuzhangForKnowledgePrediction, self).__init__()
90+
jiuzhang_config = JiuzhangConfig.from_pretrained(pretrained_model_dir)
91+
if init:
92+
print(f'Load Jiuzhang from checkpoint: {pretrained_model_dir}')
93+
self.jiuzhang = JiuzhangModel.from_pretrained(pretrained_model_dir, ignore_mismatched_sizes=True)
94+
else:
95+
print(f'Load Jiuzhang from config: {pretrained_model_dir}')
96+
self.jiuzhang = JiuzhangModel(jiuzhang_config)
97+
self.hidden_size = self.jiuzhang.config.hidden_size
98+
self.head_dropout = head_dropout
99+
self.dropout = nn.Dropout(head_dropout)
100+
self.sigmoid = nn.Sigmoid()
101+
self.criterion = nn.MSELoss()
102+
self.flat_classifier = nn.Linear(self.hidden_size, num_total_classes)
103+
self.ham_classifier = HAM(
104+
num_classes_list=num_classes_list,
105+
num_total_classes=num_total_classes,
106+
sequence_model_hidden_size=self.jiuzhang.config.hidden_size,
107+
attention_unit_size=attention_unit_size,
108+
fc_hidden_size=fc_hidden_size,
109+
beta=beta,
110+
dropout_rate=head_dropout
111+
)
112+
self.flat_cls_weight = flat_cls_weight
113+
self.num_classes_list = num_classes_list
114+
self.num_total_classes = num_total_classes
115+
116+
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "jiuzhang_config"]}
117+
self.config['architecture'] = 'JiuzhangForKnowledgePrediction'
118+
self.config = PretrainedConfig.from_dict(self.config)
119+
120+
def forward(self,
121+
input_ids=None,
122+
attention_mask=None,
123+
token_type_ids=None,
124+
labels=None):
125+
outputs = self.jiuzhang(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
126+
item_embeds = outputs.last_hidden_state[:, 0, :]
127+
item_embeds = self.dropout(item_embeds)
128+
tokens_embeds = outputs.last_hidden_state
129+
tokens_embeds = self.dropout(tokens_embeds)
130+
flat_logits = self.sigmoid(self.flat_classifier(item_embeds))
131+
ham_outputs = self.ham_classifier(tokens_embeds)
132+
ham_logits = self.sigmoid(ham_outputs.scores)
133+
logits = self.flat_cls_weight * flat_logits + (1 - self.flat_cls_weight) * ham_logits
134+
loss = None
135+
if labels is not None:
136+
labels = torch.sum(torch.nn.functional.one_hot(labels, num_classes=self.num_total_classes), dim=1)
137+
labels = labels.float()
138+
loss = self.criterion(logits, labels) if labels is not None else None
139+
return KnowledgePredictionOutput(
140+
loss=loss,
141+
logits=logits,
142+
)
143+
144+
@classmethod
145+
def from_config(cls, config_path, **kwargs):
146+
config_path = os.path.join(os.path.dirname(config_path), 'model_config.json')
147+
with open(config_path, "r", encoding="utf-8") as rf:
148+
model_config = json.load(rf)
149+
model_config['pretrained_model_dir'] = os.path.dirname(config_path)
150+
model_config.update(kwargs)
151+
return cls(
152+
pretrained_model_dir=model_config['pretrained_model_dir'],
153+
head_dropout=model_config.get("head_dropout", 0.5),
154+
num_classes_list=model_config.get('num_classes_list'),
155+
num_total_classes=model_config.get('num_total_classes'),
156+
flat_cls_weight=model_config.get('flat_cls_weight', 0.5),
157+
attention_unit_size=model_config.get('attention_unit_size', 256),
158+
fc_hidden_size=model_config.get('fc_hidden_size', 512),
159+
beta=model_config.get('beta', 0.5),
160+
init=model_config.get('init', False)
161+
)
162+
163+
def save_config(self, config_dir):
164+
config_path = os.path.join(config_dir, "model_config.json")
165+
with open(config_path, "w", encoding="utf-8") as wf:
166+
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
167+
self.jiuzhang.config.save_pretrained(config_dir)

0 commit comments

Comments
 (0)