Skip to content

Commit 45d0fa3

Browse files
authored
Add files via upload
1 parent 7e8437a commit 45d0fa3

15 files changed

+2880
-0
lines changed

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,46 @@
11
# RTG4TE
22
Source code for the paper: Retrieval-enhanced Template Generation for Template Extraction (NLPCC 2024)
3+
4+
## Overview
5+
6+
![model11](figs/an example.png)
7+
An example of template extraction. A generic template is extracted for document-level REE task. Two event templates including an `Attack` event template and a `Bombing` event template are extracted for TF task.
8+
9+
All the required packages are listed in `requirements.txt`. To install all the dependencies, run
10+
11+
```
12+
pip install -r requirements.txt
13+
```
14+
15+
16+
## Data
17+
For TF task, we downloaded the original dataset from [GTT](https://github.com/xinyadu/gtt). The extracted train, dev, and test files are located in `data/tf/`.
18+
These original data are transformed into our internal format using `convert_tf.py`.
19+
```
20+
python convert_tf.py --input_path data/train.json --output_path data/tf_train.json
21+
```
22+
23+
AS for REE task, we downloaded the original dataset from [GRIT](https://github.com/xinyadu/grit_doc_event_entity/). The extracted train, dev, and test files are located in `data/ree/`.
24+
These original data are transformed into our internal format using `convert_grit.py`.
25+
26+
```
27+
python convert_grit.py --input_path data/grit_train.json --output_path data/ree_train.json
28+
```
29+
30+
## Usage
31+
Template Filling
32+
```
33+
python train.py -c config/tf_generative_model.json
34+
```
35+
36+
Role-filler entity extraction
37+
```
38+
python train.py -c config/ree_generative_model.json
39+
```
40+
41+
## Acknowledgement
42+
43+
We refer to the code of [TempGen](https://github.com/PlusLabNLP/TempGen). Thanks for their contributions.
44+
## Citation
45+
46+

config.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import copy
2+
import json
3+
import os
4+
from constants import *
5+
6+
from transformers import AutoConfig
7+
8+
class Config(object):
9+
def __init__(self, **kwargs):
10+
self.coref = kwargs.pop('coref', False)
11+
# bert
12+
self.bert_model_name = kwargs.pop('bert_model_name', 'bert-large-cased')
13+
self.bert_cache_dir = kwargs.pop('bert_cache_dir', None)
14+
self.extra_bert = kwargs.pop('extra_bert', -1)
15+
self.use_extra_bert = kwargs.pop('use_extra_bert', False)
16+
# model
17+
# self.multi_piece_strategy = kwargs.pop('multi_piece_strategy', 'first')
18+
self.bert_dropout = kwargs.pop('bert_dropout', .5)
19+
self.linear_dropout = kwargs.pop('linear_dropout', .4)
20+
self.linear_bias = kwargs.pop('linear_bias', True)
21+
self.linear_activation = kwargs.pop('linear_activation', 'relu')
22+
23+
# decoding
24+
self.max_position_embeddings = kwargs.pop('max_position_embeddings', 2048)
25+
self.num_beams = kwargs.pop('num_beams', 4)
26+
self.decoding_method = kwargs.pop('decoding_method', "greedy")
27+
28+
# files
29+
self.train_file = kwargs.pop('train_file', None)
30+
self.dev_file = kwargs.pop('dev_file', None)
31+
self.test_file = kwargs.pop('test_file', None)
32+
self.valid_pattern_path = kwargs.pop('valid_pattern_path', None)
33+
self.log_path = kwargs.pop('log_path', './log')
34+
self.output_path = kwargs.pop('output_path', './output')
35+
self.grit_dev_file = kwargs.pop('grit_dev_file', None)
36+
self.grit_test_file = kwargs.pop('grit_test_file', None)
37+
38+
# training
39+
self.accumulate_step = kwargs.pop('accumulate_step', 1)
40+
self.batch_size = kwargs.pop('batch_size', 10)
41+
self.eval_batch_size = kwargs.pop('eval_batch_size', 5)
42+
self.max_epoch = kwargs.pop('max_epoch', 50)
43+
self.max_length = kwargs.pop('max_length', 128)
44+
self.learning_rate = kwargs.pop('learning_rate', 1e-3)
45+
self.bert_learning_rate = kwargs.pop('bert_learning_rate', 1e-5)
46+
self.weight_decay = kwargs.pop('weight_decay', 0.001)
47+
self.bert_weight_decay = kwargs.pop('bert_weight_decay', 0.00001)
48+
self.warmup_epoch = kwargs.pop('warmup_epoch', 5)
49+
self.grad_clipping = kwargs.pop('grad_clipping', 5.0)
50+
self.SOT_weights = kwargs.pop('SOT_weights', 100)
51+
self.permute_slots = kwargs.pop('permute_slots', False)
52+
53+
# task cannot be empty
54+
55+
# others
56+
self.use_gpu = kwargs.pop('use_gpu', True)
57+
self.gpu_device = kwargs.pop('gpu_device', 0)
58+
self.seed = kwargs.pop('seed', 0)
59+
# self.seed = kwargs.pop('seed', 1)
60+
self.use_copy = kwargs.pop('use_copy', False)
61+
self.use_SAGCopy = kwargs.pop('use_SAGCopy', False)
62+
self.k = kwargs.pop('k', 12)
63+
64+
65+
66+
@classmethod
67+
def from_dict(cls, dict_obj):
68+
"""Creates a Config object from a dictionary.
69+
Args:
70+
dict_obj (Dict[str, Any]): a dict where keys are
71+
"""
72+
config = cls()
73+
for k, v in dict_obj.items():
74+
setattr(config, k, v)
75+
return config
76+
77+
@classmethod
78+
def from_json_file(cls, path):
79+
with open(path, 'r', encoding='utf-8') as r:
80+
return cls.from_dict(json.load(r))
81+
82+
def to_dict(self):
83+
output = copy.deepcopy(self.__dict__)
84+
return output
85+
86+
def save_config(self, path):
87+
"""Save a configuration object to a file.
88+
:param path (str): path to the output file or its parent directory.
89+
"""
90+
if os.path.isdir(path):
91+
path = os.path.join(path, 'config.json')
92+
print('Save config to {}'.format(path))
93+
with open(path, 'w', encoding='utf-8') as w:
94+
w.write(json.dumps(self.to_dict(), indent=2,
95+
sort_keys=True))
96+
@property
97+
def bert_config(self):
98+
99+
100+
return AutoConfig.from_pretrained(self.bert_model_name,
101+
cache_dir=self.bert_cache_dir,
102+
max_position_embeddings=self.max_position_embeddings)
103+
104+

constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
SEP_T = '<SEP_T>'
2+
SEP = '<SEP>'
3+
END_OF_SEP = '</SEP>'
4+
PERP_IND='<PerpInd>'
5+
END_OF_PERP_IND='</PerpInd>'
6+
PERP_ORG='<PerpOrg>'
7+
END_OF_PERP_ORG='</PerpOrg>'
8+
TARGET='<Target>'
9+
END_OF_TARGET='</Target>'
10+
VICTIM='<Victim>'
11+
END_OF_VICTIM='</Victim>'
12+
WEAPON='<Weapon>'
13+
END_OF_WEAPON='</Weapon>'
14+
AND = "[and]"
15+
NO_ROLE = "[None]"
16+
17+
ROLES = [SEP_T, SEP, END_OF_SEP, PERP_IND, END_OF_PERP_IND, PERP_ORG, END_OF_PERP_ORG, TARGET, END_OF_TARGET, VICTIM, END_OF_VICTIM, WEAPON, END_OF_WEAPON, AND, NO_ROLE]
18+
# these variables are for decoding
19+
SLOT_NAME_TAG=0
20+
ENTITY_TAG=1
21+
22+
ROLE_FILLER_ENTITY_EXTRACTION='ree'

convert_grit.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import argparse
2+
import json
3+
import nltk
4+
# these are for splitting doctext to sentences
5+
nltk.download('punkt')
6+
sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
7+
8+
def process_entities(entities):
9+
10+
'''
11+
[
12+
[
13+
['guerrillas', 37],
14+
['guerrilla column', 349]
15+
],
16+
[
17+
['apple', 45]
18+
],
19+
[
20+
['banana', 60]
21+
]
22+
]
23+
-> [['guerrillas, guerrilla column'], ['apple'], ['banana']]
24+
'''
25+
26+
res = []
27+
for entity in entities:
28+
29+
# take only the string
30+
res.append([mention[0] for mention in entity])
31+
32+
return res
33+
34+
def convert(doc, capitalize=False):
35+
'''
36+
doc: a dictionary that has the following format:
37+
38+
{'docid': 'TST1-MUC3-0001',
39+
'doctext': 'the guatemala army denied today that guerrillas attacked the "santo tomas" presidential farm, located on the pacific side, where president cerezo has been staying since 2 february. a report published by the "cerigua" news agency -- mouthpiece of the guatemalan national revolutionary unity (urng) -- whose main offices are in mexico, says that a guerrilla column attacked the farm 2 days ago. however, armed forces spokesman colonel luis arturo isaacs said that the attack, which resulted in the death of a civilian who was passing by at the time of the skirmish, was not against the farm, and that president cerezo is safe and sound. he added that on 3 february president cerezo met with the diplomatic corps accredited in guatemala. the government also issued a communique describing the rebel report as "false and incorrect," and stressing that the president was never in danger. col isaacs said that the guerrillas attacked the "la eminencia" farm located near the "santo tomas" farm, where they burned the facilities and stole food. a military patrol clashed with a rebel column and inflicted three casualties, which were taken away by the guerrillas who fled to the mountains, isaacs noted. he also reported that guerrillas killed a peasant in the city of flores, in the northern el peten department, and burned a tank truck.',
40+
'extracts': {'PerpInd': [[['guerrillas', 37], ['guerrilla column', 349]]],
41+
'PerpOrg': [[['guatemalan national revolutionary unity', 253],
42+
['urng', 294]]],
43+
'Target': [[['"santo tomas" presidential farm', 61],
44+
['presidential farm', 75]],
45+
[['farm', 88], ['"la eminencia" farm', 947]],
46+
[['facilities', 1026]],
47+
[['tank truck', 1341], ['truck', 1346]]],
48+
'Victim': [[['cerezo', 139]]],
49+
'Weapon': []}}
50+
51+
capitalize: whether to capitalize doctext or not
52+
'''
53+
54+
res = {
55+
'docid': doc['docid'],
56+
'document': doc['doctext'], # the raw text document.
57+
'annotation': [] # A list of templates. In role-filler entity extraction, we only have one template for each don't care about this.
58+
}
59+
60+
if capitalize:
61+
# split doctext into sentences
62+
sentences = sent_tokenizer.tokenize(doc['doctext'])
63+
capitalized_doctext = ' '.join([sent.capitalize() for sent in sentences])
64+
res['document'] = capitalized_doctext
65+
66+
67+
68+
# TODO: add "tags" in the document
69+
# res['document'] = doc_text_no_n
70+
71+
annotation = doc['extracts']
72+
for role, entities in annotation.items():
73+
# make sure entities is not an empty list
74+
if entities:
75+
# make sure res['annotation'] has one dictionary
76+
if len(res['annotation']) == 0:
77+
res['annotation'].append({})
78+
res['annotation'][0][role] = process_entities(entities)
79+
80+
return res
81+
82+
if __name__ == '__main__':
83+
84+
p = argparse.ArgumentParser("Convert GRIT input data into ours format.")
85+
86+
p.add_argument('--input_path', type=str, help="input file in GRIT format.")
87+
p.add_argument('--output_path',type=str, help="path to store the output json file.")
88+
p.add_argument('--capitalize',action="store_true", help="whether to capitalize the first char of each sentence")
89+
args = p.parse_args()
90+
91+
with open(args.input_path, 'r') as f:
92+
grit_inputs = [json.loads(l) for l in f.readlines()]
93+
94+
all_processed_doc = dict()
95+
96+
# iterate thru and process all grit documents
97+
for grit_doc in grit_inputs:
98+
99+
processed = convert(grit_doc, args.capitalize)
100+
doc_id = processed.pop('docid')
101+
if processed['annotation']:
102+
all_processed_doc[doc_id] = processed
103+
104+
with open(args.output_path, 'w') as f:
105+
f.write(json.dumps(all_processed_doc))

convert_tf.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import argparse
2+
import json
3+
import nltk
4+
# these are for splitting doctext to sentences
5+
nltk.download('punkt')
6+
sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
7+
8+
def process_entities(entities):
9+
10+
'''
11+
[
12+
[
13+
['terrorists', 102]
14+
]
15+
]
16+
->[['terrorists']]
17+
18+
[
19+
[
20+
['farabundo marti national liberation front', 120],
21+
['fmln', 163]
22+
]
23+
]
24+
->[['farabundo marti national liberation front', 'fmln']]
25+
26+
'''
27+
28+
res = []
29+
for entity in entities:
30+
31+
# take only the string
32+
res.append([mention[0] for mention in entity])
33+
34+
return res
35+
36+
def convert(doc, capitalize=False):
37+
'''
38+
doc: a dictionary that has the following format:
39+
{'docid': 'DEV-MUC3-0001',
40+
'doctext': "the arce battalion command has reported that about 50 peasants of various ages have been kidnapped by terrorists of the farabundo marti national liberation front (fmln) in san miguel department. according to that garrison, the mass kidnapping took place on 30 december in san luis de la reina. the source added that the terrorists forced the individuals, who were taken to an unknown location, out of their residences, presumably to incorporate them against their will into clandestine groups. meanwhile, three subversives were killed and seven others were wounded during clashes yesterday in usulutan and morazan departments. the atonal battalion reported that one extremist was killed and five others were wounded during a clash yesterday afternoon near la esperanza farm, santa elena jurisdiction, usulutan department. it was also reported that a soldier was wounded and taken to the military hospital in this capital. the same military unit reported that there was another clash that resulted in one dead terrorist and the seizure of various kinds of war materiel near san rafael farm in the same town. in the country's eastern region, military detachment no.4 reported that a terrorist was killed and two others were wounded during a clash in la ranera stream, san carlos, morazan department. an m-16 rifle, cartridge clips, and ammunition were seized there. meanwhile, the 3d infantry brigade reported that ponce battalion units found the decomposed body of a subversive in la finca hill, san miguel. an m-16 rifle, five grenades, and material for the production of explosives were found in the same place. the brigade, which is headquartered in san miguel, added that the seizure was made yesterday morning. national guard units guarding the las canas bridge, which is on the northern trunk highway in apopa, this morning repelled a terrorist attack that resulted in no casualties. the armed clash involved mortar and rifle fire and lasted 30 minutes. members of that security group are combing the area to determine the final outcome of the fighting.",
41+
'templates': [{'incident_type': 'kidnapping',
42+
'PerpInd': [[['terrorists', 102]]],
43+
'PerpOrg': [[['farabundo marti national liberation front', 120], ['fmln', 163]]],
44+
'Target': [], 'Victim': [], 'Weapon': []},
45+
{'incident_type': 'attack',
46+
'PerpInd': [[['terrorist', 102]]],
47+
'PerpOrg': [],
48+
'Target': [[['las canas bridge', 1774]]],
49+
'Victim': [],
50+
'Weapon': [[['rifle', 1322]], [['mortar', 1940]]]}]}
51+
52+
capitalize: whether to capitalize doctext or not
53+
'''
54+
55+
res = {
56+
'docid': doc['docid'],
57+
'document': doc['doctext'], # the raw text document.
58+
'annotation': [] # A list of templates. In role-filler entity extraction, we only have one template for each don't care about this.
59+
}
60+
61+
if capitalize:
62+
# split doctext into sentences
63+
sentences = sent_tokenizer.tokenize(doc['doctext'])
64+
capitalized_doctext = ' '.join([sent.capitalize() for sent in sentences])
65+
res['document'] = capitalized_doctext
66+
67+
68+
69+
# TODO: add "tags" in the document
70+
# res['document'] = doc_text_no_n
71+
72+
annotation = doc['templates']
73+
for template in annotation:
74+
template_dic = {}
75+
for role, entities in template.items():
76+
# make sure entities is not an empty list
77+
if entities:
78+
# make sure res['annotation'] has one dictionary
79+
if role == "incident_type":
80+
template_dic[role] = entities
81+
else:
82+
template_dic[role] = process_entities(entities)
83+
if template_dic['incident_type'] in ['kidnapping', 'attack', 'bombing', "arson", 'robbery']:
84+
res['annotation'].append(template_dic)
85+
return res
86+
87+
if __name__ == '__main__':
88+
89+
p = argparse.ArgumentParser("Convert GRIT input data into ours format.")
90+
91+
p.add_argument('--input_path', default="./data/train.json", type=str, help="input file in GRIT format.")
92+
p.add_argument('--output_path', default="./data/tf_train.json", type=str, help="path to store the output json file.")
93+
p.add_argument('--capitalize',action="store_true", help="whether to capitalize the first char of each sentence")
94+
args = p.parse_args()
95+
96+
with open(args.input_path, 'r') as f:
97+
grit_inputs = [json.loads(l) for l in f.readlines()]
98+
99+
all_processed_doc = dict()
100+
101+
# iterate thru and process all grit documents
102+
for grit_doc in grit_inputs:
103+
104+
processed = convert(grit_doc, args.capitalize)
105+
doc_id = processed.pop('docid')
106+
all_processed_doc[doc_id] = processed
107+
108+
with open(args.output_path, 'w') as f:
109+
f.write(json.dumps(all_processed_doc))

0 commit comments

Comments
 (0)