Skip to content

Commit

Permalink
#52: 完成对 model 与 paper 模型比对测试。同时增加, seed 函数,用于控制随机值
Browse files Browse the repository at this point in the history
  • Loading branch information
cjopengler committed Nov 8, 2021
1 parent d55b8f2 commit 90bbda9
Show file tree
Hide file tree
Showing 16 changed files with 401 additions and 21 deletions.
46 changes: 46 additions & 0 deletions easytext/utils/nn/bert_init_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python 3
# -*- coding: utf-8 -*-

#
# Copyright (c) 2021 PanXu, Inc. All Rights Reserved
#
"""
使用 bert 的 init weights
Authors: PanXu
Date: 2021/11/08 08:44:00
"""

from torch.nn import Module
from torch import nn

from transformers import BertConfig


class BertInitWeights:
"""
bert 初始化权重
参考: BertPreTrainedModel._init_weights
"""

def __init__(self, bert_config: BertConfig):
self.config = bert_config

def __call__(self, module: Module) -> None:
"""
参考: BertPreTrainedModel._init_weights
:param module: 模型
:return: None
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()




4 changes: 2 additions & 2 deletions easytext/utils/nn/tensor_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def is_tensor_equal(tensor1: torch.Tensor, tensor2: torch.Tensor, epsilon: float
assert tensor1.size() == tensor2.size(), f"tensor1 size: {tensor1.size()} 与 tensor2 size: {tensor2.size()} 不匹配"

if tensor1.dtype == torch.long and tensor2.dtype == torch.long:
# 都是整数的时候,需要将 epsilon 设置为 0
epsilon = 0
# 都是整数的时候,直接使用 equal
return torch.equal(tensor1, tensor2)

delta = tensor1 - tensor2

Expand Down
29 changes: 29 additions & 0 deletions easytext/utils/seed_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python 3
# -*- coding: utf-8 -*-

#
# Copyright (c) 2021 PanXu, Inc. All Rights Reserved
#
"""
设置随机数种子
Authors: PanXu
Date: 2021/11/07 12:17:00
"""

import torch
import numpy as np
import random


def set_seed(seed: int = 7) -> None:
"""
设置相关函数的随机数种子
:param seed: 随机数种子
:return: None
"""
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
2 changes: 1 addition & 1 deletion mrc/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ def _start(self, rank: Optional[int], world_size: int, device: torch.device) ->
logging.fatal("--config 参数为空!")
exit(-1)
logging.info(f"config file path: {parsed_args.config}")
ner_launcher = MrcLauncher(config_file_path=parsed_args.config, train_type=NerLauncher.NEW_TRAIN)
ner_launcher = MrcLauncher(config_file_path=parsed_args.config, train_type=MrcLauncher.NEW_TRAIN)
ner_launcher()
21 changes: 15 additions & 6 deletions mrc/models/mrc_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
from torch.nn import Module
from torch.nn import Linear
from torch.nn import Dropout, GELU
from torch.nn import Embedding, LayerNorm

from transformers import BertModel, BertConfig
from transformers import BertModel, BertConfig, BertPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling

from easytext.utils.nn.bert_init_weights import BertInitWeights
from easytext.utils.seed_util import set_seed

from mrc.models import MRCNerOutput


Expand Down Expand Up @@ -47,20 +51,22 @@ class MRCNer(Module):
"""

def __init__(self, bert_dir: str, dropout: float):
super().__init__()

super().__init__()
self.bert = BertModel.from_pretrained(bert_dir)

bert_config: BertConfig = self.bert.config
bert_config = self.bert.config

self.start_classifier = Linear(bert_config.hidden_size, 1)
self.end_classifier = Linear(bert_config.hidden_size, 1)
self.match_classifier = MultiNonLinearClassifier(bert_config.hidden_size * 1, 1, dropout)
self.match_classifier = MultiNonLinearClassifier(bert_config.hidden_size * 2, 1, dropout)

self.init_weights = BertInitWeights(bert_config=bert_config)
self.reset_parameters()

