mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-04 06:16:33 +03:00
trainable_lemmatizer/entity_linker: add store_activations option
This commit is contained in:
parent
8772b9ccc4
commit
1c9be0d8ab
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
from thinc.types import ArrayXd, Floats2d, Ints1d
|
||||||
|
|
||||||
from ._edit_tree_internals.edit_trees import EditTrees
|
from ._edit_tree_internals.edit_trees import EditTrees
|
||||||
from ._edit_tree_internals.schemas import validate_edit_tree
|
from ._edit_tree_internals.schemas import validate_edit_tree
|
||||||
|
@ -21,6 +21,9 @@ from ..vocab import Vocab
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
|
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
|
||||||
|
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.Tagger.v2"
|
@architectures = "spacy.Tagger.v2"
|
||||||
|
@ -49,6 +52,7 @@ DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["mo
|
||||||
"overwrite": False,
|
"overwrite": False,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
|
||||||
|
"store_activations": False,
|
||||||
},
|
},
|
||||||
default_score_weights={"lemma_acc": 1.0},
|
default_score_weights={"lemma_acc": 1.0},
|
||||||
)
|
)
|
||||||
|
@ -61,6 +65,7 @@ def make_edit_tree_lemmatizer(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
|
store_activations: Union[bool, List[str]],
|
||||||
):
|
):
|
||||||
"""Construct an EditTreeLemmatizer component."""
|
"""Construct an EditTreeLemmatizer component."""
|
||||||
return EditTreeLemmatizer(
|
return EditTreeLemmatizer(
|
||||||
|
@ -72,6 +77,7 @@ def make_edit_tree_lemmatizer(
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
|
store_activations=store_activations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,6 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
top_k: int = 1,
|
top_k: int = 1,
|
||||||
scorer: Optional[Callable] = lemmatizer_score,
|
scorer: Optional[Callable] = lemmatizer_score,
|
||||||
|
store_activations=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Construct an edit tree lemmatizer.
|
Construct an edit tree lemmatizer.
|
||||||
|
@ -116,6 +123,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
self.cfg: Dict[str, Any] = {"labels": []}
|
self.cfg: Dict[str, Any] = {"labels": []}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.store_activations = store_activations # type: ignore
|
||||||
|
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||||
|
@ -144,21 +152,24 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
|
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
|
||||||
n_docs = len(list(docs))
|
n_docs = len(list(docs))
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
n_labels = len(self.cfg["labels"])
|
n_labels = len(self.cfg["labels"])
|
||||||
guesses: List[Ints2d] = [
|
guesses: List[Ints1d] = [
|
||||||
|
self.model.ops.alloc((0,), dtype="i") for doc in docs
|
||||||
|
]
|
||||||
|
scores: List[Floats2d] = [
|
||||||
self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
|
self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
|
||||||
]
|
]
|
||||||
assert len(guesses) == n_docs
|
assert len(guesses) == n_docs
|
||||||
return guesses
|
return {"probs": scores, "guesses": guesses}
|
||||||
scores = self.model.predict(docs)
|
scores = self.model.predict(docs)
|
||||||
assert len(scores) == n_docs
|
assert len(scores) == n_docs
|
||||||
guesses = self._scores2guesses(docs, scores)
|
guesses = self._scores2guesses(docs, scores)
|
||||||
assert len(guesses) == n_docs
|
assert len(guesses) == n_docs
|
||||||
return guesses
|
return {"probs": scores, "guesses": guesses}
|
||||||
|
|
||||||
def _scores2guesses(self, docs, scores):
|
def _scores2guesses(self, docs, scores):
|
||||||
guesses = []
|
guesses = []
|
||||||
|
@ -186,8 +197,12 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
return guesses
|
return guesses
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
|
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
|
||||||
|
batch_tree_ids = activations["guesses"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
doc.activations[self.name] = {}
|
||||||
|
for activation in self.store_activations:
|
||||||
|
doc.activations[self.name][activation] = activations[activation][i]
|
||||||
doc_tree_ids = batch_tree_ids[i]
|
doc_tree_ids = batch_tree_ids[i]
|
||||||
if hasattr(doc_tree_ids, "get"):
|
if hasattr(doc_tree_ids, "get"):
|
||||||
doc_tree_ids = doc_tree_ids.get()
|
doc_tree_ids = doc_tree_ids.get()
|
||||||
|
@ -377,3 +392,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
self.tree2label[tree_id] = len(self.cfg["labels"])
|
self.tree2label[tree_id] = len(self.cfg["labels"])
|
||||||
self.cfg["labels"].append(tree_id)
|
self.cfg["labels"].append(tree_id)
|
||||||
return self.tree2label[tree_id]
|
return self.tree2label[tree_id]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activations(self):
|
||||||
|
return ["probs", "guesses"]
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
|
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
|
||||||
from thinc.types import Floats2d
|
from typing import cast
|
||||||
|
from numpy import dtype
|
||||||
|
from thinc.types import Floats2d, Ragged
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -21,6 +23,9 @@ from ..util import SimpleFrozenList, registry
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
|
|
||||||
|
|
||||||
|
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
|
||||||
|
|
||||||
# See #9050
|
# See #9050
|
||||||
BACKWARD_OVERWRITE = True
|
BACKWARD_OVERWRITE = True
|
||||||
|
|
||||||
|
@ -56,6 +61,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
"overwrite": True,
|
"overwrite": True,
|
||||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||||
"use_gold_ents": True,
|
"use_gold_ents": True,
|
||||||
|
"store_activations": False,
|
||||||
},
|
},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
"nel_micro_f": 1.0,
|
"nel_micro_f": 1.0,
|
||||||
|
@ -77,6 +83,7 @@ def make_entity_linker(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
|
store_activations: Union[bool, List[str]],
|
||||||
):
|
):
|
||||||
"""Construct an EntityLinker component.
|
"""Construct an EntityLinker component.
|
||||||
|
|
||||||
|
@ -121,6 +128,7 @@ def make_entity_linker(
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
use_gold_ents=use_gold_ents,
|
use_gold_ents=use_gold_ents,
|
||||||
|
store_activations=store_activations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,6 +164,7 @@ class EntityLinker(TrainablePipe):
|
||||||
overwrite: bool = BACKWARD_OVERWRITE,
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
|
store_activations=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize an entity linker.
|
"""Initialize an entity linker.
|
||||||
|
|
||||||
|
@ -192,6 +201,7 @@ class EntityLinker(TrainablePipe):
|
||||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.use_gold_ents = use_gold_ents
|
self.use_gold_ents = use_gold_ents
|
||||||
|
self.store_activations = store_activations
|
||||||
|
|
||||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||||
"""Define the KB of this pipe by providing a function that will
|
"""Define the KB of this pipe by providing a function that will
|
||||||
|
@ -377,7 +387,7 @@ class EntityLinker(TrainablePipe):
|
||||||
loss = loss / len(entity_encodings)
|
loss = loss / len(entity_encodings)
|
||||||
return float(loss), out
|
return float(loss), out
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
Returns the KB IDs for each entity in each doc, including NIL if there is
|
Returns the KB IDs for each entity in each doc, including NIL if there is
|
||||||
no prediction.
|
no prediction.
|
||||||
|
@ -390,13 +400,21 @@ class EntityLinker(TrainablePipe):
|
||||||
self.validate_kb()
|
self.validate_kb()
|
||||||
entity_count = 0
|
entity_count = 0
|
||||||
final_kb_ids: List[str] = []
|
final_kb_ids: List[str] = []
|
||||||
xp = self.model.ops.xp
|
ops = self.model.ops
|
||||||
|
xp = ops.xp
|
||||||
|
docs_ents: List[Ragged] = []
|
||||||
|
docs_scores: List[Ragged] = []
|
||||||
if not docs:
|
if not docs:
|
||||||
return final_kb_ids
|
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
for i, doc in enumerate(docs):
|
for doc in docs:
|
||||||
|
doc_ents = []
|
||||||
|
doc_scores = []
|
||||||
|
doc_scores_lens: List[int] = []
|
||||||
if len(doc) == 0:
|
if len(doc) == 0:
|
||||||
|
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||||
|
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||||
continue
|
continue
|
||||||
sentences = [s for s in doc.sents]
|
sentences = [s for s in doc.sents]
|
||||||
# Looping through each entity (TODO: rewrite)
|
# Looping through each entity (TODO: rewrite)
|
||||||
|
@ -419,11 +437,17 @@ class EntityLinker(TrainablePipe):
|
||||||
if ent.label_ in self.labels_discard:
|
if ent.label_ in self.labels_discard:
|
||||||
# ignoring this entity - setting to NIL
|
# ignoring this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
|
self._add_activations(
|
||||||
|
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
candidates = list(self.get_candidates(self.kb, ent))
|
candidates = list(self.get_candidates(self.kb, ent))
|
||||||
if not candidates:
|
if not candidates:
|
||||||
# no prediction possible for this entity - setting to NIL
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
|
self._add_activations(
|
||||||
|
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
|
||||||
|
)
|
||||||
elif len(candidates) == 1:
|
elif len(candidates) == 1:
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
|
@ -456,30 +480,48 @@ class EntityLinker(TrainablePipe):
|
||||||
raise ValueError(Errors.E161)
|
raise ValueError(Errors.E161)
|
||||||
scores = prior_probs + sims - (prior_probs * sims)
|
scores = prior_probs + sims - (prior_probs * sims)
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
|
self._add_activations(
|
||||||
|
doc_scores,
|
||||||
|
doc_scores_lens,
|
||||||
|
doc_ents,
|
||||||
|
scores,
|
||||||
|
[c.entity for c in candidates],
|
||||||
|
)
|
||||||
best_index = scores.argmax().item()
|
best_index = scores.argmax().item()
|
||||||
best_candidate = candidates[best_index]
|
best_candidate = candidates[best_index]
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
self._add_doc_activations(
|
||||||
|
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
|
||||||
|
)
|
||||||
if not (len(final_kb_ids) == entity_count):
|
if not (len(final_kb_ids) == entity_count):
|
||||||
err = Errors.E147.format(
|
err = Errors.E147.format(
|
||||||
method="predict", msg="result variables not of equal length"
|
method="predict", msg="result variables not of equal length"
|
||||||
)
|
)
|
||||||
raise RuntimeError(err)
|
raise RuntimeError(err)
|
||||||
return final_kb_ids
|
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
|
activations (List[str]): The activations used for setting annotations, produced
|
||||||
|
by EntityLinker.predict.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
||||||
"""
|
"""
|
||||||
|
kb_ids = cast(List[str], activations["kb_ids"])
|
||||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||||
if count_ents != len(kb_ids):
|
if count_ents != len(kb_ids):
|
||||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||||
i = 0
|
i = 0
|
||||||
overwrite = self.cfg["overwrite"]
|
overwrite = self.cfg["overwrite"]
|
||||||
for doc in docs:
|
for j, doc in enumerate(docs):
|
||||||
|
doc.activations[self.name] = {}
|
||||||
|
for activation in self.store_activations:
|
||||||
|
# We only copy activations that are Ragged.
|
||||||
|
doc.activations[self.name][activation] = cast(
|
||||||
|
Ragged, activations[activation][j]
|
||||||
|
)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
kb_id = kb_ids[i]
|
kb_id = kb_ids[i]
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -578,3 +620,30 @@ class EntityLinker(TrainablePipe):
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activations(self):
|
||||||
|
return ["ents", "scores"]
|
||||||
|
|
||||||
|
def _add_doc_activations(
|
||||||
|
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
|
||||||
|
):
|
||||||
|
if len(self.store_activations) == 0:
|
||||||
|
return
|
||||||
|
ops = self.model.ops
|
||||||
|
docs_scores.append(
|
||||||
|
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
|
||||||
|
)
|
||||||
|
docs_ents.append(
|
||||||
|
Ragged(
|
||||||
|
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
|
||||||
|
if len(self.store_activations) == 0:
|
||||||
|
return
|
||||||
|
ops = self.model.ops
|
||||||
|
doc_scores.append(ops.asarray1f(scores))
|
||||||
|
doc_scores_lens.append(doc_scores[-1].shape[0])
|
||||||
|
doc_ents.append(ops.xp.array(ents, dtype="uint64"))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import cast
|
||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
from hypothesis import given
|
from hypothesis import given
|
||||||
|
@ -6,6 +7,7 @@ from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
|
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
|
||||||
|
from spacy.pipeline.trainable_pipe import TrainablePipe
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.strings import StringStore
|
from spacy.strings import StringStore
|
||||||
from spacy.util import make_tempdir
|
from spacy.util import make_tempdir
|
||||||
|
@ -278,3 +280,28 @@ def test_empty_strings():
|
||||||
no_change = trees.add("xyz", "xyz")
|
no_change = trees.add("xyz", "xyz")
|
||||||
empty = trees.add("", "")
|
empty = trees.add("", "")
|
||||||
assert no_change == empty
|
assert no_change == empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_activations():
|
||||||
|
nlp = English()
|
||||||
|
lemmatizer = cast(TrainablePipe, nlp.add_pipe("trainable_lemmatizer"))
|
||||||
|
lemmatizer.min_tree_freq = 1
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
nO = lemmatizer.model.get_dim("nO")
|
||||||
|
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
|
||||||
|
|
||||||
|
lemmatizer.store_activations = True
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
|
||||||
|
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||||
|
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
|
lemmatizer.store_activations = ["probs"]
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
|
||||||
|
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Callable, Iterable
|
from typing import Callable, Iterable, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
|
from thinc.types import Ragged
|
||||||
|
|
||||||
from spacy import registry, util
|
from spacy import registry, util
|
||||||
from spacy.attrs import ENT_KB_ID
|
from spacy.attrs import ENT_KB_ID
|
||||||
|
@ -9,7 +10,7 @@ from spacy.compat import pickle
|
||||||
from spacy.kb import Candidate, KnowledgeBase, get_candidates
|
from spacy.kb import Candidate, KnowledgeBase, get_candidates
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.ml import load_kb
|
from spacy.ml import load_kb
|
||||||
from spacy.pipeline import EntityLinker
|
from spacy.pipeline import EntityLinker, TrainablePipe
|
||||||
from spacy.pipeline.legacy import EntityLinker_v1
|
from spacy.pipeline.legacy import EntityLinker_v1
|
||||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
|
@ -1115,3 +1116,79 @@ def test_tokenization_mismatch():
|
||||||
|
|
||||||
nlp.add_pipe("sentencizer", first=True)
|
nlp.add_pipe("sentencizer", first=True)
|
||||||
results = nlp.evaluate(train_examples)
|
results = nlp.evaluate(train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_activations():
|
||||||
|
nlp = English()
|
||||||
|
vector_length = 3
|
||||||
|
assert "Q2146908" not in nlp.vocab.strings
|
||||||
|
|
||||||
|
# Convert the texts to docs to make sure we have doc.ents set for the training examples
|
||||||
|
train_examples = []
|
||||||
|
for text, annotation in TRAIN_DATA:
|
||||||
|
doc = nlp(text)
|
||||||
|
train_examples.append(Example.from_dict(doc, annotation))
|
||||||
|
|
||||||
|
def create_kb(vocab):
|
||||||
|
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||||
|
# Q2146908 (Russ Cochran): American golfer
|
||||||
|
# Q7381115 (Russ Cochran): publisher
|
||||||
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
|
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||||
|
mykb.add_alias(
|
||||||
|
alias="Russ Cochran",
|
||||||
|
entities=["Q2146908", "Q7381115"],
|
||||||
|
probabilities=[0.5, 0.5],
|
||||||
|
)
|
||||||
|
return mykb
|
||||||
|
|
||||||
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
|
entity_linker = cast(TrainablePipe, nlp.add_pipe("entity_linker", last=True))
|
||||||
|
assert isinstance(entity_linker, EntityLinker)
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
|
assert "Q2146908" in entity_linker.vocab.strings
|
||||||
|
assert "Q2146908" in entity_linker.kb.vocab.strings
|
||||||
|
|
||||||
|
# initialize the NEL pipe
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
nO = entity_linker.model.get_dim("nO")
|
||||||
|
|
||||||
|
nlp.add_pipe("sentencizer", first=True)
|
||||||
|
patterns = [
|
||||||
|
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
|
||||||
|
{"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
|
||||||
|
]
|
||||||
|
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
|
||||||
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
|
assert len(doc.activations["entity_linker"].keys()) == 0
|
||||||
|
|
||||||
|
entity_linker.store_activations = True
|
||||||
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
|
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
|
||||||
|
ents = doc.activations["entity_linker"]["ents"]
|
||||||
|
assert isinstance(ents, Ragged)
|
||||||
|
assert ents.data.shape == (2, 1)
|
||||||
|
assert ents.data.dtype == "uint64"
|
||||||
|
assert ents.lengths.shape == (1,)
|
||||||
|
scores = doc.activations["entity_linker"]["scores"]
|
||||||
|
assert isinstance(scores, Ragged)
|
||||||
|
assert scores.data.shape == (2, 1)
|
||||||
|
assert scores.data.dtype == "float32"
|
||||||
|
assert scores.lengths.shape == (1,)
|
||||||
|
|
||||||
|
entity_linker.store_activations = ["scores"]
|
||||||
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
|
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
|
||||||
|
scores = doc.activations["entity_linker"]["scores"]
|
||||||
|
assert isinstance(scores, Ragged)
|
||||||
|
assert scores.data.shape == (2, 1)
|
||||||
|
assert scores.data.dtype == "float32"
|
||||||
|
assert scores.lengths.shape == (1,)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Callable, Protocol, Iterable, Iterator, Optional
|
from typing import Callable, Protocol, Iterable, Iterator, Optional
|
||||||
from typing import Union, Tuple, List, Dict, Any, overload
|
from typing import Union, Tuple, List, Dict, Any, overload
|
||||||
from cymem.cymem import Pool
|
from cymem.cymem import Pool
|
||||||
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d
|
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
|
||||||
from .span import Span
|
from .span import Span
|
||||||
from .token import Token
|
from .token import Token
|
||||||
from ._dict_proxies import SpanGroups
|
from ._dict_proxies import SpanGroups
|
||||||
|
@ -22,7 +22,7 @@ class Doc:
|
||||||
max_length: int
|
max_length: int
|
||||||
length: int
|
length: int
|
||||||
sentiment: float
|
sentiment: float
|
||||||
activations: Dict[str, Dict[str, ArrayXd]]
|
activations: Dict[str, Dict[str, Union[ArrayXd, Ragged]]]
|
||||||
cats: Dict[str, float]
|
cats: Dict[str, float]
|
||||||
user_hooks: Dict[str, Callable[..., Any]]
|
user_hooks: Dict[str, Callable[..., Any]]
|
||||||
user_token_hooks: Dict[str, Callable[..., Any]]
|
user_token_hooks: Dict[str, Callable[..., Any]]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user