diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 4d02b89d0..d8514b54c 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -93,6 +93,15 @@ cdef class KnowledgeBase: self.vocab = vocab self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) + def initialize_entities(self, int64_t nr_entities): + self._entry_index = PreshMap(nr_entities + 1) + self._entries = entry_vec(nr_entities + 1) + self._vectors_table = float_matrix(nr_entities + 1) + + def initialize_aliases(self, int64_t nr_aliases): + self._alias_index = PreshMap(nr_aliases + 1) + self._aliases_table = alias_vec(nr_aliases + 1) + @property def entity_vector_length(self): """RETURNS (uint64): length of the entity vectors""" @@ -144,8 +153,7 @@ cdef class KnowledgeBase: raise ValueError(Errors.E140) nr_entities = len(set(entity_list)) - self._entry_index = PreshMap(nr_entities+1) - self._entries = entry_vec(nr_entities+1) + self.initialize_entities(nr_entities) i = 0 cdef KBEntryC entry @@ -325,6 +333,102 @@ cdef class KnowledgeBase: return 0.0 + def to_bytes(self, **kwargs): + """Serialize the current state to a binary string. + """ + def serialize_header(): + header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length) + return srsly.json_dumps(header) + + def serialize_entries(): + i = 1 + tuples = [] + for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): + entry = self._entries[entry_index] + assert entry.entity_hash == entry_hash + assert entry_index == i + tuples.append((entry.entity_hash, entry.freq, entry.vector_index)) + i = i + 1 + return srsly.json_dumps(tuples) + + def serialize_aliases(): + i = 1 + headers = [] + indices_lists = [] + probs_lists = [] + for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]): + alias = self._aliases_table[alias_index] + assert alias_index == i + candidate_length = len(alias.entry_indices) + headers.append((alias_hash, candidate_length)) + indices_lists.append(alias.entry_indices) + probs_lists.append(alias.probs) + i = i + 1 + headers_dump = srsly.json_dumps(headers) + indices_dump = srsly.json_dumps(indices_lists) + probs_dump = srsly.json_dumps(probs_lists) + return srsly.json_dumps((headers_dump, indices_dump, probs_dump)) + + serializers = { + "header": serialize_header, + "entity_vectors": lambda: srsly.json_dumps(self._vectors_table), + "entries": serialize_entries, + "aliases": serialize_aliases, + } + return util.to_bytes(serializers, []) + + def from_bytes(self, bytes_data, *, exclude=tuple()): + """Load state from a binary string. + """ + def deserialize_header(b): + header = srsly.json_loads(b) + nr_entities = header[0] + nr_aliases = header[1] + entity_vector_length = header[2] + self.initialize_entities(nr_entities) + self.initialize_aliases(nr_aliases) + self.entity_vector_length = entity_vector_length + + def deserialize_vectors(b): + self._vectors_table = srsly.json_loads(b) + + def deserialize_entries(b): + cdef KBEntryC entry + tuples = srsly.json_loads(b) + i = 1 + for (entity_hash, freq, vector_index) in tuples: + entry.entity_hash = entity_hash + entry.freq = freq + entry.vector_index = vector_index + entry.feats_row = -1 # Features table currently not implemented + self._entries[i] = entry + self._entry_index[entity_hash] = i + i += 1 + + def deserialize_aliases(b): + cdef AliasC alias + i = 1 + all_data = srsly.json_loads(b) + headers = srsly.json_loads(all_data[0]) + indices = srsly.json_loads(all_data[1]) + probs = srsly.json_loads(all_data[2]) + for header, indices, probs in zip(headers, indices, probs): + alias_hash, candidate_length = header + alias.entry_indices = indices + alias.probs = probs + self._aliases_table[i] = alias + self._alias_index[alias_hash] = i + i += 1 + + setters = { + "header": deserialize_header, + "entity_vectors": deserialize_vectors, + "entries": deserialize_entries, + "aliases": deserialize_aliases, + } + util.from_bytes(bytes_data, setters, exclude) + return self + def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): path = ensure_path(path) if not path.exists(): @@ -404,10 +508,8 @@ cdef class KnowledgeBase: cdef int64_t entity_vector_length reader.read_header(&nr_entities, &entity_vector_length) + self.initialize_entities(nr_entities) self.entity_vector_length = entity_vector_length - self._entry_index = PreshMap(nr_entities+1) - self._entries = entry_vec(nr_entities+1) - self._vectors_table = float_matrix(nr_entities+1) # STEP 1: load entity vectors cdef int i = 0 @@ -445,8 +547,7 @@ cdef class KnowledgeBase: # STEP 3: load aliases cdef int64_t nr_aliases reader.read_alias_length(&nr_aliases) - self._alias_index = PreshMap(nr_aliases+1) - self._aliases_table = alias_vec(nr_aliases+1) + self.initialize_aliases(nr_aliases) cdef int64_t nr_candidates cdef vector[int64_t] entry_indices diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 66070916e..002ea71a7 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -408,6 +408,48 @@ class EntityLinker(TrainablePipe): validate_examples(examples, "EntityLinker.score") return Scorer.score_links(examples, negative_labels=[self.NIL]) + def to_bytes(self, *, exclude=tuple()): + """Serialize the pipe to a bytestring. + + exclude (Iterable[str]): String names of serialization fields to exclude. + RETURNS (bytes): The serialized object. + + DOCS: https://spacy.io/api/entitylinker#to_bytes + """ + self._validate_serialization_attrs() + serialize = {} + if hasattr(self, "cfg") and self.cfg is not None: + serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) + serialize["vocab"] = self.vocab.to_bytes + serialize["kb"] = self.kb.to_bytes + serialize["model"] = self.model.to_bytes + return util.to_bytes(serialize, exclude) + + def from_bytes(self, bytes_data, *, exclude=tuple()): + """Load the pipe from a bytestring. + + exclude (Iterable[str]): String names of serialization fields to exclude. + RETURNS (TrainablePipe): The loaded object. + + DOCS: https://spacy.io/api/entitylinker#from_bytes + """ + self._validate_serialization_attrs() + + def load_model(b): + try: + self.model.from_bytes(b) + except AttributeError: + raise ValueError(Errors.E149) from None + + deserialize = {} + if hasattr(self, "cfg") and self.cfg is not None: + deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) + deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) + deserialize["kb"] = lambda b: self.kb.from_bytes(b) + deserialize["model"] = load_model + util.from_bytes(bytes_data, deserialize, exclude) + return self + def to_disk( self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> None: diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 4883cceb8..a7f9364e9 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -2,7 +2,7 @@ from typing import Callable, Iterable import pytest from numpy.testing import assert_equal from spacy.attrs import ENT_KB_ID - +from spacy.compat import pickle from spacy.kb import KnowledgeBase, get_candidates, Candidate from spacy.vocab import Vocab @@ -11,7 +11,7 @@ from spacy.ml import load_kb from spacy.scorer import Scorer from spacy.training import Example from spacy.lang.en import English -from spacy.tests.util import make_tempdir +from spacy.tests.util import make_tempdir, make_tempfile from spacy.tokens import Span @@ -290,6 +290,9 @@ def test_vocab_serialization(nlp): assert candidates[0].alias == adam_hash assert candidates[0].alias_ == "adam" + assert kb_new_vocab.get_vector("Q2") == [2] + assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4) + def test_append_alias(nlp): """Test that we can append additional alias-entity pairs""" @@ -546,6 +549,98 @@ def test_kb_serialization(): assert "RandomWord" in nlp2.vocab.strings +@pytest.mark.xfail(reason="Needs fixing") +def test_kb_pickle(): + # Test that the KB can be pickled + nlp = English() + kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3) + kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + assert not kb_1.contains_alias("Russ Cochran") + kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8]) + assert kb_1.contains_alias("Russ Cochran") + data = pickle.dumps(kb_1) + kb_2 = pickle.loads(data) + assert kb_2.contains_alias("Russ Cochran") + + +@pytest.mark.xfail(reason="Needs fixing") +def test_nel_pickle(): + # Test that a pipeline with an EL component can be pickled + def create_kb(vocab): + kb = KnowledgeBase(vocab, entity_vector_length=3) + kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8]) + return kb + + nlp_1 = English() + nlp_1.add_pipe("ner") + entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True) + entity_linker_1.set_kb(create_kb) + assert nlp_1.pipe_names == ["ner", "entity_linker"] + assert entity_linker_1.kb.contains_alias("Russ Cochran") + + data = pickle.dumps(nlp_1) + nlp_2 = pickle.loads(data) + assert nlp_2.pipe_names == ["ner", "entity_linker"] + entity_linker_2 = nlp_2.get_pipe("entity_linker") + assert entity_linker_2.kb.contains_alias("Russ Cochran") + + +def test_kb_to_bytes(): + # Test that the KB's to_bytes method works correctly + nlp = English() + kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3) + kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3]) + kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8]) + kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5]) + kb_1.add_alias(alias="Randomness", entities=["Q66", "Q2146908"], probabilities=[0.1, 0.2]) + assert kb_1.contains_alias("Russ Cochran") + kb_bytes = kb_1.to_bytes() + kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3) + assert not kb_2.contains_alias("Russ Cochran") + kb_2 = kb_2.from_bytes(kb_bytes) + # check that both KBs are exactly the same + assert kb_1.get_size_entities() == kb_2.get_size_entities() + assert kb_1.entity_vector_length == kb_2.entity_vector_length + assert kb_1.get_entity_strings() == kb_2.get_entity_strings() + assert kb_1.get_vector("Q2146908") == kb_2.get_vector("Q2146908") + assert kb_1.get_vector("Q66") == kb_2.get_vector("Q66") + assert kb_2.contains_alias("Russ Cochran") + assert kb_1.get_size_aliases() == kb_2.get_size_aliases() + assert kb_1.get_alias_strings() == kb_2.get_alias_strings() + assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(kb_2.get_alias_candidates("Russ Cochran")) + assert len(kb_1.get_alias_candidates("Randomness")) == len(kb_2.get_alias_candidates("Randomness")) + + +def test_nel_to_bytes(): + # Test that a pipeline with an EL component can be converted to bytes + def create_kb(vocab): + kb = KnowledgeBase(vocab, entity_vector_length=3) + kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8]) + return kb + + nlp_1 = English() + nlp_1.add_pipe("ner") + entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True) + entity_linker_1.set_kb(create_kb) + assert entity_linker_1.kb.contains_alias("Russ Cochran") + assert nlp_1.pipe_names == ["ner", "entity_linker"] + + nlp_bytes = nlp_1.to_bytes() + nlp_2 = English() + nlp_2.add_pipe("ner") + nlp_2.add_pipe("entity_linker", last=True) + assert nlp_2.pipe_names == ["ner", "entity_linker"] + assert not nlp_2.get_pipe("entity_linker").kb.contains_alias("Russ Cochran") + nlp_2 = nlp_2.from_bytes(nlp_bytes) + kb_2 = nlp_2.get_pipe("entity_linker").kb + assert kb_2.contains_alias("Russ Cochran") + assert kb_2.get_vector("Q2146908") == [6, -4, 3] + assert_almost_equal(kb_2.get_prior_prob(entity="Q2146908", alias="Russ Cochran"), 0.8) + + def test_scorer_links(): train_examples = [] nlp = English() diff --git a/website/docs/api/entitylinker.md b/website/docs/api/entitylinker.md index b3a1054fc..2994d934b 100644 --- a/website/docs/api/entitylinker.md +++ b/website/docs/api/entitylinker.md @@ -213,10 +213,10 @@ if there is no prediction. > kb_ids = entity_linker.predict([doc1, doc2]) > ``` -| Name | Description | -| ----------- | ------------------------------------------- | -| `docs` | The documents to predict. ~~Iterable[Doc]~~ | -| **RETURNS** | `List[str]` | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ | +| Name | Description | +| ----------- | -------------------------------------------------------------------------- | +| `docs` | The documents to predict. ~~Iterable[Doc]~~ | +| **RETURNS** | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ | ## EntityLinker.set_annotations {#set_annotations tag="method"} @@ -341,6 +341,42 @@ Load the pipe from disk. Modifies the object in place and returns it. | `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | | **RETURNS** | The modified `EntityLinker` object. ~~EntityLinker~~ | +## EntityLinker.to_bytes {#to_bytes tag="method"} + +> #### Example +> +> ```python +> entity_linker = nlp.add_pipe("entity_linker") +> entity_linker_bytes = entity_linker.to_bytes() +> ``` + +Serialize the pipe to a bytestring, including the `KnowledgeBase`. + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------- | +| _keyword-only_ | | +| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | +| **RETURNS** | The serialized form of the `EntityLinker` object. ~~bytes~~ | + +## EntityLinker.from_bytes {#from_bytes tag="method"} + +Load the pipe from a bytestring. Modifies the object in place and returns it. + +> #### Example +> +> ```python +> entity_linker_bytes = entity_linker.to_bytes() +> entity_linker = nlp.add_pipe("entity_linker") +> entity_linker.from_bytes(entity_linker_bytes) +> ``` + +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------- | +| `bytes_data` | The data to load from. ~~bytes~~ | +| _keyword-only_ | | +| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ | +| **RETURNS** | The `EntityLinker` object. ~~EntityLinker~~ | + ## Serialization fields {#serialization-fields} During serialization, spaCy will export several data fields used to restore