mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix entity linker batching (#9669)
* Partial fix of entity linker batching * Add import * Better name * Add `use_gold_ents` option, docs * Change to v2, create stub v1, update docs etc. * Fix error type Honestly no idea what the right type to use here is. ConfigValidationError seems wrong. Maybe a NotImplementedError? * Make mypy happy * Add hacky fix for init issue * Add legacy pipeline entity linker * Fix references to class name * Add __init__.py for legacy * Attempted fix for loss issue * Remove placeholder V1 * formatting * slightly more interesting train data * Handle batches with no usable examples This adds a test for batches that have docs but not entities, and a check in the component that detects such cases and skips the update step as thought the batch were empty. * Remove todo about data verification Check for empty data was moved further up so this should be OK now - the case in question shouldn't be possible. * Fix gradient calculation The model doesn't know which entities are not in the kb, so it generates embeddings for the context of all of them. However, the loss does know which entities aren't in the kb, and it ignores them, as there's no sensible gradient. This has the issue that the gradient will not be calculated for some of the input embeddings, which causes a dimension mismatch in backprop. That should have caused a clear error, but with numpyops it was causing nans to happen, which is another problem that should be addressed separately. This commit changes the loss to give a zero gradient for entities not in the kb. * add failing test for v1 EL legacy architecture * Add nasty but simple working check for legacy arch * Clarify why init hack works the way it does * Clarify use_gold_ents use case * Fix use gold ents related handling * Add tests for no gold ents and fix other tests * Use aligned ents function (not working) This doesn't actually work because the "aligned" ents are gold-only. But if I have a different function that returns the intersection, *then* this will work as desired. * Use proper matching ent check This changes the process when gold ents are not used so that the intersection of ents in the pred and gold is used. * Move get_matching_ents to Example * Use model attribute to check for legacy arch * Rename flag * bump spacy-legacy to lower 3.0.9 Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
parent
8e93fa8507
commit
91acc3ea75
|
@ -1,5 +1,5 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
spacy-legacy>=3.0.8,<3.1.0
|
spacy-legacy>=3.0.9,<3.1.0
|
||||||
spacy-loggers>=1.0.0,<2.0.0
|
spacy-loggers>=1.0.0,<2.0.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
|
|
|
@ -41,7 +41,7 @@ setup_requires =
|
||||||
thinc>=8.0.12,<8.1.0
|
thinc>=8.0.12,<8.1.0
|
||||||
install_requires =
|
install_requires =
|
||||||
# Our libraries
|
# Our libraries
|
||||||
spacy-legacy>=3.0.8,<3.1.0
|
spacy-legacy>=3.0.9,<3.1.0
|
||||||
spacy-loggers>=1.0.0,<2.0.0
|
spacy-loggers>=1.0.0,<2.0.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
|
|
|
@ -131,7 +131,7 @@ incl_context = true
|
||||||
incl_prior = true
|
incl_prior = true
|
||||||
|
|
||||||
[components.entity_linker.model]
|
[components.entity_linker.model]
|
||||||
@architectures = "spacy.EntityLinker.v1"
|
@architectures = "spacy.EntityLinker.v2"
|
||||||
nO = null
|
nO = null
|
||||||
|
|
||||||
[components.entity_linker.model.tok2vec]
|
[components.entity_linker.model.tok2vec]
|
||||||
|
@ -303,7 +303,7 @@ incl_context = true
|
||||||
incl_prior = true
|
incl_prior = true
|
||||||
|
|
||||||
[components.entity_linker.model]
|
[components.entity_linker.model]
|
||||||
@architectures = "spacy.EntityLinker.v1"
|
@architectures = "spacy.EntityLinker.v2"
|
||||||
nO = null
|
nO = null
|
||||||
|
|
||||||
[components.entity_linker.model.tok2vec]
|
[components.entity_linker.model.tok2vec]
|
||||||
|
|
|
@ -63,4 +63,4 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
|
||||||
|
|
||||||
|
|
||||||
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]:
|
def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]:
|
||||||
return (Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths))
|
return Ragged(to_numpy(spans.dataXd), to_numpy(spans.lengths)), to_numpy(lengths)
|
||||||
|
|
|
@ -1,34 +1,82 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Callable, Iterable, List
|
from typing import Optional, Callable, Iterable, List, Tuple
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
|
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
|
||||||
from thinc.api import Model, Maxout, Linear
|
from thinc.api import Model, Maxout, Linear, noop, tuplify, Ragged
|
||||||
|
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ...kb import KnowledgeBase, Candidate, get_candidates
|
from ...kb import KnowledgeBase, Candidate, get_candidates
|
||||||
from ...vocab import Vocab
|
from ...vocab import Vocab
|
||||||
from ...tokens import Span, Doc
|
from ...tokens import Span, Doc
|
||||||
|
from ..extract_spans import extract_spans
|
||||||
|
from ...errors import Errors
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.EntityLinker.v1")
|
@registry.architectures("spacy.EntityLinker.v2")
|
||||||
def build_nel_encoder(
|
def build_nel_encoder(
|
||||||
tok2vec: Model, nO: Optional[int] = None
|
tok2vec: Model, nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
with Model.define_operators({">>": chain, "**": clone}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
token_width = tok2vec.maybe_get_dim("nO")
|
token_width = tok2vec.maybe_get_dim("nO")
|
||||||
output_layer = Linear(nO=nO, nI=token_width)
|
output_layer = Linear(nO=nO, nI=token_width)
|
||||||
model = (
|
model = (
|
||||||
tok2vec
|
((tok2vec >> list2ragged()) & build_span_maker())
|
||||||
>> list2ragged()
|
>> extract_spans()
|
||||||
>> reduce_mean()
|
>> reduce_mean()
|
||||||
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
|
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
|
||||||
>> output_layer
|
>> output_layer
|
||||||
)
|
)
|
||||||
model.set_ref("output_layer", output_layer)
|
model.set_ref("output_layer", output_layer)
|
||||||
model.set_ref("tok2vec", tok2vec)
|
model.set_ref("tok2vec", tok2vec)
|
||||||
|
# flag to show this isn't legacy
|
||||||
|
model.attrs["include_span_maker"] = True
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_span_maker(n_sents: int = 0) -> Model:
|
||||||
|
model: Model = Model("span_maker", forward=span_maker_forward)
|
||||||
|
model.attrs["n_sents"] = n_sents
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callable]:
|
||||||
|
ops = model.ops
|
||||||
|
n_sents = model.attrs["n_sents"]
|
||||||
|
candidates = []
|
||||||
|
for doc in docs:
|
||||||
|
cands = []
|
||||||
|
try:
|
||||||
|
sentences = [s for s in doc.sents]
|
||||||
|
except ValueError:
|
||||||
|
# no sentence info, normal in initialization
|
||||||
|
for tok in doc:
|
||||||
|
tok.is_sent_start = tok.i == 0
|
||||||
|
sentences = [doc[:]]
|
||||||
|
for ent in doc.ents:
|
||||||
|
try:
|
||||||
|
# find the sentence in the list of sentences.
|
||||||
|
sent_index = sentences.index(ent.sent)
|
||||||
|
except AttributeError:
|
||||||
|
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
||||||
|
raise RuntimeError(Errors.E030) from None
|
||||||
|
# get n previous sentences, if there are any
|
||||||
|
start_sentence = max(0, sent_index - n_sents)
|
||||||
|
# get n posterior sentences, or as many < n as there are
|
||||||
|
end_sentence = min(len(sentences) - 1, sent_index + n_sents)
|
||||||
|
# get token positions
|
||||||
|
start_token = sentences[start_sentence].start
|
||||||
|
end_token = sentences[end_sentence].end
|
||||||
|
# save positions for extraction
|
||||||
|
cands.append((start_token, end_token))
|
||||||
|
|
||||||
|
candidates.append(ops.asarray2i(cands))
|
||||||
|
candlens = ops.asarray1i([len(cands) for cands in candidates])
|
||||||
|
candidates = ops.xp.concatenate(candidates)
|
||||||
|
outputs = Ragged(candidates, candlens)
|
||||||
|
# because this is just rearranging docs, the backprop does nothing
|
||||||
|
return outputs, lambda x: []
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.KBFromFile.v1")
|
@registry.misc("spacy.KBFromFile.v1")
|
||||||
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
def load_kb(kb_path: Path) -> Callable[[Vocab], KnowledgeBase]:
|
||||||
def kb_from_file(vocab):
|
def kb_from_file(vocab):
|
||||||
|
|
|
@ -6,17 +6,17 @@ import srsly
|
||||||
import random
|
import random
|
||||||
from thinc.api import CosineDistance, Model, Optimizer, Config
|
from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate
|
||||||
import warnings
|
|
||||||
|
|
||||||
from ..kb import KnowledgeBase, Candidate
|
from ..kb import KnowledgeBase, Candidate
|
||||||
from ..ml import empty_kb
|
from ..ml import empty_kb
|
||||||
from ..tokens import Doc, Span
|
from ..tokens import Doc, Span
|
||||||
from .pipe import deserialize_config
|
from .pipe import deserialize_config
|
||||||
|
from .legacy.entity_linker import EntityLinker_v1
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..training import Example, validate_examples, validate_get_examples
|
from ..training import Example, validate_examples, validate_get_examples
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors
|
||||||
from ..util import SimpleFrozenList, registry
|
from ..util import SimpleFrozenList, registry
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
|
@ -26,7 +26,7 @@ BACKWARD_OVERWRITE = True
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.EntityLinker.v1"
|
@architectures = "spacy.EntityLinker.v2"
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.HashEmbedCNN.v2"
|
@architectures = "spacy.HashEmbedCNN.v2"
|
||||||
|
@ -55,6 +55,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||||
"overwrite": True,
|
"overwrite": True,
|
||||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||||
|
"use_gold_ents": True,
|
||||||
},
|
},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
"nel_micro_f": 1.0,
|
"nel_micro_f": 1.0,
|
||||||
|
@ -75,6 +76,7 @@ def make_entity_linker(
|
||||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
|
use_gold_ents: bool,
|
||||||
):
|
):
|
||||||
"""Construct an EntityLinker component.
|
"""Construct an EntityLinker component.
|
||||||
|
|
||||||
|
@ -90,6 +92,22 @@ def make_entity_linker(
|
||||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||||
scorer (Optional[Callable]): The scoring method.
|
scorer (Optional[Callable]): The scoring method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not model.attrs.get("include_span_maker", False):
|
||||||
|
# The only difference in arguments here is that use_gold_ents is not available
|
||||||
|
return EntityLinker_v1(
|
||||||
|
nlp.vocab,
|
||||||
|
model,
|
||||||
|
name,
|
||||||
|
labels_discard=labels_discard,
|
||||||
|
n_sents=n_sents,
|
||||||
|
incl_prior=incl_prior,
|
||||||
|
incl_context=incl_context,
|
||||||
|
entity_vector_length=entity_vector_length,
|
||||||
|
get_candidates=get_candidates,
|
||||||
|
overwrite=overwrite,
|
||||||
|
scorer=scorer,
|
||||||
|
)
|
||||||
return EntityLinker(
|
return EntityLinker(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
model,
|
model,
|
||||||
|
@ -102,6 +120,7 @@ def make_entity_linker(
|
||||||
get_candidates=get_candidates,
|
get_candidates=get_candidates,
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
|
use_gold_ents=use_gold_ents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,6 +155,7 @@ class EntityLinker(TrainablePipe):
|
||||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
overwrite: bool = BACKWARD_OVERWRITE,
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
|
use_gold_ents: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize an entity linker.
|
"""Initialize an entity linker.
|
||||||
|
|
||||||
|
@ -152,6 +172,8 @@ class EntityLinker(TrainablePipe):
|
||||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
Scorer.score_links.
|
Scorer.score_links.
|
||||||
|
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||||
|
component must provide entity annotations.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#init
|
DOCS: https://spacy.io/api/entitylinker#init
|
||||||
"""
|
"""
|
||||||
|
@ -169,6 +191,7 @@ class EntityLinker(TrainablePipe):
|
||||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.use_gold_ents = use_gold_ents
|
||||||
|
|
||||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||||
"""Define the KB of this pipe by providing a function that will
|
"""Define the KB of this pipe by providing a function that will
|
||||||
|
@ -212,14 +235,48 @@ class EntityLinker(TrainablePipe):
|
||||||
doc_sample = []
|
doc_sample = []
|
||||||
vector_sample = []
|
vector_sample = []
|
||||||
for example in islice(get_examples(), 10):
|
for example in islice(get_examples(), 10):
|
||||||
doc_sample.append(example.x)
|
doc = example.x
|
||||||
|
if self.use_gold_ents:
|
||||||
|
doc.ents = example.y.ents
|
||||||
|
doc_sample.append(doc)
|
||||||
vector_sample.append(self.model.ops.alloc1f(nO))
|
vector_sample.append(self.model.ops.alloc1f(nO))
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
|
||||||
|
# XXX In order for size estimation to work, there has to be at least
|
||||||
|
# one entity. It's not used for training so it doesn't have to be real,
|
||||||
|
# so we add a fake one if none are present.
|
||||||
|
# We can't use Doc.has_annotation here because it can be True for docs
|
||||||
|
# that have been through an NER component but got no entities.
|
||||||
|
has_annotations = any([doc.ents for doc in doc_sample])
|
||||||
|
if not has_annotations:
|
||||||
|
doc = doc_sample[0]
|
||||||
|
ent = doc[0:1]
|
||||||
|
ent.label_ = "XXX"
|
||||||
|
doc.ents = (ent,)
|
||||||
|
|
||||||
self.model.initialize(
|
self.model.initialize(
|
||||||
X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
|
X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not has_annotations:
|
||||||
|
# Clean up dummy annotation
|
||||||
|
doc.ents = []
|
||||||
|
|
||||||
|
def batch_has_learnable_example(self, examples):
|
||||||
|
"""Check if a batch contains a learnable example.
|
||||||
|
|
||||||
|
If one isn't present, then the update step needs to be skipped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for eg in examples:
|
||||||
|
for ent in eg.predicted.ents:
|
||||||
|
candidates = list(self.get_candidates(self.kb, ent))
|
||||||
|
if candidates:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
|
@ -247,35 +304,29 @@ class EntityLinker(TrainablePipe):
|
||||||
if not examples:
|
if not examples:
|
||||||
return losses
|
return losses
|
||||||
validate_examples(examples, "EntityLinker.update")
|
validate_examples(examples, "EntityLinker.update")
|
||||||
sentence_docs = []
|
|
||||||
for eg in examples:
|
|
||||||
sentences = [s for s in eg.reference.sents]
|
|
||||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
|
||||||
for ent in eg.reference.ents:
|
|
||||||
# KB ID of the first token is the same as the whole span
|
|
||||||
kb_id = kb_ids[ent.start]
|
|
||||||
if kb_id:
|
|
||||||
try:
|
|
||||||
# find the sentence in the list of sentences.
|
|
||||||
sent_index = sentences.index(ent.sent)
|
|
||||||
except AttributeError:
|
|
||||||
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
|
||||||
raise RuntimeError(Errors.E030) from None
|
|
||||||
# get n previous sentences, if there are any
|
|
||||||
start_sentence = max(0, sent_index - self.n_sents)
|
|
||||||
# get n posterior sentences, or as many < n as there are
|
|
||||||
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
|
||||||
# get token positions
|
|
||||||
start_token = sentences[start_sentence].start
|
|
||||||
end_token = sentences[end_sentence].end
|
|
||||||
# append that span as a doc to training
|
|
||||||
sent_doc = eg.predicted[start_token:end_token].as_doc()
|
|
||||||
sentence_docs.append(sent_doc)
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
if not sentence_docs:
|
docs = [eg.predicted for eg in examples]
|
||||||
warnings.warn(Warnings.W093.format(name="Entity Linker"))
|
# save to restore later
|
||||||
|
old_ents = [doc.ents for doc in docs]
|
||||||
|
|
||||||
|
for doc, ex in zip(docs, examples):
|
||||||
|
if self.use_gold_ents:
|
||||||
|
doc.ents = ex.reference.ents
|
||||||
|
else:
|
||||||
|
# only keep matching ents
|
||||||
|
doc.ents = ex.get_matching_ents()
|
||||||
|
|
||||||
|
# make sure we have something to learn from, if not, short-circuit
|
||||||
|
if not self.batch_has_learnable_example(examples):
|
||||||
return losses
|
return losses
|
||||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
|
||||||
|
sentence_encodings, bp_context = self.model.begin_update(docs)
|
||||||
|
|
||||||
|
# now restore the ents
|
||||||
|
for doc, old in zip(docs, old_ents):
|
||||||
|
doc.ents = old
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(
|
loss, d_scores = self.get_loss(
|
||||||
sentence_encodings=sentence_encodings, examples=examples
|
sentence_encodings=sentence_encodings, examples=examples
|
||||||
)
|
)
|
||||||
|
@ -288,24 +339,38 @@ class EntityLinker(TrainablePipe):
|
||||||
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
||||||
validate_examples(examples, "EntityLinker.get_loss")
|
validate_examples(examples, "EntityLinker.get_loss")
|
||||||
entity_encodings = []
|
entity_encodings = []
|
||||||
|
eidx = 0 # indices in gold entities to keep
|
||||||
|
keep_ents = [] # indices in sentence_encodings to keep
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
|
|
||||||
for ent in eg.reference.ents:
|
for ent in eg.reference.ents:
|
||||||
kb_id = kb_ids[ent.start]
|
kb_id = kb_ids[ent.start]
|
||||||
if kb_id:
|
if kb_id:
|
||||||
entity_encoding = self.kb.get_vector(kb_id)
|
entity_encoding = self.kb.get_vector(kb_id)
|
||||||
entity_encodings.append(entity_encoding)
|
entity_encodings.append(entity_encoding)
|
||||||
|
keep_ents.append(eidx)
|
||||||
|
|
||||||
|
eidx += 1
|
||||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
if sentence_encodings.shape != entity_encodings.shape:
|
selected_encodings = sentence_encodings[keep_ents]
|
||||||
|
|
||||||
|
# If the entity encodings list is empty, then
|
||||||
|
if selected_encodings.shape != entity_encodings.shape:
|
||||||
err = Errors.E147.format(
|
err = Errors.E147.format(
|
||||||
method="get_loss", msg="gold entities do not match up"
|
method="get_loss", msg="gold entities do not match up"
|
||||||
)
|
)
|
||||||
raise RuntimeError(err)
|
raise RuntimeError(err)
|
||||||
# TODO: fix typing issue here
|
# TODO: fix typing issue here
|
||||||
gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
|
gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore
|
||||||
loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
|
# to match the input size, we need to give a zero gradient for items not in the kb
|
||||||
|
out = self.model.ops.alloc2f(*sentence_encodings.shape)
|
||||||
|
out[keep_ents] = gradients
|
||||||
|
|
||||||
|
loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore
|
||||||
loss = loss / len(entity_encodings)
|
loss = loss / len(entity_encodings)
|
||||||
return float(loss), gradients
|
return float(loss), out
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
|
3
spacy/pipeline/legacy/__init__.py
Normal file
3
spacy/pipeline/legacy/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .entity_linker import EntityLinker_v1
|
||||||
|
|
||||||
|
__all__ = ["EntityLinker_v1"]
|
427
spacy/pipeline/legacy/entity_linker.py
Normal file
427
spacy/pipeline/legacy/entity_linker.py
Normal file
|
@ -0,0 +1,427 @@
|
||||||
|
# This file is present to provide a prior version of the EntityLinker component
|
||||||
|
# for backwards compatability. For details see #9669.
|
||||||
|
|
||||||
|
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
from pathlib import Path
|
||||||
|
from itertools import islice
|
||||||
|
import srsly
|
||||||
|
import random
|
||||||
|
from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||||
|
from thinc.api import set_dropout_rate
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from ...kb import KnowledgeBase, Candidate
|
||||||
|
from ...ml import empty_kb
|
||||||
|
from ...tokens import Doc, Span
|
||||||
|
from ..pipe import deserialize_config
|
||||||
|
from ..trainable_pipe import TrainablePipe
|
||||||
|
from ...language import Language
|
||||||
|
from ...vocab import Vocab
|
||||||
|
from ...training import Example, validate_examples, validate_get_examples
|
||||||
|
from ...errors import Errors, Warnings
|
||||||
|
from ...util import SimpleFrozenList, registry
|
||||||
|
from ... import util
|
||||||
|
from ...scorer import Scorer
|
||||||
|
|
||||||
|
# See #9050
|
||||||
|
BACKWARD_OVERWRITE = True
|
||||||
|
|
||||||
|
|
||||||
|
def entity_linker_score(examples, **kwargs):
|
||||||
|
return Scorer.score_links(examples, negative_labels=[EntityLinker_v1.NIL], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityLinker_v1(TrainablePipe):
|
||||||
|
"""Pipeline component for named entity linking.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker
|
||||||
|
"""
|
||||||
|
|
||||||
|
NIL = "NIL" # string used to refer to a non-existing link
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab: Vocab,
|
||||||
|
model: Model,
|
||||||
|
name: str = "entity_linker",
|
||||||
|
*,
|
||||||
|
labels_discard: Iterable[str],
|
||||||
|
n_sents: int,
|
||||||
|
incl_prior: bool,
|
||||||
|
incl_context: bool,
|
||||||
|
entity_vector_length: int,
|
||||||
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize an entity linker.
|
||||||
|
|
||||||
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
name (str): The component instance name, used to add entries to the
|
||||||
|
losses during training.
|
||||||
|
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||||
|
n_sents (int): The number of neighbouring sentences to take into account.
|
||||||
|
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||||
|
incl_context (bool): Whether or not to include the local context in the model.
|
||||||
|
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||||
|
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||||
|
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||||
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
|
Scorer.score_links.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#init
|
||||||
|
"""
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
self.name = name
|
||||||
|
self.labels_discard = list(labels_discard)
|
||||||
|
self.n_sents = n_sents
|
||||||
|
self.incl_prior = incl_prior
|
||||||
|
self.incl_context = incl_context
|
||||||
|
self.get_candidates = get_candidates
|
||||||
|
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
||||||
|
self.distance = CosineDistance(normalize=False)
|
||||||
|
# how many neighbour sentences to take into account
|
||||||
|
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||||
|
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||||
|
self.scorer = scorer
|
||||||
|
|
||||||
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||||
|
"""Define the KB of this pipe by providing a function that will
|
||||||
|
create it using this object's vocab."""
|
||||||
|
if not callable(kb_loader):
|
||||||
|
raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
|
||||||
|
|
||||||
|
self.kb = kb_loader(self.vocab)
|
||||||
|
|
||||||
|
def validate_kb(self) -> None:
|
||||||
|
# Raise an error if the knowledge base is not initialized.
|
||||||
|
if self.kb is None:
|
||||||
|
raise ValueError(Errors.E1018.format(name=self.name))
|
||||||
|
if len(self.kb) == 0:
|
||||||
|
raise ValueError(Errors.E139.format(name=self.name))
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
|
*,
|
||||||
|
nlp: Optional[Language] = None,
|
||||||
|
kb_loader: Optional[Callable[[Vocab], KnowledgeBase]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
|
returns a representative sample of gold-standard Example objects.
|
||||||
|
nlp (Language): The current nlp object the component is part of.
|
||||||
|
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
|
||||||
|
Note that providing this argument, will overwrite all data accumulated in the current KB.
|
||||||
|
Use this only when loading a KB as-such from file.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#initialize
|
||||||
|
"""
|
||||||
|
validate_get_examples(get_examples, "EntityLinker_v1.initialize")
|
||||||
|
if kb_loader is not None:
|
||||||
|
self.set_kb(kb_loader)
|
||||||
|
self.validate_kb()
|
||||||
|
nO = self.kb.entity_vector_length
|
||||||
|
doc_sample = []
|
||||||
|
vector_sample = []
|
||||||
|
for example in islice(get_examples(), 10):
|
||||||
|
doc_sample.append(example.x)
|
||||||
|
vector_sample.append(self.model.ops.alloc1f(nO))
|
||||||
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(
|
||||||
|
X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Learn from a batch of documents and gold-standard information,
|
||||||
|
updating the pipe's model. Delegates to predict and get_loss.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
|
drop (float): The dropout rate.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#update
|
||||||
|
"""
|
||||||
|
self.validate_kb()
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
if not examples:
|
||||||
|
return losses
|
||||||
|
validate_examples(examples, "EntityLinker_v1.update")
|
||||||
|
sentence_docs = []
|
||||||
|
for eg in examples:
|
||||||
|
sentences = [s for s in eg.reference.sents]
|
||||||
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
|
for ent in eg.reference.ents:
|
||||||
|
# KB ID of the first token is the same as the whole span
|
||||||
|
kb_id = kb_ids[ent.start]
|
||||||
|
if kb_id:
|
||||||
|
try:
|
||||||
|
# find the sentence in the list of sentences.
|
||||||
|
sent_index = sentences.index(ent.sent)
|
||||||
|
except AttributeError:
|
||||||
|
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
||||||
|
raise RuntimeError(Errors.E030) from None
|
||||||
|
# get n previous sentences, if there are any
|
||||||
|
start_sentence = max(0, sent_index - self.n_sents)
|
||||||
|
# get n posterior sentences, or as many < n as there are
|
||||||
|
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
||||||
|
# get token positions
|
||||||
|
start_token = sentences[start_sentence].start
|
||||||
|
end_token = sentences[end_sentence].end
|
||||||
|
# append that span as a doc to training
|
||||||
|
sent_doc = eg.predicted[start_token:end_token].as_doc()
|
||||||
|
sentence_docs.append(sent_doc)
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
if not sentence_docs:
|
||||||
|
warnings.warn(Warnings.W093.format(name="Entity Linker"))
|
||||||
|
return losses
|
||||||
|
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||||
|
loss, d_scores = self.get_loss(
|
||||||
|
sentence_encodings=sentence_encodings, examples=examples
|
||||||
|
)
|
||||||
|
bp_context(d_scores)
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
losses[self.name] += loss
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
||||||
|
validate_examples(examples, "EntityLinker_v1.get_loss")
|
||||||
|
entity_encodings = []
|
||||||
|
for eg in examples:
|
||||||
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
|
for ent in eg.reference.ents:
|
||||||
|
kb_id = kb_ids[ent.start]
|
||||||
|
if kb_id:
|
||||||
|
entity_encoding = self.kb.get_vector(kb_id)
|
||||||
|
entity_encodings.append(entity_encoding)
|
||||||
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
if sentence_encodings.shape != entity_encodings.shape:
|
||||||
|
err = Errors.E147.format(
|
||||||
|
method="get_loss", msg="gold entities do not match up"
|
||||||
|
)
|
||||||
|
raise RuntimeError(err)
|
||||||
|
# TODO: fix typing issue here
|
||||||
|
gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
|
||||||
|
loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
|
||||||
|
loss = loss / len(entity_encodings)
|
||||||
|
return float(loss), gradients
|
||||||
|
|
||||||
|
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
||||||
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
Returns the KB IDs for each entity in each doc, including NIL if there is
|
||||||
|
no prediction.
|
||||||
|
|
||||||
|
docs (Iterable[Doc]): The documents to predict.
|
||||||
|
RETURNS (List[str]): The models prediction for each document.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#predict
|
||||||
|
"""
|
||||||
|
self.validate_kb()
|
||||||
|
entity_count = 0
|
||||||
|
final_kb_ids: List[str] = []
|
||||||
|
if not docs:
|
||||||
|
return final_kb_ids
|
||||||
|
if isinstance(docs, Doc):
|
||||||
|
docs = [docs]
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
sentences = [s for s in doc.sents]
|
||||||
|
if len(doc) > 0:
|
||||||
|
# Looping through each entity (TODO: rewrite)
|
||||||
|
for ent in doc.ents:
|
||||||
|
sent = ent.sent
|
||||||
|
sent_index = sentences.index(sent)
|
||||||
|
assert sent_index >= 0
|
||||||
|
# get n_neighbour sentences, clipped to the length of the document
|
||||||
|
start_sentence = max(0, sent_index - self.n_sents)
|
||||||
|
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
||||||
|
start_token = sentences[start_sentence].start
|
||||||
|
end_token = sentences[end_sentence].end
|
||||||
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
|
xp = self.model.ops.xp
|
||||||
|
if self.incl_context:
|
||||||
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
|
sentence_encoding_t = sentence_encoding.T
|
||||||
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
|
entity_count += 1
|
||||||
|
if ent.label_ in self.labels_discard:
|
||||||
|
# ignoring this entity - setting to NIL
|
||||||
|
final_kb_ids.append(self.NIL)
|
||||||
|
else:
|
||||||
|
candidates = list(self.get_candidates(self.kb, ent))
|
||||||
|
if not candidates:
|
||||||
|
# no prediction possible for this entity - setting to NIL
|
||||||
|
final_kb_ids.append(self.NIL)
|
||||||
|
elif len(candidates) == 1:
|
||||||
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
|
# TODO: thresholding
|
||||||
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
|
else:
|
||||||
|
random.shuffle(candidates)
|
||||||
|
# set all prior probabilities to 0 if incl_prior=False
|
||||||
|
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||||
|
if not self.incl_prior:
|
||||||
|
prior_probs = xp.asarray([0.0 for _ in candidates])
|
||||||
|
scores = prior_probs
|
||||||
|
# add in similarity from the context
|
||||||
|
if self.incl_context:
|
||||||
|
entity_encodings = xp.asarray(
|
||||||
|
[c.entity_vector for c in candidates]
|
||||||
|
)
|
||||||
|
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||||
|
if len(entity_encodings) != len(prior_probs):
|
||||||
|
raise RuntimeError(
|
||||||
|
Errors.E147.format(
|
||||||
|
method="predict",
|
||||||
|
msg="vectors not of equal length",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# cosine similarity
|
||||||
|
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
|
||||||
|
sentence_norm * entity_norm
|
||||||
|
)
|
||||||
|
if sims.shape != prior_probs.shape:
|
||||||
|
raise ValueError(Errors.E161)
|
||||||
|
scores = prior_probs + sims - (prior_probs * sims)
|
||||||
|
# TODO: thresholding
|
||||||
|
best_index = scores.argmax().item()
|
||||||
|
best_candidate = candidates[best_index]
|
||||||
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
if not (len(final_kb_ids) == entity_count):
|
||||||
|
err = Errors.E147.format(
|
||||||
|
method="predict", msg="result variables not of equal length"
|
||||||
|
)
|
||||||
|
raise RuntimeError(err)
|
||||||
|
return final_kb_ids
|
||||||
|
|
||||||
|
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
||||||
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
||||||
|
"""
|
||||||
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||||
|
if count_ents != len(kb_ids):
|
||||||
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||||
|
i = 0
|
||||||
|
overwrite = self.cfg["overwrite"]
|
||||||
|
for doc in docs:
|
||||||
|
for ent in doc.ents:
|
||||||
|
kb_id = kb_ids[i]
|
||||||
|
i += 1
|
||||||
|
for token in ent:
|
||||||
|
if token.ent_kb_id == 0 or overwrite:
|
||||||
|
token.ent_kb_id_ = kb_id
|
||||||
|
|
||||||
|
def to_bytes(self, *, exclude=tuple()):
|
||||||
|
"""Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
RETURNS (bytes): The serialized object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#to_bytes
|
||||||
|
"""
|
||||||
|
self._validate_serialization_attrs()
|
||||||
|
serialize = {}
|
||||||
|
if hasattr(self, "cfg") and self.cfg is not None:
|
||||||
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||||
|
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||||
|
serialize["kb"] = self.kb.to_bytes
|
||||||
|
serialize["model"] = self.model.to_bytes
|
||||||
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
|
"""Load the pipe from a bytestring.
|
||||||
|
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
RETURNS (TrainablePipe): The loaded object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#from_bytes
|
||||||
|
"""
|
||||||
|
self._validate_serialization_attrs()
|
||||||
|
|
||||||
|
def load_model(b):
|
||||||
|
try:
|
||||||
|
self.model.from_bytes(b)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
|
deserialize = {}
|
||||||
|
if hasattr(self, "cfg") and self.cfg is not None:
|
||||||
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||||
|
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
||||||
|
deserialize["kb"] = lambda b: self.kb.from_bytes(b)
|
||||||
|
deserialize["model"] = load_model
|
||||||
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(
|
||||||
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
|
) -> None:
|
||||||
|
"""Serialize the pipe to disk.
|
||||||
|
|
||||||
|
path (str / Path): Path to a directory.
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#to_disk
|
||||||
|
"""
|
||||||
|
serialize = {}
|
||||||
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
||||||
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
|
serialize["kb"] = lambda p: self.kb.to_disk(p)
|
||||||
|
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||||
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
def from_disk(
|
||||||
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
|
) -> "EntityLinker_v1":
|
||||||
|
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
path (str / Path): Path to a directory.
|
||||||
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
|
RETURNS (EntityLinker): The modified EntityLinker object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/entitylinker#from_disk
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_model(p):
|
||||||
|
try:
|
||||||
|
with p.open("rb") as infile:
|
||||||
|
self.model.from_bytes(infile.read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
|
deserialize: Dict[str, Callable[[Any], Any]] = {}
|
||||||
|
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
||||||
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
|
||||||
|
deserialize["kb"] = lambda p: self.kb.from_disk(p)
|
||||||
|
deserialize["model"] = load_model
|
||||||
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add_label(self, label):
|
||||||
|
raise NotImplementedError
|
|
@ -9,6 +9,9 @@ from spacy.compat import pickle
|
||||||
from spacy.kb import Candidate, KnowledgeBase, get_candidates
|
from spacy.kb import Candidate, KnowledgeBase, get_candidates
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.ml import load_kb
|
from spacy.ml import load_kb
|
||||||
|
from spacy.pipeline import EntityLinker
|
||||||
|
from spacy.pipeline.legacy import EntityLinker_v1
|
||||||
|
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
from spacy.tokens import Span
|
from spacy.tokens import Span
|
||||||
|
@ -168,6 +171,45 @@ def test_issue7065_b():
|
||||||
assert doc
|
assert doc
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_entities():
|
||||||
|
# Test that having no entities doesn't crash the model
|
||||||
|
TRAIN_DATA = [
|
||||||
|
(
|
||||||
|
"The sky is blue.",
|
||||||
|
{
|
||||||
|
"sent_starts": [1, 0, 0, 0, 0],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
nlp = English()
|
||||||
|
vector_length = 3
|
||||||
|
train_examples = []
|
||||||
|
for text, annotation in TRAIN_DATA:
|
||||||
|
doc = nlp(text)
|
||||||
|
train_examples.append(Example.from_dict(doc, annotation))
|
||||||
|
|
||||||
|
def create_kb(vocab):
|
||||||
|
# create artificial KB
|
||||||
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
|
mykb.add_alias("Russ Cochran", ["Q2146908"], [0.9])
|
||||||
|
return mykb
|
||||||
|
|
||||||
|
# Create and train the Entity Linker
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
for i in range(2):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
# adding additional components that are required for the entity_linker
|
||||||
|
nlp.add_pipe("sentencizer", first=True)
|
||||||
|
|
||||||
|
# this will run the pipeline on the examples and shouldn't crash
|
||||||
|
results = nlp.evaluate(train_examples)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_links():
|
def test_partial_links():
|
||||||
# Test that having some entities on the doc without gold links, doesn't crash
|
# Test that having some entities on the doc without gold links, doesn't crash
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
|
@ -650,7 +692,7 @@ TRAIN_DATA = [
|
||||||
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}),
|
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}),
|
||||||
("Russ Cochran his reprints include EC Comics.",
|
("Russ Cochran his reprints include EC Comics.",
|
||||||
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
||||||
"entities": [(0, 12, "PERSON")],
|
"entities": [(0, 12, "PERSON"), (34, 43, "ART")],
|
||||||
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0]}),
|
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0]}),
|
||||||
("Russ Cochran has been publishing comic art.",
|
("Russ Cochran has been publishing comic art.",
|
||||||
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
||||||
|
@ -693,6 +735,7 @@ def test_overfitting_IO():
|
||||||
|
|
||||||
# Create the Entity Linker component and add it to the pipeline
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||||
|
assert isinstance(entity_linker, EntityLinker)
|
||||||
entity_linker.set_kb(create_kb)
|
entity_linker.set_kb(create_kb)
|
||||||
assert "Q2146908" in entity_linker.vocab.strings
|
assert "Q2146908" in entity_linker.vocab.strings
|
||||||
assert "Q2146908" in entity_linker.kb.vocab.strings
|
assert "Q2146908" in entity_linker.kb.vocab.strings
|
||||||
|
@ -922,3 +965,109 @@ def test_scorer_links():
|
||||||
|
|
||||||
assert scores["nel_micro_p"] == 2 / 3
|
assert scores["nel_micro_p"] == 2 / 3
|
||||||
assert scores["nel_micro_r"] == 2 / 4
|
assert scores["nel_micro_r"] == 2 / 4
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,config",
|
||||||
|
[
|
||||||
|
("entity_linker", {"@architectures": "spacy.EntityLinker.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
|
||||||
|
("entity_linker", {"@architectures": "spacy.EntityLinker.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
def test_legacy_architectures(name, config):
|
||||||
|
# Ensure that the legacy architectures still work
|
||||||
|
vector_length = 3
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotation in TRAIN_DATA:
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
train_examples.append(Example.from_dict(doc, annotation))
|
||||||
|
|
||||||
|
def create_kb(vocab):
|
||||||
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
|
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||||
|
mykb.add_alias(
|
||||||
|
alias="Russ Cochran",
|
||||||
|
entities=["Q2146908", "Q7381115"],
|
||||||
|
probabilities=[0.5, 0.5],
|
||||||
|
)
|
||||||
|
return mykb
|
||||||
|
|
||||||
|
entity_linker = nlp.add_pipe(name, config={"model": config})
|
||||||
|
if config["@architectures"] == "spacy.EntityLinker.v1":
|
||||||
|
assert isinstance(entity_linker, EntityLinker_v1)
|
||||||
|
else:
|
||||||
|
assert isinstance(entity_linker, EntityLinker)
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("patterns", [
|
||||||
|
# perfect case
|
||||||
|
[{"label": "CHARACTER", "pattern": "Kirby"}],
|
||||||
|
# typo for false negative
|
||||||
|
[{"label": "PERSON", "pattern": "Korby"}],
|
||||||
|
# random stuff for false positive
|
||||||
|
[{"label": "IS", "pattern": "is"}, {"label": "COLOR", "pattern": "pink"}],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_no_gold_ents(patterns):
|
||||||
|
# test that annotating components work
|
||||||
|
TRAIN_DATA = [
|
||||||
|
(
|
||||||
|
"Kirby is pink",
|
||||||
|
{
|
||||||
|
"links": {(0, 5): {"Q613241": 1.0}},
|
||||||
|
"entities": [(0, 5, "CHARACTER")],
|
||||||
|
"sent_starts": [1, 0, 0],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
nlp = English()
|
||||||
|
vector_length = 3
|
||||||
|
train_examples = []
|
||||||
|
for text, annotation in TRAIN_DATA:
|
||||||
|
doc = nlp(text)
|
||||||
|
train_examples.append(Example.from_dict(doc, annotation))
|
||||||
|
|
||||||
|
# Create a ruler to mark entities
|
||||||
|
ruler = nlp.add_pipe("entity_ruler")
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
|
||||||
|
# Apply ruler to examples. In a real pipeline this would be an annotating component.
|
||||||
|
for eg in train_examples:
|
||||||
|
eg.predicted = ruler(eg.predicted)
|
||||||
|
|
||||||
|
def create_kb(vocab):
|
||||||
|
# create artificial KB
|
||||||
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
|
mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3])
|
||||||
|
mykb.add_alias("Kirby", ["Q613241"], [0.9])
|
||||||
|
# Placeholder
|
||||||
|
mykb.add_entity(entity="pink", freq=12, entity_vector=[7, 2, -5])
|
||||||
|
mykb.add_alias("pink", ["pink"], [0.9])
|
||||||
|
return mykb
|
||||||
|
|
||||||
|
|
||||||
|
# Create and train the Entity Linker
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config={"use_gold_ents": False}, last=True)
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
|
assert entity_linker.use_gold_ents == False
|
||||||
|
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
for i in range(2):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
# adding additional components that are required for the entity_linker
|
||||||
|
nlp.add_pipe("sentencizer", first=True)
|
||||||
|
|
||||||
|
# this will run the pipeline on the examples and shouldn't crash
|
||||||
|
results = nlp.evaluate(train_examples)
|
||||||
|
|
|
@ -256,6 +256,29 @@ cdef class Example:
|
||||||
x_ents, x_tags = self.get_aligned_ents_and_ner()
|
x_ents, x_tags = self.get_aligned_ents_and_ner()
|
||||||
return x_tags
|
return x_tags
|
||||||
|
|
||||||
|
def get_matching_ents(self, check_label=True):
|
||||||
|
"""Return entities that are shared between predicted and reference docs.
|
||||||
|
|
||||||
|
If `check_label` is True, entities must have matching labels to be
|
||||||
|
kept. Otherwise only the character indices need to match.
|
||||||
|
"""
|
||||||
|
gold = {}
|
||||||
|
for ent in self.reference:
|
||||||
|
gold[(ent.start_char, ent.end_char)] = ent.label
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
for ent in self.predicted:
|
||||||
|
key = (ent.start_char, ent.end_char)
|
||||||
|
if key not in gold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if check_label and ent.label != gold[key]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
keep.append(ent)
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"doc_annotation": {
|
"doc_annotation": {
|
||||||
|
|
|
@ -858,13 +858,13 @@ into the "real world". This requires 3 main components:
|
||||||
- A machine learning [`Model`](https://thinc.ai/docs/api-model) that picks the
|
- A machine learning [`Model`](https://thinc.ai/docs/api-model) that picks the
|
||||||
most plausible ID from the set of candidates.
|
most plausible ID from the set of candidates.
|
||||||
|
|
||||||
### spacy.EntityLinker.v1 {#EntityLinker}
|
### spacy.EntityLinker.v2 {#EntityLinker}
|
||||||
|
|
||||||
> #### Example Config
|
> #### Example Config
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [model]
|
> [model]
|
||||||
> @architectures = "spacy.EntityLinker.v1"
|
> @architectures = "spacy.EntityLinker.v2"
|
||||||
> nO = null
|
> nO = null
|
||||||
>
|
>
|
||||||
> [model.tok2vec]
|
> [model.tok2vec]
|
||||||
|
|
|
@ -59,6 +59,7 @@ architectures and their arguments and hyperparameters.
|
||||||
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
|
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
|
||||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
||||||
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
|
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
|
||||||
|
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
|
||||||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||||
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
|
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
|
||||||
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user