Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The content of this page covers the following topics:
- The project is installed as follows:

```
git clone https://github.com/samuelbroscheit/open_link_prediction_benchmark.git
git clone https://github.com/samuelbroscheit/open_knowledge_graph_embeddings.git
cd open_link_prediction_benchmark
pip install -r requirements.txt
```
Expand Down Expand Up @@ -73,6 +73,15 @@ All top level options can also be set on the command line and override the yaml

If you run training on a dataset the first time some indexes will be created and cached. For OLPBENCH this can take around 30 minutes and up to 10-20 GB of main memory! After the cached files are created the startup takes under 1 minute.

For example, a working command to train is:

```
python scripts/train.py config/acl2020-openlink/wikiopenlink-thorough-complex-lstm.yaml
```

_--resume_ expects the path to a checkpoint file. Checkpoints of the current state and also the best model(s) w.r.t. a model selection metric are saved during training within _data/experiments_ by default. If you are resuming from a checkpoint, note that the number of _epochs_ in the config needs to be more than the number of epochs the checkpoint was already trained for (and not the extra number of epochs). You can set the _--epochs_ and _--resume_ on command line or set them manually in the config file.

The output log file is present in the model's _data/experiments_ subdirectory. Each _resume_ creates a new log file in the same directory.

##### Prepared configurations

Expand All @@ -98,11 +107,9 @@ An example standard KGE model for the the Freebase FB15k-237 benchmark.
Run evaluation after training on test data with:

```
python scripts/train.py --resume data/experiments/.../checkpoint.pth.tar --evaluate True --evaluate_on_validation False
python scripts/train.py config/acl2020-openlink/wikiopenlink-thorough-complex-lstm.yaml --resume data/experiments/.../checkpoint.pth.tar --evaluate True --evaluate_on_validation False
```

_--resume_ expects the path to a checkpoint file. Checkpoints of the current state and also the best model(s) w.r.t. a model selection metric are saved during training.

_--evaluate_on_validation False_ sets the evaluation to run on test data


Expand Down
14 changes: 7 additions & 7 deletions config/acl2020-openlink/wikiopenlink-thorough-complex-lstm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ evaluate: false
############### MODEL

# configure the model class
model: LSTMComplexRelationModel
model: LookupComplexRelationModel
# configure the model's arguments
model_config:
dropout: 0.1
entity_slot_size: 512
entity_slot_size: 384
init_std: 0.1
normalize: batchnorm
relation_slot_size: 512
relation_slot_size: 384
sparse: false

experiment_settings:
Expand All @@ -43,9 +43,9 @@ experiment_settings:
############### TRAINING

# max epochs to run
epochs: 100
epochs: 15
# batch size
batch_size: 4096
batch_size: 512
# label smoothing for BCE loss
bce_label_smoothing: 0.0
# learning rate schduler config kwargs dict to tunnel through to pytorch;
Expand Down Expand Up @@ -84,7 +84,7 @@ model_select_metric:
patience_metric_change: 1.0e-05
patience_metric_max_treshold: null
patience_metric_min_treshold: null
patience_epochs: 50
patience_epochs: 5


############### DATASET
Expand Down Expand Up @@ -150,7 +150,7 @@ training_dataset_class: OneToNMentionRelationDataset
# training data settings
train_data_config:
input_file: train_data_thorough.txt
batch_size: 4096 # if batch size undefined here then global batch size is used
batch_size: 512 # if batch size undefined here then global batch size is used
use_batch_shared_entities: True
min_size_batch_labels: 4096
max_size_prefix_label: 64
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ torch>=1.10.0
elasticsearch
avro
numpy
pandas
pandas==1.5.3
tqdm
pyyaml
15 changes: 15 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def main(args, hyper_setting='', time_stamp=datetime.now().strftime('%Y-%m-%d_%H
args["model_config"]['train_data'] = train_data.get_dataset_meta_dict()

model = getattr(Models, args["model"])(**args["model_config"])

#FP16 precision
model.half()
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()

logging.info(model)

# define data loaders
Expand All @@ -96,12 +103,20 @@ def main(args, hyper_setting='', time_stamp=datetime.now().strftime('%Y-%m-%d_%H
drop_last=True,
)

#FP16
for batch_size, inputs in enumerate(train_loader):
inputs = inputs.to(device).half()

val_loader = evaluation_data.get_loader(
shuffle=False,
num_workers=0,
drop_last=False,
)

#FP16
for batch_size, inputs in enumerate(val_loader):
inputs = inputs.to(device).half()

# create trainer

trainer = Trainer(
Expand Down