trainable_lemmatizer/entity_linker: add store_activations option

This commit is contained in:
Daniël de Kok 2022-06-23 15:47:00 +02:00
parent 8772b9ccc4
commit 1c9be0d8ab
5 changed files with 211 additions and 19 deletions

View File

@ -7,7 +7,7 @@ import numpy as np
import srsly import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy from thinc.api import Config, Model, SequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.types import ArrayXd, Floats2d, Ints1d
from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree from ._edit_tree_internals.schemas import validate_edit_tree
@ -21,6 +21,9 @@ from ..vocab import Vocab
from .. import util from .. import util
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
default_model_config = """ default_model_config = """
[model] [model]
@architectures = "spacy.Tagger.v2" @architectures = "spacy.Tagger.v2"
@ -49,6 +52,7 @@ DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["mo
"overwrite": False, "overwrite": False,
"top_k": 1, "top_k": 1,
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
"store_activations": False,
}, },
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
@ -61,6 +65,7 @@ def make_edit_tree_lemmatizer(
overwrite: bool, overwrite: bool,
top_k: int, top_k: int,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]],
): ):
"""Construct an EditTreeLemmatizer component.""" """Construct an EditTreeLemmatizer component."""
return EditTreeLemmatizer( return EditTreeLemmatizer(
@ -72,6 +77,7 @@ def make_edit_tree_lemmatizer(
overwrite=overwrite, overwrite=overwrite,
top_k=top_k, top_k=top_k,
scorer=scorer, scorer=scorer,
store_activations=store_activations,
) )
@ -91,6 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
overwrite: bool = False, overwrite: bool = False,
top_k: int = 1, top_k: int = 1,
scorer: Optional[Callable] = lemmatizer_score, scorer: Optional[Callable] = lemmatizer_score,
store_activations=False,
): ):
""" """
Construct an edit tree lemmatizer. Construct an edit tree lemmatizer.
@ -116,6 +123,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []} self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer self.scorer = scorer
self.store_activations = store_activations # type: ignore
def get_loss( def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d] self, examples: Iterable[Example], scores: List[Floats2d]
@ -144,21 +152,24 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores return float(loss), d_scores
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: def predict(self, docs: Iterable[Doc]) -> ActivationsT:
n_docs = len(list(docs)) n_docs = len(list(docs))
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
n_labels = len(self.cfg["labels"]) n_labels = len(self.cfg["labels"])
guesses: List[Ints2d] = [ guesses: List[Ints1d] = [
self.model.ops.alloc((0,), dtype="i") for doc in docs
]
scores: List[Floats2d] = [
self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
] ]
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return {"probs": scores, "guesses": guesses}
scores = self.model.predict(docs) scores = self.model.predict(docs)
assert len(scores) == n_docs assert len(scores) == n_docs
guesses = self._scores2guesses(docs, scores) guesses = self._scores2guesses(docs, scores)
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return {"probs": scores, "guesses": guesses}
def _scores2guesses(self, docs, scores): def _scores2guesses(self, docs, scores):
guesses = [] guesses = []
@ -186,8 +197,12 @@ class EditTreeLemmatizer(TrainablePipe):
return guesses return guesses
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["guesses"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
doc_tree_ids = batch_tree_ids[i] doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"): if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get() doc_tree_ids = doc_tree_ids.get()
@ -377,3 +392,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.tree2label[tree_id] = len(self.cfg["labels"]) self.tree2label[tree_id] = len(self.cfg["labels"])
self.cfg["labels"].append(tree_id) self.cfg["labels"].append(tree_id)
return self.tree2label[tree_id] return self.tree2label[tree_id]
@property
def activations(self):
return ["probs", "guesses"]

View File

@ -1,5 +1,7 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any from typing import Optional, Iterable, Callable, Dict, Union, List, Any
from thinc.types import Floats2d from typing import cast
from numpy import dtype
from thinc.types import Floats2d, Ragged
from pathlib import Path from pathlib import Path
from itertools import islice from itertools import islice
import srsly import srsly
@ -21,6 +23,9 @@ from ..util import SimpleFrozenList, registry
from .. import util from .. import util
from ..scorer import Scorer from ..scorer import Scorer
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
# See #9050 # See #9050
BACKWARD_OVERWRITE = True BACKWARD_OVERWRITE = True
@ -56,6 +61,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"overwrite": True, "overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True, "use_gold_ents": True,
"store_activations": False,
}, },
default_score_weights={ default_score_weights={
"nel_micro_f": 1.0, "nel_micro_f": 1.0,
@ -77,6 +83,7 @@ def make_entity_linker(
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
store_activations: Union[bool, List[str]],
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -121,6 +128,7 @@ def make_entity_linker(
overwrite=overwrite, overwrite=overwrite,
scorer=scorer, scorer=scorer,
use_gold_ents=use_gold_ents, use_gold_ents=use_gold_ents,
store_activations=store_activations,
) )
@ -156,6 +164,7 @@ class EntityLinker(TrainablePipe):
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
store_activations=False,
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -192,6 +201,7 @@ class EntityLinker(TrainablePipe):
self.kb = empty_kb(entity_vector_length)(self.vocab) self.kb = empty_kb(entity_vector_length)(self.vocab)
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.store_activations = store_activations
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will """Define the KB of this pipe by providing a function that will
@ -377,7 +387,7 @@ class EntityLinker(TrainablePipe):
loss = loss / len(entity_encodings) loss = loss / len(entity_encodings)
return float(loss), out return float(loss), out
def predict(self, docs: Iterable[Doc]) -> List[str]: def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
Returns the KB IDs for each entity in each doc, including NIL if there is Returns the KB IDs for each entity in each doc, including NIL if there is
no prediction. no prediction.
@ -390,13 +400,21 @@ class EntityLinker(TrainablePipe):
self.validate_kb() self.validate_kb()
entity_count = 0 entity_count = 0
final_kb_ids: List[str] = [] final_kb_ids: List[str] = []
xp = self.model.ops.xp ops = self.model.ops
xp = ops.xp
docs_ents: List[Ragged] = []
docs_scores: List[Ragged] = []
if not docs: if not docs:
return final_kb_ids return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for i, doc in enumerate(docs): for doc in docs:
doc_ents = []
doc_scores = []
doc_scores_lens: List[int] = []
if len(doc) == 0: if len(doc) == 0:
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue continue
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
# Looping through each entity (TODO: rewrite) # Looping through each entity (TODO: rewrite)
@ -419,11 +437,17 @@ class EntityLinker(TrainablePipe):
if ent.label_ in self.labels_discard: if ent.label_ in self.labels_discard:
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
)
else: else:
candidates = list(self.get_candidates(self.kb, ent)) candidates = list(self.get_candidates(self.kb, ent))
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
)
elif len(candidates) == 1: elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
# TODO: thresholding # TODO: thresholding
@ -456,30 +480,48 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E161) raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims) scores = prior_probs + sims - (prior_probs * sims)
# TODO: thresholding # TODO: thresholding
self._add_activations(
doc_scores,
doc_scores_lens,
doc_ents,
scores,
[c.entity for c in candidates],
)
best_index = scores.argmax().item() best_index = scores.argmax().item()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
self._add_doc_activations(
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
)
if not (len(final_kb_ids) == entity_count): if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format( err = Errors.E147.format(
method="predict", msg="result variables not of equal length" method="predict", msg="result variables not of equal length"
) )
raise RuntimeError(err) raise RuntimeError(err)
return final_kb_ids return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict. activations (List[str]): The activations used for setting annotations, produced
by EntityLinker.predict.
DOCS: https://spacy.io/api/entitylinker#set_annotations DOCS: https://spacy.io/api/entitylinker#set_annotations
""" """
kb_ids = cast(List[str], activations["kb_ids"])
count_ents = len([ent for doc in docs for ent in doc.ents]) count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids): if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
i = 0 i = 0
overwrite = self.cfg["overwrite"] overwrite = self.cfg["overwrite"]
for doc in docs: for j, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
# We only copy activations that are Ragged.
doc.activations[self.name][activation] = cast(
Ragged, activations[activation][j]
)
for ent in doc.ents: for ent in doc.ents:
kb_id = kb_ids[i] kb_id = kb_ids[i]
i += 1 i += 1
@ -578,3 +620,30 @@ class EntityLinker(TrainablePipe):
def add_label(self, label): def add_label(self, label):
raise NotImplementedError raise NotImplementedError
@property
def activations(self):
return ["ents", "scores"]
def _add_doc_activations(
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
):
if len(self.store_activations) == 0:
return
ops = self.model.ops
docs_scores.append(
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
)
docs_ents.append(
Ragged(
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
)
)
def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
if len(self.store_activations) == 0:
return
ops = self.model.ops
doc_scores.append(ops.asarray1f(scores))
doc_scores_lens.append(doc_scores[-1].shape[0])
doc_ents.append(ops.xp.array(ents, dtype="uint64"))

View File

@ -1,3 +1,4 @@
from typing import cast
import pickle import pickle
import pytest import pytest
from hypothesis import given from hypothesis import given
@ -6,6 +7,7 @@ from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
from spacy.pipeline.trainable_pipe import TrainablePipe
from spacy.training import Example from spacy.training import Example
from spacy.strings import StringStore from spacy.strings import StringStore
from spacy.util import make_tempdir from spacy.util import make_tempdir
@ -278,3 +280,28 @@ def test_empty_strings():
no_change = trees.add("xyz", "xyz") no_change = trees.add("xyz", "xyz")
empty = trees.add("", "") empty = trees.add("", "")
assert no_change == empty assert no_change == empty
def test_store_activations():
nlp = English()
lemmatizer = cast(TrainablePipe, nlp.add_pipe("trainable_lemmatizer"))
lemmatizer.min_tree_freq = 1
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
nO = lemmatizer.model.get_dim("nO")
doc = nlp("This is a test.")
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
lemmatizer.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
lemmatizer.store_activations = ["probs"]
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)

View File

@ -1,7 +1,8 @@
from typing import Callable, Iterable from typing import Callable, Iterable, cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from thinc.types import Ragged
from spacy import registry, util from spacy import registry, util
from spacy.attrs import ENT_KB_ID from spacy.attrs import ENT_KB_ID
@ -9,7 +10,7 @@ from spacy.compat import pickle
from spacy.kb import Candidate, KnowledgeBase, get_candidates from spacy.kb import Candidate, KnowledgeBase, get_candidates
from spacy.lang.en import English from spacy.lang.en import English
from spacy.ml import load_kb from spacy.ml import load_kb
from spacy.pipeline import EntityLinker from spacy.pipeline import EntityLinker, TrainablePipe
from spacy.pipeline.legacy import EntityLinker_v1 from spacy.pipeline.legacy import EntityLinker_v1
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer from spacy.scorer import Scorer
@ -1115,3 +1116,79 @@ def test_tokenization_mismatch():
nlp.add_pipe("sentencizer", first=True) nlp.add_pipe("sentencizer", first=True)
results = nlp.evaluate(train_examples) results = nlp.evaluate(train_examples)
def test_store_activations():
nlp = English()
vector_length = 3
assert "Q2146908" not in nlp.vocab.strings
# Convert the texts to docs to make sure we have doc.ents set for the training examples
train_examples = []
for text, annotation in TRAIN_DATA:
doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation))
def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="Russ Cochran",
entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = cast(TrainablePipe, nlp.add_pipe("entity_linker", last=True))
assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
assert "Q2146908" in entity_linker.kb.vocab.strings
# initialize the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
nO = entity_linker.model.get_dim("nO")
nlp.add_pipe("sentencizer", first=True)
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
{"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
doc = nlp("Russ Cochran was a publisher")
assert len(doc.activations["entity_linker"].keys()) == 0
entity_linker.store_activations = True
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
ents = doc.activations["entity_linker"]["ents"]
assert isinstance(ents, Ragged)
assert ents.data.shape == (2, 1)
assert ents.data.dtype == "uint64"
assert ents.lengths.shape == (1,)
scores = doc.activations["entity_linker"]["scores"]
assert isinstance(scores, Ragged)
assert scores.data.shape == (2, 1)
assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,)
entity_linker.store_activations = ["scores"]
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
scores = doc.activations["entity_linker"]["scores"]
assert isinstance(scores, Ragged)
assert scores.data.shape == (2, 1)
assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,)

View File

@ -1,7 +1,7 @@
from typing import Callable, Protocol, Iterable, Iterator, Optional from typing import Callable, Protocol, Iterable, Iterator, Optional
from typing import Union, Tuple, List, Dict, Any, overload from typing import Union, Tuple, List, Dict, Any, overload
from cymem.cymem import Pool from cymem.cymem import Pool
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
from .span import Span from .span import Span
from .token import Token from .token import Token
from ._dict_proxies import SpanGroups from ._dict_proxies import SpanGroups
@ -22,7 +22,7 @@ class Doc:
max_length: int max_length: int
length: int length: int
sentiment: float sentiment: float
activations: Dict[str, Dict[str, ArrayXd]] activations: Dict[str, Dict[str, Union[ArrayXd, Ragged]]]
cats: Dict[str, float] cats: Dict[str, float]
user_hooks: Dict[str, Callable[..., Any]] user_hooks: Dict[str, Callable[..., Any]]
user_token_hooks: Dict[str, Callable[..., Any]] user_token_hooks: Dict[str, Callable[..., Any]]