Skip to content

Commit cb723d0

Browse files
initial commit
1 parent f49f0a0 commit cb723d0

31 files changed

+2386
-1
lines changed

.gitignore

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
data/
2+
saved_models/
3+
bert_ckpts/
4+
saved_reranking_models/
5+
pretrained_models/
6+
pretrained_reranking_models/
7+
8+
# Byte-compiled / optimized / DLL files
9+
__pycache__/
10+
*.py[cod]
11+
*$py.class
12+
13+
# C extensions
14+
*.so
15+
16+
# Distribution / packaging
17+
.Python
18+
build/
19+
develop-eggs/
20+
dist/
21+
downloads/
22+
eggs/
23+
.eggs/
24+
lib/
25+
lib64/
26+
parts/
27+
sdist/
28+
var/
29+
wheels/
30+
share/python-wheels/
31+
*.egg-info/
32+
.installed.cfg
33+
*.egg
34+
MANIFEST
35+
36+
# PyInstaller
37+
# Usually these files are written by a python script from a template
38+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
39+
*.manifest
40+
*.spec
41+
42+
# Installer logs
43+
pip-log.txt
44+
pip-delete-this-directory.txt
45+
46+
# Unit test / coverage reports
47+
htmlcov/
48+
.tox/
49+
.nox/
50+
.coverage
51+
.coverage.*
52+
.cache
53+
nosetests.xml
54+
coverage.xml
55+
*.cover
56+
*.py,cover
57+
.hypothesis/
58+
.pytest_cache/
59+
cover/
60+
61+
# Translations
62+
*.mo
63+
*.pot
64+
65+
# Django stuff:
66+
*.log
67+
local_settings.py
68+
db.sqlite3
69+
db.sqlite3-journal
70+
71+
# Flask stuff:
72+
instance/
73+
.webassets-cache
74+
75+
# Scrapy stuff:
76+
.scrapy
77+
78+
# Sphinx documentation
79+
docs/_build/
80+
81+
# PyBuilder
82+
.pybuilder/
83+
target/
84+
85+
# Jupyter Notebook
86+
.ipynb_checkpoints
87+
88+
# IPython
89+
profile_default/
90+
ipython_config.py
91+
92+
# pyenv
93+
# For a library or package, you might want to ignore these files since the code is
94+
# intended to run in multiple environments; otherwise, check them in:
95+
# .python-version
96+
97+
# pipenv
98+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
100+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
101+
# install all needed dependencies.
102+
#Pipfile.lock
103+
104+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
105+
__pypackages__/
106+
107+
# Celery stuff
108+
celerybeat-schedule
109+
celerybeat.pid
110+
111+
# SageMath parsed files
112+
*.sage.py
113+
114+
# Environments
115+
.env
116+
.venv
117+
env/
118+
venv/
119+
ENV/
120+
env.bak/
121+
venv.bak/
122+
123+
# Spyder project settings
124+
.spyderproject
125+
.spyproject
126+
127+
# Rope project settings
128+
.ropeproject
129+
130+
# mkdocs documentation
131+
/site
132+
133+
# mypy
134+
.mypy_cache/
135+
.dmypy.json
136+
dmypy.json
137+
138+
# Pyre type checker
139+
.pyre/
140+
141+
# pytype static type analyzer
142+
.pytype/
143+
144+
# Cython debug symbols
145+
cython_debug/

CONSTANTS.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
SNOMED_CORE_DIR = 'data/UMLS/SNOMEDCT_CORE_SUBSET_202008'
2+
DATA_DIR = {'SNOMED_CT_CORE': 'data/SNOMED-CT-Core/',
3+
'FB15K_237': 'data/FB15K-237/',
4+
'FB15K_237_SPARSE': 'data/FB15K-237-Sparse/',
5+
'CN100K': 'data/CN100K/'}
6+
COLUMN_NAMES = {'SNOMED_CT_CORE': ('CUI2_id', 'RELA_id', 'CUI1_id'),
7+
'FB15K_237': ('entity1_id', 'rel_id', 'entity2_id'),
8+
'FB15K_237_SPARSE': ('entity1_id', 'rel_id', 'entity2_id'),
9+
'CN100K': ('entity1_id', 'rel_id', 'entity2_id')}
10+
UMLS_SOURCE_DIR = 'data/UMLS/2020AA/META'

README.md

