mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	trainable_lemmatizer/entity_linker: add store_activations option
This commit is contained in:
		
							parent
							
								
									8772b9ccc4
								
							
						
					
					
						commit
						1c9be0d8ab
					
				| 
						 | 
				
			
			@ -7,7 +7,7 @@ import numpy as np
 | 
			
		|||
 | 
			
		||||
import srsly
 | 
			
		||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
 | 
			
		||||
from thinc.types import Floats2d, Ints1d, Ints2d
 | 
			
		||||
from thinc.types import ArrayXd, Floats2d, Ints1d
 | 
			
		||||
 | 
			
		||||
from ._edit_tree_internals.edit_trees import EditTrees
 | 
			
		||||
from ._edit_tree_internals.schemas import validate_edit_tree
 | 
			
		||||
| 
						 | 
				
			
			@ -21,6 +21,9 @@ from ..vocab import Vocab
 | 
			
		|||
from .. import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
default_model_config = """
 | 
			
		||||
[model]
 | 
			
		||||
@architectures = "spacy.Tagger.v2"
 | 
			
		||||
| 
						 | 
				
			
			@ -49,6 +52,7 @@ DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["mo
 | 
			
		|||
        "overwrite": False,
 | 
			
		||||
        "top_k": 1,
 | 
			
		||||
        "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
 | 
			
		||||
        "store_activations": False,
 | 
			
		||||
    },
 | 
			
		||||
    default_score_weights={"lemma_acc": 1.0},
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +65,7 @@ def make_edit_tree_lemmatizer(
 | 
			
		|||
    overwrite: bool,
 | 
			
		||||
    top_k: int,
 | 
			
		||||
    scorer: Optional[Callable],
 | 
			
		||||
    store_activations: Union[bool, List[str]],
 | 
			
		||||
):
 | 
			
		||||
    """Construct an EditTreeLemmatizer component."""
 | 
			
		||||
    return EditTreeLemmatizer(
 | 
			
		||||
| 
						 | 
				
			
			@ -72,6 +77,7 @@ def make_edit_tree_lemmatizer(
 | 
			
		|||
        overwrite=overwrite,
 | 
			
		||||
        top_k=top_k,
 | 
			
		||||
        scorer=scorer,
 | 
			
		||||
        store_activations=store_activations,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -91,6 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
			
		|||
        overwrite: bool = False,
 | 
			
		||||
        top_k: int = 1,
 | 
			
		||||
        scorer: Optional[Callable] = lemmatizer_score,
 | 
			
		||||
        store_activations=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Construct an edit tree lemmatizer.
 | 
			
		||||
| 
						 | 
				
			
			@ -116,6 +123,7 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
        self.cfg: Dict[str, Any] = {"labels": []}
 | 
			
		||||
        self.scorer = scorer
 | 
			
		||||
        self.store_activations = store_activations  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def get_loss(
 | 
			
		||||
        self, examples: Iterable[Example], scores: List[Floats2d]
 | 
			
		||||
| 
						 | 
				
			
			@ -144,21 +152,24 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
        return float(loss), d_scores
 | 
			
		||||
 | 
			
		||||
    def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
 | 
			
		||||
    def predict(self, docs: Iterable[Doc]) -> ActivationsT:
 | 
			
		||||
        n_docs = len(list(docs))
 | 
			
		||||
        if not any(len(doc) for doc in docs):
 | 
			
		||||
            # Handle cases where there are no tokens in any docs.
 | 
			
		||||
            n_labels = len(self.cfg["labels"])
 | 
			
		||||
            guesses: List[Ints2d] = [
 | 
			
		||||
            guesses: List[Ints1d] = [
 | 
			
		||||
                self.model.ops.alloc((0,), dtype="i") for doc in docs
 | 
			
		||||
            ]
 | 
			
		||||
            scores: List[Floats2d] = [
 | 
			
		||||
                self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
 | 
			
		||||
            ]
 | 
			
		||||
            assert len(guesses) == n_docs
 | 
			
		||||
            return guesses
 | 
			
		||||
            return {"probs": scores, "guesses": guesses}
 | 
			
		||||
        scores = self.model.predict(docs)
 | 
			
		||||
        assert len(scores) == n_docs
 | 
			
		||||
        guesses = self._scores2guesses(docs, scores)
 | 
			
		||||
        assert len(guesses) == n_docs
 | 
			
		||||
        return guesses
 | 
			
		||||
        return {"probs": scores, "guesses": guesses}
 | 
			
		||||
 | 
			
		||||
    def _scores2guesses(self, docs, scores):
 | 
			
		||||
        guesses = []
 | 
			
		||||
| 
						 | 
				
			
			@ -186,8 +197,12 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
        return guesses
 | 
			
		||||
 | 
			
		||||
    def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
 | 
			
		||||
    def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
 | 
			
		||||
        batch_tree_ids = activations["guesses"]
 | 
			
		||||
        for i, doc in enumerate(docs):
 | 
			
		||||
            doc.activations[self.name] = {}
 | 
			
		||||
            for activation in self.store_activations:
 | 
			
		||||
                doc.activations[self.name][activation] = activations[activation][i]
 | 
			
		||||
            doc_tree_ids = batch_tree_ids[i]
 | 
			
		||||
            if hasattr(doc_tree_ids, "get"):
 | 
			
		||||
                doc_tree_ids = doc_tree_ids.get()
 | 
			
		||||
| 
						 | 
				
			
			@ -377,3 +392,7 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
			
		|||
            self.tree2label[tree_id] = len(self.cfg["labels"])
 | 
			
		||||
            self.cfg["labels"].append(tree_id)
 | 
			
		||||
        return self.tree2label[tree_id]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def activations(self):
 | 
			
		||||
        return ["probs", "guesses"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,7 @@
 | 
			
		|||
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
 | 
			
		||||
from thinc.types import Floats2d
 | 
			
		||||
from typing import cast
 | 
			
		||||
from numpy import dtype
 | 
			
		||||
from thinc.types import Floats2d, Ragged
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from itertools import islice
 | 
			
		||||
import srsly
 | 
			
		||||
| 
						 | 
				
			
			@ -21,6 +23,9 @@ from ..util import SimpleFrozenList, registry
 | 
			
		|||
from .. import util
 | 
			
		||||
from ..scorer import Scorer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
 | 
			
		||||
 | 
			
		||||
# See #9050
 | 
			
		||||
BACKWARD_OVERWRITE = True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -56,6 +61,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
 | 
			
		|||
        "overwrite": True,
 | 
			
		||||
        "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
 | 
			
		||||
        "use_gold_ents": True,
 | 
			
		||||
        "store_activations": False,
 | 
			
		||||
    },
 | 
			
		||||
    default_score_weights={
 | 
			
		||||
        "nel_micro_f": 1.0,
 | 
			
		||||
| 
						 | 
				
			
			@ -77,6 +83,7 @@ def make_entity_linker(
 | 
			
		|||
    overwrite: bool,
 | 
			
		||||
    scorer: Optional[Callable],
 | 
			
		||||
    use_gold_ents: bool,
 | 
			
		||||
    store_activations: Union[bool, List[str]],
 | 
			
		||||
):
 | 
			
		||||
    """Construct an EntityLinker component.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -121,6 +128,7 @@ def make_entity_linker(
 | 
			
		|||
        overwrite=overwrite,
 | 
			
		||||
        scorer=scorer,
 | 
			
		||||
        use_gold_ents=use_gold_ents,
 | 
			
		||||
        store_activations=store_activations,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -156,6 +164,7 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
        overwrite: bool = BACKWARD_OVERWRITE,
 | 
			
		||||
        scorer: Optional[Callable] = entity_linker_score,
 | 
			
		||||
        use_gold_ents: bool,
 | 
			
		||||
        store_activations=False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Initialize an entity linker.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -192,6 +201,7 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
        self.kb = empty_kb(entity_vector_length)(self.vocab)
 | 
			
		||||
        self.scorer = scorer
 | 
			
		||||
        self.use_gold_ents = use_gold_ents
 | 
			
		||||
        self.store_activations = store_activations
 | 
			
		||||
 | 
			
		||||
    def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
 | 
			
		||||
        """Define the KB of this pipe by providing a function that will
 | 
			
		||||
| 
						 | 
				
			
			@ -377,7 +387,7 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
        loss = loss / len(entity_encodings)
 | 
			
		||||
        return float(loss), out
 | 
			
		||||
 | 
			
		||||
    def predict(self, docs: Iterable[Doc]) -> List[str]:
 | 
			
		||||
    def predict(self, docs: Iterable[Doc]) -> ActivationsT:
 | 
			
		||||
        """Apply the pipeline's model to a batch of docs, without modifying them.
 | 
			
		||||
        Returns the KB IDs for each entity in each doc, including NIL if there is
 | 
			
		||||
        no prediction.
 | 
			
		||||
| 
						 | 
				
			
			@ -390,13 +400,21 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
        self.validate_kb()
 | 
			
		||||
        entity_count = 0
 | 
			
		||||
        final_kb_ids: List[str] = []
 | 
			
		||||
        xp = self.model.ops.xp
 | 
			
		||||
        ops = self.model.ops
 | 
			
		||||
        xp = ops.xp
 | 
			
		||||
        docs_ents: List[Ragged] = []
 | 
			
		||||
        docs_scores: List[Ragged] = []
 | 
			
		||||
        if not docs:
 | 
			
		||||
            return final_kb_ids
 | 
			
		||||
            return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
 | 
			
		||||
        if isinstance(docs, Doc):
 | 
			
		||||
            docs = [docs]
 | 
			
		||||
        for i, doc in enumerate(docs):
 | 
			
		||||
        for doc in docs:
 | 
			
		||||
            doc_ents = []
 | 
			
		||||
            doc_scores = []
 | 
			
		||||
            doc_scores_lens: List[int] = []
 | 
			
		||||
            if len(doc) == 0:
 | 
			
		||||
                doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
 | 
			
		||||
                doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
 | 
			
		||||
                continue
 | 
			
		||||
            sentences = [s for s in doc.sents]
 | 
			
		||||
            # Looping through each entity (TODO: rewrite)
 | 
			
		||||
| 
						 | 
				
			
			@ -419,11 +437,17 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
                if ent.label_ in self.labels_discard:
 | 
			
		||||
                    # ignoring this entity - setting to NIL
 | 
			
		||||
                    final_kb_ids.append(self.NIL)
 | 
			
		||||
                    self._add_activations(
 | 
			
		||||
                        doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    candidates = list(self.get_candidates(self.kb, ent))
 | 
			
		||||
                    if not candidates:
 | 
			
		||||
                        # no prediction possible for this entity - setting to NIL
 | 
			
		||||
                        final_kb_ids.append(self.NIL)
 | 
			
		||||
                        self._add_activations(
 | 
			
		||||
                            doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
 | 
			
		||||
                        )
 | 
			
		||||
                    elif len(candidates) == 1:
 | 
			
		||||
                        # shortcut for efficiency reasons: take the 1 candidate
 | 
			
		||||
                        # TODO: thresholding
 | 
			
		||||
| 
						 | 
				
			
			@ -456,30 +480,48 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
                                raise ValueError(Errors.E161)
 | 
			
		||||
                            scores = prior_probs + sims - (prior_probs * sims)
 | 
			
		||||
                        # TODO: thresholding
 | 
			
		||||
                        self._add_activations(
 | 
			
		||||
                            doc_scores,
 | 
			
		||||
                            doc_scores_lens,
 | 
			
		||||
                            doc_ents,
 | 
			
		||||
                            scores,
 | 
			
		||||
                            [c.entity for c in candidates],
 | 
			
		||||
                        )
 | 
			
		||||
                        best_index = scores.argmax().item()
 | 
			
		||||
                        best_candidate = candidates[best_index]
 | 
			
		||||
                        final_kb_ids.append(best_candidate.entity_)
 | 
			
		||||
            self._add_doc_activations(
 | 
			
		||||
                docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
 | 
			
		||||
            )
 | 
			
		||||
        if not (len(final_kb_ids) == entity_count):
 | 
			
		||||
            err = Errors.E147.format(
 | 
			
		||||
                method="predict", msg="result variables not of equal length"
 | 
			
		||||
            )
 | 
			
		||||
            raise RuntimeError(err)
 | 
			
		||||
        return final_kb_ids
 | 
			
		||||
        return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
 | 
			
		||||
 | 
			
		||||
    def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
 | 
			
		||||
    def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
 | 
			
		||||
        """Modify a batch of documents, using pre-computed scores.
 | 
			
		||||
 | 
			
		||||
        docs (Iterable[Doc]): The documents to modify.
 | 
			
		||||
        kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
 | 
			
		||||
        activations (List[str]): The activations used for setting annotations, produced
 | 
			
		||||
                                 by EntityLinker.predict.
 | 
			
		||||
 | 
			
		||||
        DOCS: https://spacy.io/api/entitylinker#set_annotations
 | 
			
		||||
        """
 | 
			
		||||
        kb_ids = cast(List[str], activations["kb_ids"])
 | 
			
		||||
        count_ents = len([ent for doc in docs for ent in doc.ents])
 | 
			
		||||
        if count_ents != len(kb_ids):
 | 
			
		||||
            raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
 | 
			
		||||
        i = 0
 | 
			
		||||
        overwrite = self.cfg["overwrite"]
 | 
			
		||||
        for doc in docs:
 | 
			
		||||
        for j, doc in enumerate(docs):
 | 
			
		||||
            doc.activations[self.name] = {}
 | 
			
		||||
            for activation in self.store_activations:
 | 
			
		||||
                # We only copy activations that are Ragged.
 | 
			
		||||
                doc.activations[self.name][activation] = cast(
 | 
			
		||||
                    Ragged, activations[activation][j]
 | 
			
		||||
                )
 | 
			
		||||
            for ent in doc.ents:
 | 
			
		||||
                kb_id = kb_ids[i]
 | 
			
		||||
                i += 1
 | 
			
		||||
| 
						 | 
				
			
			@ -578,3 +620,30 @@ class EntityLinker(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
    def add_label(self, label):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def activations(self):
 | 
			
		||||
        return ["ents", "scores"]
 | 
			
		||||
 | 
			
		||||
    def _add_doc_activations(
 | 
			
		||||
        self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
 | 
			
		||||
    ):
 | 
			
		||||
        if len(self.store_activations) == 0:
 | 
			
		||||
            return
 | 
			
		||||
        ops = self.model.ops
 | 
			
		||||
        docs_scores.append(
 | 
			
		||||
            Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
 | 
			
		||||
        )
 | 
			
		||||
        docs_ents.append(
 | 
			
		||||
            Ragged(
 | 
			
		||||
                ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
 | 
			
		||||
        if len(self.store_activations) == 0:
 | 
			
		||||
            return
 | 
			
		||||
        ops = self.model.ops
 | 
			
		||||
        doc_scores.append(ops.asarray1f(scores))
 | 
			
		||||
        doc_scores_lens.append(doc_scores[-1].shape[0])
 | 
			
		||||
        doc_ents.append(ops.xp.array(ents, dtype="uint64"))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
from typing import cast
 | 
			
		||||
import pickle
 | 
			
		||||
import pytest
 | 
			
		||||
from hypothesis import given
 | 
			
		||||
| 
						 | 
				
			
			@ -6,6 +7,7 @@ from spacy import util
 | 
			
		|||
from spacy.lang.en import English
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
 | 
			
		||||
from spacy.pipeline.trainable_pipe import TrainablePipe
 | 
			
		||||
from spacy.training import Example
 | 
			
		||||
from spacy.strings import StringStore
 | 
			
		||||
from spacy.util import make_tempdir
 | 
			
		||||
| 
						 | 
				
			
			@ -278,3 +280,28 @@ def test_empty_strings():
 | 
			
		|||
    no_change = trees.add("xyz", "xyz")
 | 
			
		||||
    empty = trees.add("", "")
 | 
			
		||||
    assert no_change == empty
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_store_activations():
 | 
			
		||||
    nlp = English()
 | 
			
		||||
    lemmatizer = cast(TrainablePipe, nlp.add_pipe("trainable_lemmatizer"))
 | 
			
		||||
    lemmatizer.min_tree_freq = 1
 | 
			
		||||
    train_examples = []
 | 
			
		||||
    for t in TRAIN_DATA:
 | 
			
		||||
        train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | 
			
		||||
    nlp.initialize(get_examples=lambda: train_examples)
 | 
			
		||||
    nO = lemmatizer.model.get_dim("nO")
 | 
			
		||||
 | 
			
		||||
    doc = nlp("This is a test.")
 | 
			
		||||
    assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
 | 
			
		||||
 | 
			
		||||
    lemmatizer.store_activations = True
 | 
			
		||||
    doc = nlp("This is a test.")
 | 
			
		||||
    assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
 | 
			
		||||
    assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
 | 
			
		||||
    assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
 | 
			
		||||
 | 
			
		||||
    lemmatizer.store_activations = ["probs"]
 | 
			
		||||
    doc = nlp("This is a test.")
 | 
			
		||||
    assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
 | 
			
		||||
    assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,8 @@
 | 
			
		|||
from typing import Callable, Iterable
 | 
			
		||||
from typing import Callable, Iterable, cast
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from numpy.testing import assert_equal
 | 
			
		||||
from thinc.types import Ragged
 | 
			
		||||
 | 
			
		||||
from spacy import registry, util
 | 
			
		||||
from spacy.attrs import ENT_KB_ID
 | 
			
		||||
| 
						 | 
				
			
			@ -9,7 +10,7 @@ from spacy.compat import pickle
 | 
			
		|||
from spacy.kb import Candidate, KnowledgeBase, get_candidates
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.ml import load_kb
 | 
			
		||||
from spacy.pipeline import EntityLinker
 | 
			
		||||
from spacy.pipeline import EntityLinker, TrainablePipe
 | 
			
		||||
from spacy.pipeline.legacy import EntityLinker_v1
 | 
			
		||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
 | 
			
		||||
from spacy.scorer import Scorer
 | 
			
		||||
| 
						 | 
				
			
			@ -1115,3 +1116,79 @@ def test_tokenization_mismatch():
 | 
			
		|||
 | 
			
		||||
    nlp.add_pipe("sentencizer", first=True)
 | 
			
		||||
    results = nlp.evaluate(train_examples)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_store_activations():
 | 
			
		||||
    nlp = English()
 | 
			
		||||
    vector_length = 3
 | 
			
		||||
    assert "Q2146908" not in nlp.vocab.strings
 | 
			
		||||
 | 
			
		||||
    # Convert the texts to docs to make sure we have doc.ents set for the training examples
 | 
			
		||||
    train_examples = []
 | 
			
		||||
    for text, annotation in TRAIN_DATA:
 | 
			
		||||
        doc = nlp(text)
 | 
			
		||||
        train_examples.append(Example.from_dict(doc, annotation))
 | 
			
		||||
 | 
			
		||||
    def create_kb(vocab):
 | 
			
		||||
        # create artificial KB - assign same prior weight to the two russ cochran's
 | 
			
		||||
        # Q2146908 (Russ Cochran): American golfer
 | 
			
		||||
        # Q7381115 (Russ Cochran): publisher
 | 
			
		||||
        mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
 | 
			
		||||
        mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
 | 
			
		||||
        mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
 | 
			
		||||
        mykb.add_alias(
 | 
			
		||||
            alias="Russ Cochran",
 | 
			
		||||
            entities=["Q2146908", "Q7381115"],
 | 
			
		||||
            probabilities=[0.5, 0.5],
 | 
			
		||||
        )
 | 
			
		||||
        return mykb
 | 
			
		||||
 | 
			
		||||
    # Create the Entity Linker component and add it to the pipeline
 | 
			
		||||
    entity_linker = cast(TrainablePipe, nlp.add_pipe("entity_linker", last=True))
 | 
			
		||||
    assert isinstance(entity_linker, EntityLinker)
 | 
			
		||||
    entity_linker.set_kb(create_kb)
 | 
			
		||||
    assert "Q2146908" in entity_linker.vocab.strings
 | 
			
		||||
    assert "Q2146908" in entity_linker.kb.vocab.strings
 | 
			
		||||
 | 
			
		||||
    # initialize the NEL pipe
 | 
			
		||||
    optimizer = nlp.initialize(get_examples=lambda: train_examples)
 | 
			
		||||
 | 
			
		||||
    for i in range(2):
 | 
			
		||||
        losses = {}
 | 
			
		||||
        nlp.update(train_examples, sgd=optimizer, losses=losses)
 | 
			
		||||
 | 
			
		||||
    nO = entity_linker.model.get_dim("nO")
 | 
			
		||||
 | 
			
		||||
    nlp.add_pipe("sentencizer", first=True)
 | 
			
		||||
    patterns = [
 | 
			
		||||
        {"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
 | 
			
		||||
        {"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
 | 
			
		||||
    ]
 | 
			
		||||
    ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
 | 
			
		||||
    ruler.add_patterns(patterns)
 | 
			
		||||
 | 
			
		||||
    doc = nlp("Russ Cochran was a publisher")
 | 
			
		||||
    assert len(doc.activations["entity_linker"].keys()) == 0
 | 
			
		||||
 | 
			
		||||
    entity_linker.store_activations = True
 | 
			
		||||
    doc = nlp("Russ Cochran was a publisher")
 | 
			
		||||
    assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
 | 
			
		||||
    ents = doc.activations["entity_linker"]["ents"]
 | 
			
		||||
    assert isinstance(ents, Ragged)
 | 
			
		||||
    assert ents.data.shape == (2, 1)
 | 
			
		||||
    assert ents.data.dtype == "uint64"
 | 
			
		||||
    assert ents.lengths.shape == (1,)
 | 
			
		||||
    scores = doc.activations["entity_linker"]["scores"]
 | 
			
		||||
    assert isinstance(scores, Ragged)
 | 
			
		||||
    assert scores.data.shape == (2, 1)
 | 
			
		||||
    assert scores.data.dtype == "float32"
 | 
			
		||||
    assert scores.lengths.shape == (1,)
 | 
			
		||||
 | 
			
		||||
    entity_linker.store_activations = ["scores"]
 | 
			
		||||
    doc = nlp("Russ Cochran was a publisher")
 | 
			
		||||
    assert set(doc.activations["entity_linker"].keys()) == {"scores"}
 | 
			
		||||
    scores = doc.activations["entity_linker"]["scores"]
 | 
			
		||||
    assert isinstance(scores, Ragged)
 | 
			
		||||
    assert scores.data.shape == (2, 1)
 | 
			
		||||
    assert scores.data.dtype == "float32"
 | 
			
		||||
    assert scores.lengths.shape == (1,)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
from typing import Callable, Protocol, Iterable, Iterator, Optional
 | 
			
		||||
from typing import Union, Tuple, List, Dict, Any, overload
 | 
			
		||||
from cymem.cymem import Pool
 | 
			
		||||
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d
 | 
			
		||||
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
 | 
			
		||||
from .span import Span
 | 
			
		||||
from .token import Token
 | 
			
		||||
from ._dict_proxies import SpanGroups
 | 
			
		||||
| 
						 | 
				
			
			@ -22,7 +22,7 @@ class Doc:
 | 
			
		|||
    max_length: int
 | 
			
		||||
    length: int
 | 
			
		||||
    sentiment: float
 | 
			
		||||
    activations: Dict[str, Dict[str, ArrayXd]]
 | 
			
		||||
    activations: Dict[str, Dict[str, Union[ArrayXd, Ragged]]]
 | 
			
		||||
    cats: Dict[str, float]
 | 
			
		||||
    user_hooks: Dict[str, Callable[..., Any]]
 | 
			
		||||
    user_token_hooks: Dict[str, Callable[..., Any]]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user