def reset_parameters(self):
pass
self.start_classifier.apply(self.init_weights)
self.end_classifier.apply(self.init_weights)
self.match_classifier.apply(self.init_weights)

def forward(self,
input_ids: torch.Tensor,
Expand All @@ -86,8 +92,11 @@ def forward(self,
sequence_length = sequence_output.size(1)

start_logits = self.start_classifier(sequence_output)
# 最后一个维度去掉
start_logits = start_logits.squeeze(-1)

end_logits = self.end_classifier(sequence_output)
end_logits = end_logits.squeeze(-1)

# 将每一个 i 与 j 连接在一起, 所以是 N*N的拼接,使用了 expand, 进行 两个方向的扩展
# 产生一个 match matrix
Expand Down
1 change: 1 addition & 0 deletions mrc/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from unittest import TestCase

ASSERT = TestCase()
SEED = 7
11 changes: 1 addition & 10 deletions mrc/tests/data/conftest.py → mrc/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,17 @@
"""

import os
from typing import Dict, Union

import pytest

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from tokenizers import BertWordPieceTokenizer

from easytext.data import Vocabulary, LabelVocabulary, PretrainedVocabulary

from ner.data import VocabularyCollate

from easytext.data import Vocabulary

from mrc import ROOT_PATH
from mrc.data import MSRAFlatNerDataset


from mrc.tests.data.paper_src.mrc_ner_dataset import MRCNERDataset
from mrc.tests.paper.mrc_ner_dataset import MRCNERDataset


@pytest.fixture(scope="session")
Expand Down
2 changes: 1 addition & 1 deletion mrc/tests/data/test_bert_model_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mrc.tests import ASSERT


from mrc.tests.data.paper_src.collate_functions import collate_to_max_length
from mrc.tests.paper.collate_functions import collate_to_max_length


def test_bert_model_collate(mrc_msra_ner_dataset, paper_mrc_msra_ner_dataset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
brief
Authors: PanXu
Date: 2021/10/25 19:30:00
Date: 2021/11/07 11:45:00
"""
176 changes: 176 additions & 0 deletions mrc/tests/models/test_mrc_ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python 3
# -*- coding: utf-8 -*-

#
# Copyright (c) 2021 PanXu, Inc. All Rights Reserved
#
"""
测试 mrc ner
Authors: PanXu
Date: 2021/11/07 11:45:00
"""

import torch

from mrc.models import MRCNer, MRCNerOutput


import os

import torch
import logging

from transformers import BertConfig

from easytext.utils.bert_tokenizer import bert_tokenizer
from easytext.utils.nn.tensor_util import is_tensor_equal

from mrc import ROOT_PATH

from mrc.tests import ASSERT
from mrc.data.bert_model_collate import BertModelCollate
from mrc.tests.paper.collate_functions import collate_to_max_length
from mrc.models import MRCNer, MRCNerOutput
from mrc.tests.paper.bert_query_ner import BertQueryNER
from mrc.tests.paper.query_ner_config import BertQueryNerConfig

from easytext.utils.seed_util import set_seed


def fake_model_weight(module: torch.nn.Module):

if isinstance(module, torch.nn.Linear):
fake_weight = torch.rand(module.weight.size())
fake_bias = 0.
module.weight.data.copy_(fake_weight)
module.bias.data.fill_(fake_bias)


def test_mrc_ner(mrc_msra_ner_dataset, paper_mrc_msra_ner_dataset):

# 设置 random seed 保证每一次的结果是一样的
set_seed()

max_length = 128

bert_dir = "data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch"
bert_dir = os.path.join(ROOT_PATH, bert_dir)

bert_config = BertConfig.from_pretrained(bert_dir)

bert_model_collate = BertModelCollate(tokenizer=bert_tokenizer(bert_dir), max_length=max_length)

instances = [instance for instance in mrc_msra_ner_dataset]
model_inputs = bert_model_collate(instances=instances)

inputs = model_inputs.model_inputs

paper_instances = [instance for instance in paper_mrc_msra_ner_dataset]
paper_model_inputs = collate_to_max_length(paper_instances)

paper_token_ids = paper_model_inputs[0]
token_ids = inputs["input_ids"]

ASSERT.assertTrue(is_tensor_equal(paper_token_ids, token_ids, epsilon=0))

paper_type_ids = paper_model_inputs[1]
type_ids = inputs["token_type_ids"]

ASSERT.assertTrue(is_tensor_equal(paper_type_ids, type_ids, epsilon=0))

paper_start_label_indices = paper_model_inputs[2]

start_label_indices = model_inputs.labels["start_position_labels"]

ASSERT.assertTrue(is_tensor_equal(paper_start_label_indices, start_label_indices, epsilon=0))

paper_end_label_indices = paper_model_inputs[3]

end_label_indices = model_inputs.labels["end_position_labels"]

ASSERT.assertTrue(is_tensor_equal(paper_end_label_indices, end_label_indices, epsilon=0))

paper_start_label_mask = paper_model_inputs[4]
sequence_mask = inputs["sequence_mask"].long()

ASSERT.assertTrue(is_tensor_equal(paper_start_label_mask, sequence_mask, epsilon=0))

paper_end_label_mask = paper_model_inputs[5]
sequence_mask = inputs["sequence_mask"].long()

ASSERT.assertTrue(is_tensor_equal(paper_end_label_mask, sequence_mask, epsilon=0))

paper_match_labels = paper_model_inputs[6]
match_labels = model_inputs.labels["match_position_labels"]

ASSERT.assertTrue(is_tensor_equal(paper_match_labels, match_labels, epsilon=0))



logging.info(f"begin mrc ner")
set_seed()

mrc_model = MRCNer(bert_dir=bert_dir, dropout=0)

# 设置固定权重
set_seed()
mrc_model.start_classifier.apply(fake_model_weight)
mrc_model.end_classifier.apply(fake_model_weight)
mrc_model.match_classifier.apply(fake_model_weight)

# fake_start = torch.rand(mrc_model.start_classifier.weight.size())
# mrc_model.start_classifier.weight.data.copy_(fake_start)

# fake_end = torch.rand(mrc_model.end_classifier.weight.size())
# mrc_model.end_classifier.weight.data.copy_(fake_end)

# fake_match = torch.rand(mrc_model.match_classifier.weight)

logging.info(f"mrc ner forward")
mrc_model_output = mrc_model.forward(**model_inputs.model_inputs)

logging.info(f"end mrc ner")

logging.info(f"begin paper ner")

set_seed()
# 获取 bert config
bert_config = BertQueryNerConfig.from_pretrained(bert_dir,
mrc_dropout=0)

# 获取模型
paper_model = BertQueryNER.from_pretrained(bert_dir, config=bert_config)

# paper_model.start_outputs.weight.data.copy_(fake_start)
set_seed()
paper_model.start_outputs.apply(fake_model_weight)
paper_model.end_outputs.apply(fake_model_weight)
paper_model.span_embedding.apply(fake_model_weight)

logging.info(f"paper ner forward")
paper_attention_mask = (paper_token_ids != 0).long()
paper_output = paper_model.forward(input_ids=paper_token_ids,
token_type_ids=paper_type_ids,
attention_mask=paper_attention_mask)

paper_start_logits, paper_end_logits, paper_span_logits = paper_output

logging.info(f"end paper ner")

ASSERT.assertTrue(is_tensor_equal(mrc_model_output.start_logits, paper_start_logits, epsilon=1e-10))
ASSERT.assertTrue(is_tensor_equal(mrc_model_output.end_logits, paper_end_logits, epsilon=1e-10))
ASSERT.assertTrue(is_tensor_equal(mrc_model_output.match_logits, paper_span_logits, epsilon=1e-10))













12 changes: 12 additions & 0 deletions mrc/tests/paper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python 3
# -*- coding: utf-8 -*-

#
# Copyright (c) 2021 PanXu, Inc. All Rights Reserved
#
"""
brief
Authors: PanXu
Date: 2021/11/07 12:10:00
"""
Loading

0 comments on commit 90bbda9

Please sign in to comment.