+48-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,48 @@
1-
# robust-kg-completion
1+
# Robust Knowledge Graph Completion with Stacked Convolutions and a Student Re-Ranking Network
2+
3+
4+
This repository contains the implementation for our paper:
5+
6+
**Robust Knowledge Graph Completion with Stacked Convolutions and a Student Re-Ranking Network** \
7+
Justin Lovelace, Denis Newman-Griffis, Shikhar Vashishth, Jill Fain Lehman, and Carolyn Penstein Rosé \
8+
Annual Meeting of the Association for Computational Linguistics and the International Joint Conference on Natural Language Processing
9+
(**ACL-IJCNLP**) 2021
10+
11+
## Dependencies
12+
13+
Our work was performed with Python 3.8. The dependencies can be installed from `requirements.txt`.
14+
15+
## Data Preparation
16+
17+
- We conduct our work upon the existing FB15K-237 and CN100K datasets. We additionally developed the FB15K-237-Sparse and SNOMED-CT Core datasets for our work.
18+
- Running `./scripts/prepare_datasets.sh` will unzip the dataset files and process them for use by our models.
19+
- Because the SNOMED-CT Core dataset was derived from the UMLS, we cannot directly release the dataset files. See [here](snomed_ct_core.md) for full instructions for how to recreate the dataset.
20+
- The BERT embeddings can be downloaded from [here](https://drive.google.com/drive/folders/1gfbZcJoay69BUzQLQku-qHB6IZ5zxLQw?usp=sharing). The `bert_emb.pt` files should be stored in the corresponding dataset directories, e.g. `data/CN100K/bert_emb.pt`
21+
22+
## Training Ranking Models
23+
We provide scripts to train our proposed ranking model, denoted as BERT-ResNet in our paper, for all four datasets.
24+
25+
- FB15K-237: `./scripts/train_resnet_fb15k237.sh`
26+
- FB15K-237-Sparse: `./scripts/train_resnet_fb15k237_sparse.sh`
27+
- CN100K: `./scripts/train_resnet_cn100k.sh`
28+
- SNOMED-CT Core `./scripts/train_resnet_snomed.sh`
29+
30+
## Training Re-Ranking Models
31+
The re-ranking models can only be trained after the ranking model for the corresponding dataset has already finished training. First, download the BERT checkpoints used for our training from [here](https://drive.google.com/drive/folders/1BsxeWEtFpZuHD_bCQKsIy0zfwl7xi6Bq?usp=sharing). They should be unzipped and stored in `reranking/bert_ckpts`. A re-ranking model can then be trained with the provided scripts similarly to above.
32+
33+
- FB15K-237: `./scripts/train_reranking_fb15k237.sh`
34+
- FB15K-237-Sparse: `./scripts/train_reranking_fb15k237_sparse.sh`
35+
- CN100K: `./scripts/train_reranking_cn100k.sh`
36+
- SNOMED-CT Core `./scripts/train_reranking_snomed.sh`
37+
38+
## Evaluating Pretrained Ranking Models
39+
Pretrained ranking models can be downloaded from [here](https://drive.google.com/drive/folders/1q20hhUq20wt5OSbHbOWsvAiviFUZ8s8r?usp=sharing). After unzipping them in a `robust-kg-completion/pretrained_models` directory, they can be evaluated by running `./scripts/eval_pretrained_ranking_model.sh {DATASET}` where `{DATASET}` is one of `SNOMED_CT_CORE`, `FB15K_237`, `FB15K_237_SPARSE`, or `CN100K`.
40+
41+
## Evaluating Pretrained Re-Ranking Models
42+
Pretrained re-ranking models can be downloaded from [here](https://drive.google.com/drive/folders/1q20hhUq20wt5OSbHbOWsvAiviFUZ8s8r?usp=sharing). After unzipping them in a `robust-kg-completion/reranking/pretrained_reranking_models` directory, they can be evaluated by running the following commands.
43+
44+
- FB15K-237: `./scripts/eval_pretrained_reranking_model.sh FB15K_237 0.75`
45+
- FB15K-237-Sparse: `./scripts/eval_pretrained_reranking_model.sh FB15K_237_SPARSE 0.75`
46+
- CN100K: `./scripts/eval_pretrained_reranking_model.sh CN100K 1.0`
47+
- SNOMED-CT Core `./scripts/eval_pretrained_reranking_model.sh SNOMED_CT_CORE 0.5`
48+

dataset.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import pandas as pd
2+
import numpy as np
3+
import torch
4+
from torch.utils.data import Dataset, DataLoader
5+
import argparse
6+
import os
7+
from collections import defaultdict
8+
import sys
9+
from CONSTANTS import DATA_DIR, COLUMN_NAMES
10+
11+
12+
class KbDataset(Dataset):
13+
def __init__(self, args: argparse.ArgumentParser):
14+
self.args = args
15+
csv_file = os.path.join(DATA_DIR[args.dataset], f'df_train.csv')
16+
e1_col, rel_col, e2_col = COLUMN_NAMES[args.dataset]
17+
print(f'Loading dataset from {csv_file}')
18+
df_kg = pd.read_csv(csv_file)
19+
self.pos_samples = df_kg[[e1_col,
20+
rel_col, e2_col]].to_numpy(np.int64)
21+
# print(self.pos_samples.shape)
22+
if args.strategy == 'one_to_n':
23+
self.gen_one_to_n_data(df_kg)
24+
elif args.strategy == 'k_to_n' or args.strategy == 'gen_triplets':
25+
self.gen_k_to_n_data(df_kg)
26+
elif args.strategy == 'softmax':
27+
self.gen_softmax_data(df_kg)
28+
else:
29+
raise NotImplementedError
30+
31+
def gen_one_to_n_data(self, df_kg):
32+
e1_col, rel_col, e2_col = COLUMN_NAMES[self.args.dataset]
33+
self.data = df_kg[[
34+
e1_col, rel_col]].to_numpy(np.int64)
35+
self.labels = np.zeros(
36+
(len(self.data), self.args.num_entities), dtype=np.float32)
37+
self.labels[np.arange(start=0, stop=self.labels.shape[0],
38+
step=1), df_kg[e2_col].to_numpy(np.int64)] = 1
39+
40+
def gen_softmax_data(self, df_kg):
41+
e1_col, rel_col, e2_col = COLUMN_NAMES[self.args.dataset]
42+
self.data = df_kg[[
43+
e1_col, rel_col]].to_numpy(np.int64)
44+
self.labels = df_kg[e2_col].to_numpy(np.int64)
45+
46+
def gen_k_to_n_data(self, df_kg):
47+
e1_col, rel_col, e2_col = COLUMN_NAMES[self.args.dataset]
48+
self.data = df_kg[[
49+
e1_col, rel_col]].drop_duplicates().to_numpy(np.int64)
50+
e2_lookup = defaultdict(set)
51+
for e1, r, e2 in zip(df_kg[e1_col], df_kg[rel_col], df_kg[e2_col]):
52+
e2_lookup[(e1, r)].add(e2)
53+
self.labels = np.zeros(
54+
(len(self.data), self.args.num_entities), dtype=np.float32)
55+
for idx, query in enumerate(self.data):
56+
e1, r = query[0], query[1]
57+
for e2 in e2_lookup[(e1, r)]:
58+
self.labels[idx, e2] = 1
59+
60+
def __len__(self):
61+
return self.data.shape[0]
62+
63+
def __getitem__(self, idx):
64+
return self.data[idx], self.labels[idx]
65+
66+
67+
class KbEvalGenerator(Dataset):
68+
def __init__(self, eval_split: str, args: argparse.ArgumentParser):
69+
self.args = args
70+
pos_samples = defaultdict(set)
71+
splits = ['train', 'valid', 'test']
72+
e1_col, rel_col, e2_col = COLUMN_NAMES[args.dataset]
73+
assert eval_split in splits
74+
for spl in splits:
75+
csv_file = os.path.join(DATA_DIR[args.dataset], f'df_{spl}.csv')
76+
df_data = pd.read_csv(csv_file)
77+
if spl == eval_split:
78+
self.queries = df_data[[
79+
e1_col, rel_col]].to_numpy(np.int64)
80+
e2_list = df_data[e2_col].tolist()
81+
for e1, r, e2 in zip(df_data[e1_col], df_data[rel_col], df_data[e2_col]):
82+
pos_samples[(e1, r)].add(e2)
83+
self.labels = np.zeros((self.queries.shape[0], self.args.num_entities))
84+
self.filtered_labels = np.zeros(
85+
(self.queries.shape[0], self.args.num_entities))
86+
for i, query in enumerate(self.queries):
87+
e1, r = query[0], query[1]
88+
e2 = e2_list[i]
89+
self.labels[i, e2] = 1
90+
self.filtered_labels[i, list(pos_samples[(e1, r)] - {e2})] = 1
91+
92+
def __len__(self):
93+
return self.queries.shape[0]
94+
95+
def __getitem__(self, idx):
96+
return self.queries[idx], self.labels[idx], self.filtered_labels[idx]

0 commit comments

Comments
 (0)