diff --git a/requirements.txt b/requirements.txt
index ca4099be5..b8970f686 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
# Our libraries
-spacy-legacy>=3.0.8,<3.1.0
+spacy-legacy>=3.0.9,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
diff --git a/setup.cfg b/setup.cfg
index 586a044ff..ed3bf63ce 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -41,7 +41,7 @@ setup_requires =
thinc>=8.0.12,<8.1.0
install_requires =
# Our libraries
- spacy-legacy>=3.0.8,<3.1.0
+ spacy-legacy>=3.0.9,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja
index fb79a4f60..da533b767 100644
--- a/spacy/cli/templates/quickstart_training.jinja
+++ b/spacy/cli/templates/quickstart_training.jinja
@@ -131,7 +131,7 @@ incl_context = true
incl_prior = true
[components.entity_linker.model]
-@architectures = "spacy.EntityLinker.v1"
+@architectures = "spacy.EntityLinker.v2"
nO = null
[components.entity_linker.model.tok2vec]
@@ -303,7 +303,7 @@ incl_context = true
incl_prior = true
[components.entity_linker.model]
-@architectures = "spacy.EntityLinker.v1"
+@architectures = "spacy.EntityLinker.v2"
nO = null
[components.entity_linker.model.tok2vec]
diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py
index edc86ff9c..d5e9bc07c 100644
--- a/spacy/ml/extract_spans.py
+++ b/spacy/ml/extract_spans.py
@@ -63,4 +63,4 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]:
- return (Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths))
+ return Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths)
diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py
index 831fee90f..0149bea89 100644
--- a/spacy/ml/models/entity_linker.py
+++ b/spacy/ml/models/entity_linker.py
@@ -1,34 +1,82 @@
from pathlib import Path
-from typing import Optional, Callable, Iterable, List
+from typing import Optional, Callable, Iterable, List, Tuple
from thinc.types import Floats2d
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
-from thinc.api import Model, Maxout, Linear
+from thinc.api import Model, Maxout, Linear, noop, tuplify, Ragged
from ...util import registry
from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab
from ...tokens import Span, Doc
+from ..extract_spans import extract_spans
+from ...errors import Errors
-@registry.architectures("spacy.EntityLinker.v1")
+@registry.architectures("spacy.EntityLinker.v2")
def build_nel_encoder(
tok2vec: Model, nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
- with Model.define_operators({">>": chain, "**": clone}):
+ with Model.define_operators({">>": chain, "&": tuplify}):
token_width = tok2vec.maybe_get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width)
model = (
- tok2vec
- >> list2ragged()
+ ((tok2vec >> list2ragged()) & build_span_maker())
+ >> extract_spans()
>> reduce_mean()
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
>> output_layer
)
model.set_ref("output_layer", output_layer)
model.set_ref("tok2vec", tok2vec)
+ # flag to show this isn't legacy
+ model.attrs["include_span_maker"] = True
return model
+def build_span_maker(n_sents: int = 0) -> Model:
+ model: Model = Model("span_maker", forward=span_maker_forward)
+ model.attrs["n_sents"] = n_sents
+ return model
+
+
+def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callable]:
+ ops = model.ops
+ n_sents = model.attrs["n_sents"]
+ candidates = []
+ for doc in docs:
+ cands = []
+ try:
+ sentences = [s for s in doc.sents]
+ except ValueError:
+ # no sentence info, normal in initialization
+ for tok in doc:
+ tok.is_sent_start = tok.i == 0
+ sentences = [doc[:]]
+ for ent in doc.ents:
+ try:
+ # find the sentence in the list of sentences.
+ sent_index = sentences.index(ent.sent)
+ except AttributeError:
+ # Catch the exception when ent.sent is None and provide a user-friendly warning
+ raise RuntimeError(Errors.E030) from None
+ # get n previous sentences, if there are any
+ start_sentence = max(0, sent_index - n_sents)
+ # get n posterior sentences, or as many < n as there are
+ end_sentence = min(len(sentences) - 1, sent_index + n_sents)
+ # get token positions
+ start_token = sentences[start_sentence].start
+ end_token = sentences[end_sentence].end
+ # save positions for extraction
+ cands.append((start_token, end_token))
+
+ candidates.append(ops.asarray2i(cands))
+ candlens = ops.asarray1i([len(cands) for cands in candidates])
+ candidates = ops.xp.concatenate(candidates)
+ outputs = Ragged(candidates, candlens)
+ # because this is just rearranging docs, the backprop does nothing
+ return outputs, lambda x: []
+
+
@registry.misc("spacy.KBFromFile.v1")
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
def kb_from_file(vocab):
diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py
index 1169e898d..89e7576bf 100644
--- a/spacy/pipeline/entity_linker.py
+++ b/spacy/pipeline/entity_linker.py
@@ -6,17 +6,17 @@ import srsly
import random
from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
-import warnings
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc, Span
from .pipe import deserialize_config
+from .legacy.entity_linker import EntityLinker_v1
from .trainable_pipe import TrainablePipe
from ..language import Language
from ..vocab import Vocab
from ..training import Example, validate_examples, validate_get_examples
-from ..errors import Errors, Warnings
+from ..errors import Errors
from ..util import SimpleFrozenList, registry
from .. import util
from ..scorer import Scorer
@@ -26,7 +26,7 @@ BACKWARD_OVERWRITE = True
default_model_config = """
[model]
-@architectures = "spacy.EntityLinker.v1"
+@architectures = "spacy.EntityLinker.v2"
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2"
@@ -55,6 +55,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
+ "use_gold_ents": True,
},
default_score_weights={
"nel_micro_f": 1.0,
@@ -75,6 +76,7 @@ def make_entity_linker(
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
overwrite: bool,
scorer: Optional[Callable],
+ use_gold_ents: bool,
):
"""Construct an EntityLinker component.
@@ -90,6 +92,22 @@ def make_entity_linker(
produces a list of candidates, given a certain knowledge base and a textual mention.
scorer (Optional[Callable]): The scoring method.
"""
+
+ if not model.attrs.get("include_span_maker", False):
+ # The only difference in arguments here is that use_gold_ents is not available
+ return EntityLinker_v1(
+ nlp.vocab,
+ model,
+ name,
+ labels_discard=labels_discard,
+ n_sents=n_sents,
+ incl_prior=incl_prior,
+ incl_context=incl_context,
+ entity_vector_length=entity_vector_length,
+ get_candidates=get_candidates,
+ overwrite=overwrite,
+ scorer=scorer,
+ )
return EntityLinker(
nlp.vocab,
model,
@@ -102,6 +120,7 @@ def make_entity_linker(
get_candidates=get_candidates,
overwrite=overwrite,
scorer=scorer,
+ use_gold_ents=use_gold_ents,
)
@@ -136,6 +155,7 @@ class EntityLinker(TrainablePipe):
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score,
+ use_gold_ents: bool,
) -> None:
"""Initialize an entity linker.
@@ -152,6 +172,8 @@ class EntityLinker(TrainablePipe):
produces a list of candidates, given a certain knowledge base and a textual mention.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_links.
+ use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
+ component must provide entity annotations.
DOCS: https://spacy.io/api/entitylinker#init
"""
@@ -169,6 +191,7 @@ class EntityLinker(TrainablePipe):
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
self.kb = empty_kb(entity_vector_length)(self.vocab)
self.scorer = scorer
+ self.use_gold_ents = use_gold_ents
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
@@ -212,14 +235,48 @@ class EntityLinker(TrainablePipe):
doc_sample = []
vector_sample = []
for example in islice(get_examples(), 10):
- doc_sample.append(example.x)
+ doc = example.x
+ if self.use_gold_ents:
+ doc.ents = example.y.ents
+ doc_sample.append(doc)
vector_sample.append(self.model.ops.alloc1f(nO))
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
+
+ # XXX In order for size estimation to work, there has to be at least
+ # one entity. It's not used for training so it doesn't have to be real,
+ # so we add a fake one if none are present.
+ # We can't use Doc.has_annotation here because it can be True for docs
+ # that have been through an NER component but got no entities.
+ has_annotations = any([doc.ents for doc in doc_sample])
+ if not has_annotations:
+ doc = doc_sample[0]
+ ent = doc[0:1]
+ ent.label_ = "XXX"
+ doc.ents = (ent,)
+
self.model.initialize(
X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
)
+ if not has_annotations:
+ # Clean up dummy annotation
+ doc.ents = []
+
+ def batch_has_learnable_example(self, examples):
+ """Check if a batch contains a learnable example.
+
+ If one isn't present, then the update step needs to be skipped.
+ """
+
+ for eg in examples:
+ for ent in eg.predicted.ents:
+ candidates = list(self.get_candidates(self.kb, ent))
+ if candidates:
+ return True
+
+ return False
+
def update(
self,
examples: Iterable[Example],
@@ -247,35 +304,29 @@ class EntityLinker(TrainablePipe):
if not examples:
return losses
validate_examples(examples, "EntityLinker.update")
- sentence_docs = []
- for eg in examples:
- sentences = [s for s in eg.reference.sents]
- kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
- for ent in eg.reference.ents:
- # KB ID of the first token is the same as the whole span
- kb_id = kb_ids[ent.start]
- if kb_id:
- try:
- # find the sentence in the list of sentences.
- sent_index = sentences.index(ent.sent)
- except AttributeError:
- # Catch the exception when ent.sent is None and provide a user-friendly warning
- raise RuntimeError(Errors.E030) from None
- # get n previous sentences, if there are any
- start_sentence = max(0, sent_index - self.n_sents)
- # get n posterior sentences, or as many < n as there are
- end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
- # get token positions
- start_token = sentences[start_sentence].start
- end_token = sentences[end_sentence].end
- # append that span as a doc to training
- sent_doc = eg.predicted[start_token:end_token].as_doc()
- sentence_docs.append(sent_doc)
+
set_dropout_rate(self.model, drop)
- if not sentence_docs:
- warnings.warn(Warnings.W093.format(name="Entity Linker"))
+ docs = [eg.predicted for eg in examples]
+ # save to restore later
+ old_ents = [doc.ents for doc in docs]
+
+ for doc, ex in zip(docs, examples):
+ if self.use_gold_ents:
+ doc.ents = ex.reference.ents
+ else:
+ # only keep matching ents
+ doc.ents = ex.get_matching_ents()
+
+ # make sure we have something to learn from, if not, short-circuit
+ if not self.batch_has_learnable_example(examples):
return losses
- sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
+
+ sentence_encodings, bp_context = self.model.begin_update(docs)
+
+ # now restore the ents
+ for doc, old in zip(docs, old_ents):
+ doc.ents = old
+
loss, d_scores = self.get_loss(
sentence_encodings=sentence_encodings, examples=examples
)
@@ -288,24 +339,38 @@ class EntityLinker(TrainablePipe):
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = []
+ eidx = 0 # indices in gold entities to keep
+ keep_ents = [] # indices in sentence_encodings to keep
+
for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
+
for ent in eg.reference.ents:
kb_id = kb_ids[ent.start]
if kb_id:
entity_encoding = self.kb.get_vector(kb_id)
entity_encodings.append(entity_encoding)
+ keep_ents.append(eidx)
+
+ eidx += 1
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
- if sentence_encodings.shape != entity_encodings.shape:
+ selected_encodings = sentence_encodings[keep_ents]
+
+ # If the entity encodings list is empty, then
+ if selected_encodings.shape != entity_encodings.shape:
err = Errors.E147.format(
method="get_loss", msg="gold entities do not match up"
)
raise RuntimeError(err)
# TODO: fix typing issue here
- gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
- loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
+ gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore
+ # to match the input size, we need to give a zero gradient for items not in the kb
+ out = self.model.ops.alloc2f(*sentence_encodings.shape)
+ out[keep_ents] = gradients
+
+ loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore
loss = loss / len(entity_encodings)
- return float(loss), gradients
+ return float(loss), out
def predict(self, docs: Iterable[Doc]) -> List[str]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
diff --git a/spacy/pipeline/legacy/__init__.py b/spacy/pipeline/legacy/__init__.py
new file mode 100644
index 000000000..f216840dc
--- /dev/null
+++ b/spacy/pipeline/legacy/__init__.py
@@ -0,0 +1,3 @@
+from .entity_linker import EntityLinker_v1
+
+__all__ = ["EntityLinker_v1"]
diff --git a/spacy/pipeline/legacy/entity_linker.py b/spacy/pipeline/legacy/entity_linker.py
new file mode 100644
index 000000000..6440c18e5
--- /dev/null
+++ b/spacy/pipeline/legacy/entity_linker.py
@@ -0,0 +1,427 @@
+# This file is present to provide a prior version of the EntityLinker component
+# for backwards compatability. For details see #9669.
+
+from typing import Optional, Iterable, Callable, Dict, Union, List, Any
+from thinc.types import Floats2d
+from pathlib import Path
+from itertools import islice
+import srsly
+import random
+from thinc.api import CosineDistance, Model, Optimizer, Config
+from thinc.api import set_dropout_rate
+import warnings
+
+from ...kb import KnowledgeBase, Candidate
+from ...ml import empty_kb
+from ...tokens import Doc, Span
+from ..pipe import deserialize_config
+from ..trainable_pipe import TrainablePipe
+from ...language import Language
+from ...vocab import Vocab
+from ...training import Example, validate_examples, validate_get_examples
+from ...errors import Errors, Warnings
+from ...util import SimpleFrozenList, registry
+from ... import util
+from ...scorer import Scorer
+
+# See #9050
+BACKWARD_OVERWRITE = True
+
+
+def entity_linker_score(examples, **kwargs):
+ return Scorer.score_links(examples, negative_labels=[EntityLinker_v1.NIL], **kwargs)
+
+
+class EntityLinker_v1(TrainablePipe):
+ """Pipeline component for named entity linking.
+
+ DOCS: https://spacy.io/api/entitylinker
+ """
+
+ NIL = "NIL" # string used to refer to a non-existing link
+
+ def __init__(
+ self,
+ vocab: Vocab,
+ model: Model,
+ name: str = "entity_linker",
+ *,
+ labels_discard: Iterable[str],
+ n_sents: int,
+ incl_prior: bool,
+ incl_context: bool,
+ entity_vector_length: int,
+ get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
+ overwrite: bool = BACKWARD_OVERWRITE,
+ scorer: Optional[Callable] = entity_linker_score,
+ ) -> None:
+ """Initialize an entity linker.
+
+ vocab (Vocab): The shared vocabulary.
+ model (thinc.api.Model): The Thinc Model powering the pipeline component.
+ name (str): The component instance name, used to add entries to the
+ losses during training.
+ labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
+ n_sents (int): The number of neighbouring sentences to take into account.
+ incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
+ incl_context (bool): Whether or not to include the local context in the model.
+ entity_vector_length (int): Size of encoding vectors in the KB.
+ get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
+ produces a list of candidates, given a certain knowledge base and a textual mention.
+ scorer (Optional[Callable]): The scoring method. Defaults to
+ Scorer.score_links.
+
+ DOCS: https://spacy.io/api/entitylinker#init
+ """
+ self.vocab = vocab
+ self.model = model
+ self.name = name
+ self.labels_discard = list(labels_discard)
+ self.n_sents = n_sents
+ self.incl_prior = incl_prior
+ self.incl_context = incl_context
+ self.get_candidates = get_candidates
+ self.cfg: Dict[str, Any] = {"overwrite": overwrite}
+ self.distance = CosineDistance(normalize=False)
+ # how many neighbour sentences to take into account
+ # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
+ self.kb = empty_kb(entity_vector_length)(self.vocab)
+ self.scorer = scorer
+
+ def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
+ """Define the KB of this pipe by providing a function that will
+ create it using this object's vocab."""
+ if not callable(kb_loader):
+ raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
+
+ self.kb = kb_loader(self.vocab)
+
+ def validate_kb(self) -> None:
+ # Raise an error if the knowledge base is not initialized.
+ if self.kb is None:
+ raise ValueError(Errors.E1018.format(name=self.name))
+ if len(self.kb) == 0:
+ raise ValueError(Errors.E139.format(name=self.name))
+
+ def initialize(
+ self,
+ get_examples: Callable[[], Iterable[Example]],
+ *,
+ nlp: Optional[Language] = None,
+ kb_loader: Optional[Callable[[Vocab], KnowledgeBase]] = None,
+ ):
+ """Initialize the pipe for training, using a representative set
+ of data examples.
+
+ get_examples (Callable[[], Iterable[Example]]): Function that
+ returns a representative sample of gold-standard Example objects.
+ nlp (Language): The current nlp object the component is part of.
+ kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
+ Note that providing this argument, will overwrite all data accumulated in the current KB.
+ Use this only when loading a KB as-such from file.
+
+ DOCS: https://spacy.io/api/entitylinker#initialize
+ """
+ validate_get_examples(get_examples, "EntityLinker_v1.initialize")
+ if kb_loader is not None:
+ self.set_kb(kb_loader)
+ self.validate_kb()
+ nO = self.kb.entity_vector_length
+ doc_sample = []
+ vector_sample = []
+ for example in islice(get_examples(), 10):
+ doc_sample.append(example.x)
+ vector_sample.append(self.model.ops.alloc1f(nO))
+ assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
+ assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
+ self.model.initialize(
+ X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
+ )
+
+ def update(
+ self,
+ examples: Iterable[Example],
+ *,
+ drop: float = 0.0,
+ sgd: Optional[Optimizer] = None,
+ losses: Optional[Dict[str, float]] = None,
+ ) -> Dict[str, float]:
+ """Learn from a batch of documents and gold-standard information,
+ updating the pipe's model. Delegates to predict and get_loss.
+
+ examples (Iterable[Example]): A batch of Example objects.
+ drop (float): The dropout rate.
+ sgd (thinc.api.Optimizer): The optimizer.
+ losses (Dict[str, float]): Optional record of the loss during training.
+ Updated using the component name as the key.
+ RETURNS (Dict[str, float]): The updated losses dictionary.
+
+ DOCS: https://spacy.io/api/entitylinker#update
+ """
+ self.validate_kb()
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+ if not examples:
+ return losses
+ validate_examples(examples, "EntityLinker_v1.update")
+ sentence_docs = []
+ for eg in examples:
+ sentences = [s for s in eg.reference.sents]
+ kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
+ for ent in eg.reference.ents:
+ # KB ID of the first token is the same as the whole span
+ kb_id = kb_ids[ent.start]
+ if kb_id:
+ try:
+ # find the sentence in the list of sentences.
+ sent_index = sentences.index(ent.sent)
+ except AttributeError:
+ # Catch the exception when ent.sent is None and provide a user-friendly warning
+ raise RuntimeError(Errors.E030) from None
+ # get n previous sentences, if there are any
+ start_sentence = max(0, sent_index - self.n_sents)
+ # get n posterior sentences, or as many < n as there are
+ end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
+ # get token positions
+ start_token = sentences[start_sentence].start
+ end_token = sentences[end_sentence].end
+ # append that span as a doc to training
+ sent_doc = eg.predicted[start_token:end_token].as_doc()
+ sentence_docs.append(sent_doc)
+ set_dropout_rate(self.model, drop)
+ if not sentence_docs:
+ warnings.warn(Warnings.W093.format(name="Entity Linker"))
+ return losses
+ sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
+ loss, d_scores = self.get_loss(
+ sentence_encodings=sentence_encodings, examples=examples
+ )
+ bp_context(d_scores)
+ if sgd is not None:
+ self.finish_update(sgd)
+ losses[self.name] += loss
+ return losses
+
+ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
+ validate_examples(examples, "EntityLinker_v1.get_loss")
+ entity_encodings = []
+ for eg in examples:
+ kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
+ for ent in eg.reference.ents:
+ kb_id = kb_ids[ent.start]
+ if kb_id:
+ entity_encoding = self.kb.get_vector(kb_id)
+ entity_encodings.append(entity_encoding)
+ entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
+ if sentence_encodings.shape != entity_encodings.shape:
+ err = Errors.E147.format(
+ method="get_loss", msg="gold entities do not match up"
+ )
+ raise RuntimeError(err)
+ # TODO: fix typing issue here
+ gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
+ loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
+ loss = loss / len(entity_encodings)
+ return float(loss), gradients
+
+ def predict(self, docs: Iterable[Doc]) -> List[str]:
+ """Apply the pipeline's model to a batch of docs, without modifying them.
+ Returns the KB IDs for each entity in each doc, including NIL if there is
+ no prediction.
+
+ docs (Iterable[Doc]): The documents to predict.
+ RETURNS (List[str]): The models prediction for each document.
+
+ DOCS: https://spacy.io/api/entitylinker#predict
+ """
+ self.validate_kb()
+ entity_count = 0
+ final_kb_ids: List[str] = []
+ if not docs:
+ return final_kb_ids
+ if isinstance(docs, Doc):
+ docs = [docs]
+ for i, doc in enumerate(docs):
+ sentences = [s for s in doc.sents]
+ if len(doc) > 0:
+ # Looping through each entity (TODO: rewrite)
+ for ent in doc.ents:
+ sent = ent.sent
+ sent_index = sentences.index(sent)
+ assert sent_index >= 0
+ # get n_neighbour sentences, clipped to the length of the document
+ start_sentence = max(0, sent_index - self.n_sents)
+ end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
+ start_token = sentences[start_sentence].start
+ end_token = sentences[end_sentence].end
+ sent_doc = doc[start_token:end_token].as_doc()
+ # currently, the context is the same for each entity in a sentence (should be refined)
+ xp = self.model.ops.xp
+ if self.incl_context:
+ sentence_encoding = self.model.predict([sent_doc])[0]
+ sentence_encoding_t = sentence_encoding.T
+ sentence_norm = xp.linalg.norm(sentence_encoding_t)
+ entity_count += 1
+ if ent.label_ in self.labels_discard:
+ # ignoring this entity - setting to NIL
+ final_kb_ids.append(self.NIL)
+ else:
+ candidates = list(self.get_candidates(self.kb, ent))
+ if not candidates:
+ # no prediction possible for this entity - setting to NIL
+ final_kb_ids.append(self.NIL)
+ elif len(candidates) == 1:
+ # shortcut for efficiency reasons: take the 1 candidate
+ # TODO: thresholding
+ final_kb_ids.append(candidates[0].entity_)
+ else:
+ random.shuffle(candidates)
+ # set all prior probabilities to 0 if incl_prior=False
+ prior_probs = xp.asarray([c.prior_prob for c in candidates])
+ if not self.incl_prior:
+ prior_probs = xp.asarray([0.0 for _ in candidates])
+ scores = prior_probs
+ # add in similarity from the context
+ if self.incl_context:
+ entity_encodings = xp.asarray(
+ [c.entity_vector for c in candidates]
+ )
+ entity_norm = xp.linalg.norm(entity_encodings, axis=1)
+ if len(entity_encodings) != len(prior_probs):
+ raise RuntimeError(
+ Errors.E147.format(
+ method="predict",
+ msg="vectors not of equal length",
+ )
+ )
+ # cosine similarity
+ sims = xp.dot(entity_encodings, sentence_encoding_t) / (
+ sentence_norm * entity_norm
+ )
+ if sims.shape != prior_probs.shape:
+ raise ValueError(Errors.E161)
+ scores = prior_probs + sims - (prior_probs * sims)
+ # TODO: thresholding
+ best_index = scores.argmax().item()
+ best_candidate = candidates[best_index]
+ final_kb_ids.append(best_candidate.entity_)
+ if not (len(final_kb_ids) == entity_count):
+ err = Errors.E147.format(
+ method="predict", msg="result variables not of equal length"
+ )
+ raise RuntimeError(err)
+ return final_kb_ids
+
+ def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
+ """Modify a batch of documents, using pre-computed scores.
+
+ docs (Iterable[Doc]): The documents to modify.
+ kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
+
+ DOCS: https://spacy.io/api/entitylinker#set_annotations
+ """
+ count_ents = len([ent for doc in docs for ent in doc.ents])
+ if count_ents != len(kb_ids):
+ raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
+ i = 0
+ overwrite = self.cfg["overwrite"]
+ for doc in docs:
+ for ent in doc.ents:
+ kb_id = kb_ids[i]
+ i += 1
+ for token in ent:
+ if token.ent_kb_id == 0 or overwrite:
+ token.ent_kb_id_ = kb_id
+
+ def to_bytes(self, *, exclude=tuple()):
+ """Serialize the pipe to a bytestring.
+
+ exclude (Iterable[str]): String names of serialization fields to exclude.
+ RETURNS (bytes): The serialized object.
+
+ DOCS: https://spacy.io/api/entitylinker#to_bytes
+ """
+ self._validate_serialization_attrs()
+ serialize = {}
+ if hasattr(self, "cfg") and self.cfg is not None:
+ serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
+ serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
+ serialize["kb"] = self.kb.to_bytes
+ serialize["model"] = self.model.to_bytes
+ return util.to_bytes(serialize, exclude)
+
+ def from_bytes(self, bytes_data, *, exclude=tuple()):
+ """Load the pipe from a bytestring.
+
+ exclude (Iterable[str]): String names of serialization fields to exclude.
+ RETURNS (TrainablePipe): The loaded object.
+
+ DOCS: https://spacy.io/api/entitylinker#from_bytes
+ """
+ self._validate_serialization_attrs()
+
+ def load_model(b):
+ try:
+ self.model.from_bytes(b)
+ except AttributeError:
+ raise ValueError(Errors.E149) from None
+
+ deserialize = {}
+ if hasattr(self, "cfg") and self.cfg is not None:
+ deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
+ deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
+ deserialize["kb"] = lambda b: self.kb.from_bytes(b)
+ deserialize["model"] = load_model
+ util.from_bytes(bytes_data, deserialize, exclude)
+ return self
+
+ def to_disk(
+ self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
+ ) -> None:
+ """Serialize the pipe to disk.
+
+ path (str / Path): Path to a directory.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
+
+ DOCS: https://spacy.io/api/entitylinker#to_disk
+ """
+ serialize = {}
+ serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
+ serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
+ serialize["kb"] = lambda p: self.kb.to_disk(p)
+ serialize["model"] = lambda p: self.model.to_disk(p)
+ util.to_disk(path, serialize, exclude)
+
+ def from_disk(
+ self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
+ ) -> "EntityLinker_v1":
+ """Load the pipe from disk. Modifies the object in place and returns it.
+
+ path (str / Path): Path to a directory.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
+ RETURNS (EntityLinker): The modified EntityLinker object.
+
+ DOCS: https://spacy.io/api/entitylinker#from_disk
+ """
+
+ def load_model(p):
+ try:
+ with p.open("rb") as infile:
+ self.model.from_bytes(infile.read())
+ except AttributeError:
+ raise ValueError(Errors.E149) from None
+
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
+ deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
+ deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
+ deserialize["kb"] = lambda p: self.kb.from_disk(p)
+ deserialize["model"] = load_model
+ util.from_disk(path, deserialize, exclude)
+ return self
+
+ def rehearse(self, examples, *, sgd=None, losses=None, **config):
+ raise NotImplementedError
+
+ def add_label(self, label):
+ raise NotImplementedError
diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py
index 3740e430e..7d1382741 100644
--- a/spacy/tests/pipeline/test_entity_linker.py
+++ b/spacy/tests/pipeline/test_entity_linker.py
@@ -9,6 +9,9 @@ from spacy.compat import pickle
from spacy.kb import Candidate, KnowledgeBase, get_candidates
from spacy.lang.en import English
from spacy.ml import load_kb
+from spacy.pipeline import EntityLinker
+from spacy.pipeline.legacy import EntityLinker_v1
+from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer
from spacy.tests.util import make_tempdir
from spacy.tokens import Span
@@ -168,6 +171,45 @@ def test_issue7065_b():
assert doc
+def test_no_entities():
+ # Test that having no entities doesn't crash the model
+ TRAIN_DATA = [
+ (
+ "The sky is blue.",
+ {
+ "sent_starts": [1, 0, 0, 0, 0],
+ },
+ )
+ ]
+ nlp = English()
+ vector_length = 3
+ train_examples = []
+ for text, annotation in TRAIN_DATA:
+ doc = nlp(text)
+ train_examples.append(Example.from_dict(doc, annotation))
+
+ def create_kb(vocab):
+ # create artificial KB
+ mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
+ mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
+ mykb.add_alias("Russ Cochran", ["Q2146908"], [0.9])
+ return mykb
+
+ # Create and train the Entity Linker
+ entity_linker = nlp.add_pipe("entity_linker", last=True)
+ entity_linker.set_kb(create_kb)
+ optimizer = nlp.initialize(get_examples=lambda: train_examples)
+ for i in range(2):
+ losses = {}
+ nlp.update(train_examples, sgd=optimizer, losses=losses)
+
+ # adding additional components that are required for the entity_linker
+ nlp.add_pipe("sentencizer", first=True)
+
+ # this will run the pipeline on the examples and shouldn't crash
+ results = nlp.evaluate(train_examples)
+
+
def test_partial_links():
# Test that having some entities on the doc without gold links, doesn't crash
TRAIN_DATA = [
@@ -650,7 +692,7 @@ TRAIN_DATA = [
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}),
("Russ Cochran his reprints include EC Comics.",
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
- "entities": [(0, 12, "PERSON")],
+ "entities": [(0, 12, "PERSON"), (34, 43, "ART")],
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0]}),
("Russ Cochran has been publishing comic art.",
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
@@ -693,6 +735,7 @@ def test_overfitting_IO():
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True)
+ assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
assert "Q2146908" in entity_linker.kb.vocab.strings
@@ -922,3 +965,109 @@ def test_scorer_links():
assert scores["nel_micro_p"] == 2 / 3
assert scores["nel_micro_r"] == 2 / 4
+
+
+# fmt: off
+@pytest.mark.parametrize(
+ "name,config",
+ [
+ ("entity_linker", {"@architectures": "spacy.EntityLinker.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
+ ("entity_linker", {"@architectures": "spacy.EntityLinker.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
+ ],
+)
+# fmt: on
+def test_legacy_architectures(name, config):
+ # Ensure that the legacy architectures still work
+ vector_length = 3
+ nlp = English()
+
+ train_examples = []
+ for text, annotation in TRAIN_DATA:
+ doc = nlp.make_doc(text)
+ train_examples.append(Example.from_dict(doc, annotation))
+
+ def create_kb(vocab):
+ mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
+ mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
+ mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
+ mykb.add_alias(
+ alias="Russ Cochran",
+ entities=["Q2146908", "Q7381115"],
+ probabilities=[0.5, 0.5],
+ )
+ return mykb
+
+ entity_linker = nlp.add_pipe(name, config={"model": config})
+ if config["@architectures"] == "spacy.EntityLinker.v1":
+ assert isinstance(entity_linker, EntityLinker_v1)
+ else:
+ assert isinstance(entity_linker, EntityLinker)
+ entity_linker.set_kb(create_kb)
+ optimizer = nlp.initialize(get_examples=lambda: train_examples)
+
+ for i in range(2):
+ losses = {}
+ nlp.update(train_examples, sgd=optimizer, losses=losses)
+
+ @pytest.mark.parametrize("patterns", [
+ # perfect case
+ [{"label": "CHARACTER", "pattern": "Kirby"}],
+ # typo for false negative
+ [{"label": "PERSON", "pattern": "Korby"}],
+ # random stuff for false positive
+ [{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}],
+ ]
+ )
+ def test_no_gold_ents(patterns):
+ # test that annotating components work
+ TRAIN_DATA = [
+ (
+ "Kirby is pink",
+ {
+ "links": {(0, 5): {"Q613241": 1.0}},
+ "entities": [(0, 5, "CHARACTER")],
+ "sent_starts": [1, 0, 0],
+ },
+ )
+ ]
+ nlp = English()
+ vector_length = 3
+ train_examples = []
+ for text, annotation in TRAIN_DATA:
+ doc = nlp(text)
+ train_examples.append(Example.from_dict(doc, annotation))
+
+ # Create a ruler to mark entities
+ ruler = nlp.add_pipe("entity_ruler")
+ ruler.add_patterns(patterns)
+
+ # Apply ruler to examples. In a real pipeline this would be an annotating component.
+ for eg in train_examples:
+ eg.predicted = ruler(eg.predicted)
+
+ def create_kb(vocab):
+ # create artificial KB
+ mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
+ mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3])
+ mykb.add_alias("Kirby", ["Q613241"], [0.9])
+ # Placeholder
+ mykb.add_entity(entity="pink", freq=12, entity_vector=[7, 2, -5])
+ mykb.add_alias("pink", ["pink"], [0.9])
+ return mykb
+
+
+ # Create and train the Entity Linker
+ entity_linker = nlp.add_pipe("entity_linker", config={"use_gold_ents": False}, last=True)
+ entity_linker.set_kb(create_kb)
+ assert entity_linker.use_gold_ents == False
+
+ optimizer = nlp.initialize(get_examples=lambda: train_examples)
+ for i in range(2):
+ losses = {}
+ nlp.update(train_examples, sgd=optimizer, losses=losses)
+
+ # adding additional components that are required for the entity_linker
+ nlp.add_pipe("sentencizer", first=True)
+
+ # this will run the pipeline on the examples and shouldn't crash
+ results = nlp.evaluate(train_examples)
diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx
index d792c9bbf..a2c5e08e9 100644
--- a/spacy/training/example.pyx
+++ b/spacy/training/example.pyx
@@ -256,6 +256,29 @@ cdef class Example:
x_ents, x_tags = self.get_aligned_ents_and_ner()
return x_tags
+ def get_matching_ents(self, check_label=True):
+ """Return entities that are shared between predicted and reference docs.
+
+ If `check_label` is True, entities must have matching labels to be
+ kept. Otherwise only the character indices need to match.
+ """
+ gold = {}
+ for ent in self.reference:
+ gold[(ent.start_char, ent.end_char)] = ent.label
+
+ keep = []
+ for ent in self.predicted:
+ key = (ent.start_char, ent.end_char)
+ if key not in gold:
+ continue
+
+ if check_label and ent.label != gold[key]:
+ continue
+
+ keep.append(ent)
+
+ return keep
+
def to_dict(self):
return {
"doc_annotation": {
diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md
index 07b76393f..5fb3546a7 100644
--- a/website/docs/api/architectures.md
+++ b/website/docs/api/architectures.md
@@ -858,13 +858,13 @@ into the "real world". This requires 3 main components:
- A machine learning [`Model`](https://thinc.ai/docs/api-model) that picks the
most plausible ID from the set of candidates.
-### spacy.EntityLinker.v1 {#EntityLinker}
+### spacy.EntityLinker.v2 {#EntityLinker}
> #### Example Config
>
> ```ini
> [model]
-> @architectures = "spacy.EntityLinker.v1"
+> @architectures = "spacy.EntityLinker.v2"
> nO = null
>
> [model.tok2vec]
diff --git a/website/docs/api/entitylinker.md b/website/docs/api/entitylinker.md
index 3d3372679..8e0d6087a 100644
--- a/website/docs/api/entitylinker.md
+++ b/website/docs/api/entitylinker.md
@@ -59,6 +59,7 @@ architectures and their arguments and hyperparameters.
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
+| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `overwrite` 3.2 | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
| `scorer` 3.2 | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |