Skip to content

Commit

Permalink
Adjust code to train a spaCy model to spaCy version 3.0
Browse files Browse the repository at this point in the history
the old code was not running anymore for the new version.
There must be a problem with the new code as well however, since the
training loss is always 0. I was unable to figure out why until now.
I followed the instructions in the official spaCy entity linking
tutorial here:
https://github.com/explosion/projects/blob/v3/tutorials/nel_emerson/notebooks/notebook_video.ipynb

Relevant to #14
  • Loading branch information
flackbash committed Nov 29, 2023
1 parent f9a2531 commit f4c98d6
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 130 deletions.
6 changes: 3 additions & 3 deletions create_entity_word_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def preprocess_description(description: str) -> str:


def save_vectors(vectors, file_no):
path = settings.VECTORS_DIRECTORY + "%i.pkl" % file_no
path = settings.VECTORS_DIRECTORY + "vectors%i.pkl" % file_no
logger.info("Saving vectors to %s" % path)
with open(path, "wb") as f:
pickle.dump(vectors, f)
Expand Down Expand Up @@ -62,9 +62,9 @@ def main(args):
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__)

parser.add_argument("min_score", type=int, required=True,
parser.add_argument("min_score", type=int,
help="Minimum score.")
parser.add_argument("start_line", type=int, default=0,
parser.add_argument("--start_line", type=int, default=0,
help="Start line.")

logger = log.setup_logger(sys.argv[0])
Expand Down
4 changes: 2 additions & 2 deletions create_knowledge_base_wikipedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def main():
logger.info("Creating directory %s ..." % save_path)
os.mkdir(save_path)

kb_path = save_path + "kb"
kb_path = save_path
logger.info("Writing knowledge base to %s ..." % kb_path)
kb.dump(kb_path)
kb.to_disk(kb_path)
vocab_file = save_path + "vocab"
logger.info("Writing vocab to %s ..." % vocab_file)
kb.vocab.to_disk(vocab_file)
Expand Down
32 changes: 7 additions & 25 deletions src/helpers/entity_linker_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,15 @@ class EntityLinkerLoader:
@staticmethod
def load_trained_linker(name: str, kb_name: Optional[str] = None):
logger.info("Loading linker model...")

model = spacy.load("en_core_web_lg")

# Load the previously saved entity linker
path = settings.SPACY_MODEL_DIRECTORY + name
with open(path, "rb") as f:
model_bytes = f.read()
model = spacy.blank("en")
model: Language
pipeline = ['tagger', 'parser', 'ner', 'entity_linker']
for pipe_name in pipeline:
pipe = model.create_pipe(pipe_name)
model.add_pipe(pipe)
model.from_bytes(model_bytes)
logger.info("-> Linker model loaded.")
model.add_pipe("entity_linker")
model.get_pipe("entity_linker").from_disk(path)

logger.info("Loading knowledge base...")
if kb_name is None:
vocab_path = settings.VOCAB_DIRECTORY
kb_path = settings.KB_FILE
else:
load_path = settings.KB_DIRECTORY + kb_name + "/"
vocab_path = load_path + "vocab"
kb_path = load_path + "kb"
vocab = Vocab().from_disk(vocab_path)
kb = KnowledgeBase(vocab=vocab, entity_vector_length=vocab.vectors.shape[1])
kb.load_bulk(kb_path)
model.get_pipe("entity_linker").set_kb(kb)
logger.info("-> Knowledge base loaded.")

model.disable_pipes(["tagger"])
logger.info("-> Linker model loaded.")

return model

Expand Down
4 changes: 2 additions & 2 deletions src/helpers/knowledge_base_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def create_kb_wikipedia() -> KnowledgeBase:
logger.info("-> Spacy model loaded.")

logger.info("Load vectors...")
vector_dir = settings.VECTORS_ABSTRACTS_DIRECTORY
vector_dir = settings.VECTORS_DIRECTORY
for entity_id, vector in VectorLoader.iterate(vector_dir):
frequency = entity_db.get_entity_frequency(entity_id)
kb.add_entity(entity=entity_id, freq=frequency, entity_vector=vector)
logger.info("-> Vectors loaded. Knowledge base contains %d entities." % len(kb))

logger.info("Adding aliases...")
for alias in entity_db.aliases:
for alias in entity_db.link_aliases:
if len(alias) > 0:
alias_entity_ids = [entity_id for entity_id in entity_db.get_candidates(alias)
if kb.contains_entity(entity_id)]
Expand Down
31 changes: 18 additions & 13 deletions src/helpers/label_generator.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from typing import Iterator, Tuple, Dict, Optional

import spacy
from spacy.training import Example
from spacy.kb import KnowledgeBase
from spacy.language import Language, Doc
from spacy.language import Doc
from spacy.pipeline import Sentencizer

from src import settings
from src.helpers.wikipedia_corpus import WikipediaCorpus


class LabelGenerator:
DISABLE_TRAINING = ["entity_linker", "tagger"]

def __init__(self,
model: Language,
kb: KnowledgeBase,
mapping: Dict[str, str]):
def __init__(self, kb: KnowledgeBase, mapping: Dict[str, str]):
"""
Provide data for the generation of entity linker training examples.
:param model: The language pipeline to be used to generate spacy documents. Used for entity recognition.
:param kb: The knowledge base to link to. Candidate and ground truth entities come from the knowledge base.
:param mapping: Mapping between link targets and entity IDs.
"""
self.model = model
self.model = spacy.load(settings.LARGE_MODEL_NAME)
self.model.add_pipe("sentencizer")
self.kb = kb
self.mapping = mapping
self.sentencizer = Sentencizer()

def read_examples(self,
n: int = -1,
Expand All @@ -44,7 +46,7 @@ def read_examples(self,
:param test: articles are read from the development instead of the training set
:return: iterator over training examples
"""
disable_training = [pipe for pipe in LabelGenerator.DISABLE_TRAINING if pipe in self.model.pipeline]
# disable_training = [pipe for pipe in LabelGenerator.DISABLE_TRAINING if pipe in self.model.pipeline]
iterator = WikipediaCorpus.training_articles(n) if not test else WikipediaCorpus.development_articles(n)
for article in iterator:
link_dict = {}
Expand All @@ -56,20 +58,23 @@ def read_examples(self,
if self.kb.contains_entity(entity_id):
begin, end = span
snippet = article.text[begin:end]
candidate_entities = [candidate.entity_ for candidate in self.kb.get_candidates(snippet)]
candidate_entities = [candidate.entity_ for candidate in self.kb.get_alias_candidates(snippet)]
# check if the ground truth entity is in the candidates:
if entity_id in candidate_entities:
# generate ground truth labels for all candidates:
link_dict[span] = {candidate_id: 1.0 if candidate_id == entity_id else 0.0
for candidate_id in candidate_entities}
# skip if no link could be mapped:
if len(link_dict) > 0:
with self.model.disable_pipes(disable_training):
doc = self.model(article.text)
doc = self.model(article.text)
ner_dict = {(ent.start_char, ent.end_char): ent.label for ent in doc.ents}
# filter entities not recognized by the NER:
doc_entity_spans = {(e.start_char, e.end_char) for e in doc.ents}
link_dict = {span: link_dict[span] for span in link_dict if span in doc_entity_spans}
entities = [(span[0], span[1], ner_dict.get(span, "UNK")) for span in link_dict]
# skip if no entities remain after filtering:
if len(link_dict) > 0:
labels = {"links": link_dict}
yield doc, labels
annotation = {"links": link_dict, "entities": entities}
example = Example.from_dict(self.model.make_doc(article.text), annotation)
example.reference = self.sentencizer(example.reference)
yield example
2 changes: 1 addition & 1 deletion src/linkers/spacy_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def predict(self,
return predictions

def get_candidates(self, snippet: str) -> Set[str]:
return {candidate.entity_ for candidate in self.kb.get_candidates(snippet)}
return {candidate.entity_ for candidate in self.kb.get_alias_candidates(snippet)}

def contains_entity(self, entity_id: str) -> bool:
return entity_id in self.known_entities
2 changes: 1 addition & 1 deletion src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
KB_FILE = DATA_DIRECTORY + "linker_files/spacy/knowledge_bases/wikidata/kb"
KB_DIRECTORY = DATA_DIRECTORY + "linker_files/spacy/knowledge_bases/"
VOCAB_DIRECTORY = DATA_DIRECTORY + "linker_files/spacy/knowledge_bases/wikidata/vocab"
VECTORS_DIRECTORY = DATA_DIRECTORY + "linker_files/spacy/knowledge_bases/vectors/vectors"
VECTORS_DIRECTORY = DATA_DIRECTORY + "linker_files/spacy/knowledge_bases/vectors/"
VECTORS_ABSTRACTS_DIRECTORY = DATA_DIRECTORY + "linker_files/spacy/knowledge_bases/vectors/vectors_abstracts/"

# Linker files
Expand Down
121 changes: 38 additions & 83 deletions train_spacy_entity_linker.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,66 @@
import argparse
import log
import sys
import spacy
from spacy.kb import KnowledgeBase
from spacy.language import Language

from src import settings
from src.helpers.entity_database_reader import EntityDatabaseReader
from src.helpers.label_generator import LabelGenerator


def save_model(model: Language, model_name: str):
path = settings.SPACY_MODEL_DIRECTORY + model_name
model_bytes = model.to_bytes()
with open(path, "wb") as f:
f.write(model_bytes)
logger.info("Saved model to %s" % path)


PRINT_EVERY = 1
SAVE_EVERY = 10000


def main(args):
# make pipeline:
if args.kb_name == "0":
vocab_path = settings.VOCAB_DIRECTORY
kb_path = settings.KB_FILE
else:
load_path = settings.KB_DIRECTORY + args.kb_name + "/"
vocab_path = load_path + "vocab"
kb_path = load_path + "kb"
def train():
load_path = settings.KB_DIRECTORY + "wikipedia/"
vocab_path = load_path + "vocab"
kb_path = load_path

logger.info("Loading model ...")
nlp = spacy.load(settings.LARGE_MODEL_NAME)
nlp.vocab.from_disk(vocab_path)
# kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=300)
# kb.from_disk(kb_path)
# entity_linker = nlp.create_pipe("entity_linker", {"incl_prior": args.prior})

def create_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=300)
kb.from_disk(kb_path)
logger.info("Knowledge base contains %d entities." % kb.get_size_entities())
logger.info("Knowledge base contains %d aliases." % kb.get_size_aliases())
return kb

# create entity linker with the knowledge base and add it to the pipeline:
entity_linker = nlp.add_pipe("entity_linker", config={"incl_prior": False}, last=True)
logger.info("Loading knowledge base ...")
entity_linker = nlp.create_pipe("entity_linker", {"incl_prior": args.prior})
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_path)
logger.info("Knowledge base contains %d entities." % kb.get_size_entities())
logger.info("Knowledge base contains %d aliases." % kb.get_size_aliases())
entity_linker.set_kb(kb)
nlp.add_pipe(entity_linker, last=True)

pipe_exceptions = ["entity_linker", "trf_wordpiecer", "trf_tok2vec"]
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]

# initialize model:
optimizer = nlp.begin_training()

# initialize label generator:
entity_linker.set_kb(create_kb)
kb = nlp.get_pipe("entity_linker").kb
mapping = EntityDatabaseReader.get_wikipedia_to_wikidata_db()
generator = LabelGenerator(nlp, kb, mapping)

# iterate over training examples (batch size 1):
logger.info("Training ...")
n_batches = 0
n_articles = 0
n_entities = 0
loss_sum = 0
if args.n_batches != 0:
for doc, labels in generator.read_examples():
batch_docs = [doc]
batch_labels = [labels]
generator = LabelGenerator(kb, mapping)
example = next(generator.read_examples())
entity_linker.initialize(get_examples=lambda: [example])
print(f"hyperparameters: {entity_linker.cfg}")

from spacy.util import minibatch, compounding

with nlp.select_pipes(enable=["entity_linker"]): # train only the entity_linker
optimizer = nlp.resume_training()
print(nlp.get_pipe("entity_linker").cfg)
for itn in range(1):
batches = minibatch(generator.read_examples(n=100), size=1)
losses = {}
with nlp.disable_pipes(*other_pipes):
for batch in batches:
nlp.update(
batch_docs,
batch_labels,
batch,
drop=0.2, # prevent overfitting
losses=losses,
sgd=optimizer,
losses=losses
)
n_batches += 1
n_articles += len(batch_docs)
n_entities += len(labels["links"])
loss = losses["entity_linker"]
loss_sum += loss
if n_batches % PRINT_EVERY == 0:
loss_mean = loss_sum / n_batches
print("\r%i batches\t%i articles\t%i entities\tloss: %f\tmean: %f" %
(n_batches, n_articles, n_entities, loss, loss_mean), end='')
if n_batches == args.n_batches:
break
elif n_batches % SAVE_EVERY == 0:
print()
save_model(nlp, args.name)
print()
save_model(nlp, args.name)
if itn % 10 == 0:
print(itn, "Losses", losses) # print the training loss
print(itn, "Losses", losses)
entity_linker.to_disk(settings.SPACY_MODEL_DIRECTORY + "spacy_batch1_model")


if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__)

parser.add_argument("name", type=str,
help="Linker name.")
parser.add_argument("n_batches", type=int,
help="Number of batches.")
parser.add_argument("kb_name", type=str,
help="KB name.")
parser.add_argument("-p", "--prior", type=str, action="store_true",
help="Use prior probabilities.")

logger = log.setup_logger(sys.argv[0])
logger.debug(' '.join(sys.argv))

main(parser.parse_args())
train()

0 comments on commit f4c98d6

Please sign in to comment.