Skip to content

Commit

Permalink
#52: 完成launcher 和 config
Browse files Browse the repository at this point in the history
  • Loading branch information
cjopengler committed Nov 10, 2021
1 parent 68f374e commit 3dc0242
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 6 deletions.
86 changes: 86 additions & 0 deletions data/mrc_ner/config/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"training_dataset": {
"__type__": "MSRAFlatNerDataset",
"__name_space__": "mrc_ner",
"dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/sample.json"
},

"validation_dataset": {
"__type__": "MSRAFlatNerDataset",
"__name_space__": "mrc_ner",
"dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/sample.json"
},

"model_collate": {
"__type__": "BertModelCollate",
"__name_space__": "mrc_ner",
"tokenizer": {
"__type__": "BertTokenizer",
"bert_dir": "/Users/panxu/MyProjects/github/easytext/data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch"
},

"max_length": 128
},



"model": {
"__type__": "BertRnnWithCrf",
"__name_space__": "mrc_ner",
"bert_dir": "/Users/panxu/MyProjects/github/easytext/data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch",
"dropout": 0.1
},

"loss": {
"__type__": "MRCBCELoss",
"__name_space__": "mrc_ner",
"start_weight": 1.0,
"end_weight": 1.0,
"match_weight": 1.0
},

"metric": {
"__type__": "MrcModelMetricAdapter",
"__name_space__": "mrc_ner"
},

"optimizer": {
"__type__": "MRCOptimizer",
"__name_space__": "mrc_ner",
"lr": 0.00002,
"eps": 0.00000001,
"weight_decay": 0.01
},

"lr_scheduler": {
"__type__": "MRCLrScheduler",
"__name_space__": "mrc_ner",
"max_lr": 0.00002,
"final_div_factor": 10000,
"total_steps": null,
},
"grad_rescaled": null,

"process_group_parameter": {
"__type__": "ProcessGroupParameter",
"__name_space__": "__easytext__",
"host": "127.0.0.1",
"port": "2345",
"backend": "nccl"
},

"distributed_data_parallel_parameter": {
"__type__": "DistributedDataParallelParameter",
"__name_space__": "__easytext__",
"find_unused_parameters": false
},

"num_epoch": 500,
"patient": 100,
"num_check_point_keep": 1,

"devices": ["cpu"],
"serialize_dir": "/Users/panxu/MyProjects/github/easytext/data/mrc_ner/serialize",
"train_batch_size": 4,
"test_batch_size": 8
}
3 changes: 3 additions & 0 deletions mrc/data/bert_model_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from mrc.data import MRCModelInputs

from easytext.component.register import ComponentRegister


@ComponentRegister.register(name_space="mrc_ner")
class BertModelCollate:
"""
Bert Model Collate
Expand Down
3 changes: 2 additions & 1 deletion mrc/data/msra_flat_ner_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

from easytext.data import Instance
from easytext.data.dataset import Dataset
from easytext.data.tokenizer import ZhTokenizer
from easytext.component.register import ComponentRegister


@ComponentRegister.register(name_space="mrc_ner")
class MSRAFlatNerDataset(Dataset):
"""
MSRA flat ner dataset
Expand Down
17 changes: 13 additions & 4 deletions mrc/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,20 @@
from easytext.distributed import ProcessGroupParameter
from easytext.utils.json_util import json2str
from easytext.component.register import Registry
from easytext.utils.bert_tokenizer import bert_tokenizer

from mrc.data import MSRAFlatNerDataset
from mrc.data import BertModelCollate
from mrc.models import MRCNer
from mrc.loss import MRCBCELoss
from mrc.metric import MrcModelMetricAdapter
from mrc.optimizer import MRCOptimizer
from mrc.optimizer import MRCLrScheduler

class MrcLauncher(Launcher):

class MrcNerLauncher(Launcher):
"""
ner 训练的启动器
mrc ner 训练的启动器
"""

NEW_TRAIN = 0
Expand All @@ -63,7 +72,7 @@ def _preprocess(self):
logging.info(f"config:\n{self.config}\n")

serialize_dir = self.config.serialize_dir
if self._train_type == MrcLauncher.NEW_TRAIN:
if self._train_type == MrcNerLauncher.NEW_TRAIN:
# 清理 serialize dir
if os.path.isdir(serialize_dir):
shutil.rmtree(serialize_dir)
Expand Down Expand Up @@ -131,5 +140,5 @@ def _start(self, rank: Optional[int], world_size: int, device: torch.device) ->
logging.fatal("--config 参数为空!")
exit(-1)
logging.info(f"config file path: {parsed_args.config}")
ner_launcher = MrcLauncher(config_file_path=parsed_args.config, train_type=MrcLauncher.NEW_TRAIN)
ner_launcher = MrcNerLauncher(config_file_path=parsed_args.config, train_type=MrcNerLauncher.NEW_TRAIN)
ner_launcher()
2 changes: 2 additions & 0 deletions mrc/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
Authors: PanXu
Date: 2021/10/27 14:11:00
"""

from .mrc_bce_loss import MRCBCELoss
3 changes: 3 additions & 0 deletions mrc/loss/mrc_bce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

from mrc.models import MRCNerOutput

from easytext.component.register import ComponentRegister


@ComponentRegister.register(name_space="mrc_ner")
class MRCBCELoss:
"""
基于 bce 的 los
Expand Down
2 changes: 1 addition & 1 deletion mrc/metric/mrc_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from mrc.metric import MRCF1Metric


@ComponentRegister.register(name_space="mrc")
@ComponentRegister.register(name_space="mrc_ner")
class MrcModelMetricAdapter(ModelMetricAdapter):
"""
Ner Model Metric Adapter
Expand Down
2 changes: 2 additions & 0 deletions mrc/models/mrc_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from easytext.utils.seed_util import set_seed

from mrc.models import MRCNerOutput
from easytext.component.register import ComponentRegister


class MultiNonLinearClassifier(Module):
Expand All @@ -45,6 +46,7 @@ def forward(self, input_features):
return features_output2


@ComponentRegister.register(name_space="mrc_ner")
class MRCNer(Module):
"""
基于 MRC 的 ner 模型
Expand Down
1 change: 1 addition & 0 deletions mrc/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
"""

from .mrc_optimizer import MRCOptimizer
from .mrc_lr_scheduler import MRCLrScheduler
2 changes: 2 additions & 0 deletions mrc/optimizer/mrc_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

from easytext.model import Model
from easytext.optimizer import LRSchedulerFactory
from easytext.component.register import ComponentRegister


@ComponentRegister.register(name_space="mrc_ner")
class MRCLrScheduler(LRSchedulerFactory):

def __init__(self, max_lr: float, final_div_factor: float, total_steps: int = None, pct_start: float = 0):
Expand Down
2 changes: 2 additions & 0 deletions mrc/optimizer/mrc_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
"""
from easytext.model import Model
from easytext.optimizer import OptimizerFactory
from easytext.component.register import ComponentRegister

from transformers import AdamW


@ComponentRegister.register(name_space="mrc_ner")
class MRCOptimizer(OptimizerFactory):

def __init__(self, lr: float, eps: float, weight_decay: float):
Expand Down

0 comments on commit 3dc0242

Please sign in to comment.