KB & NEL to/from bytes (#8113)

* unit test for pickling KB

* add pickling test for NEL

* KB to_bytes and from_bytes

* NEL to_bytes and from_bytes

* xfail pickle tests for now

* fix docs

* cleanup
This commit is contained in:
Sofie Van Landeghem 2021-05-20 10:11:30 +02:00 committed by GitHub
parent f6128c06b0
commit 202943bc8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 287 additions and 13 deletions

View File

@ -93,6 +93,15 @@ cdef class KnowledgeBase:
self.vocab = vocab
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def initialize_entities(self, int64_t nr_entities):
self._entry_index = PreshMap(nr_entities + 1)
self._entries = entry_vec(nr_entities + 1)
self._vectors_table = float_matrix(nr_entities + 1)
def initialize_aliases(self, int64_t nr_aliases):
self._alias_index = PreshMap(nr_aliases + 1)
self._aliases_table = alias_vec(nr_aliases + 1)
@property
def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors"""
@ -144,8 +153,7 @@ cdef class KnowledgeBase:
raise ValueError(Errors.E140)
nr_entities = len(set(entity_list))
self._entry_index = PreshMap(nr_entities+1)
self._entries = entry_vec(nr_entities+1)
self.initialize_entities(nr_entities)
i = 0
cdef KBEntryC entry
@ -325,6 +333,102 @@ cdef class KnowledgeBase:
return 0.0
def to_bytes(self, **kwargs):
"""Serialize the current state to a binary string.
"""
def serialize_header():
header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length)
return srsly.json_dumps(header)
def serialize_entries():
i = 1
tuples = []
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash
assert entry_index == i
tuples.append((entry.entity_hash, entry.freq, entry.vector_index))
i = i + 1
return srsly.json_dumps(tuples)
def serialize_aliases():
i = 1
headers = []
indices_lists = []
probs_lists = []
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]):
alias = self._aliases_table[alias_index]
assert alias_index == i
candidate_length = len(alias.entry_indices)
headers.append((alias_hash, candidate_length))
indices_lists.append(alias.entry_indices)
probs_lists.append(alias.probs)
i = i + 1
headers_dump = srsly.json_dumps(headers)
indices_dump = srsly.json_dumps(indices_lists)
probs_dump = srsly.json_dumps(probs_lists)
return srsly.json_dumps((headers_dump, indices_dump, probs_dump))
serializers = {
"header": serialize_header,
"entity_vectors": lambda: srsly.json_dumps(self._vectors_table),
"entries": serialize_entries,
"aliases": serialize_aliases,
}
return util.to_bytes(serializers, [])
def from_bytes(self, bytes_data, *, exclude=tuple()):
"""Load state from a binary string.
"""
def deserialize_header(b):
header = srsly.json_loads(b)
nr_entities = header[0]
nr_aliases = header[1]
entity_vector_length = header[2]
self.initialize_entities(nr_entities)
self.initialize_aliases(nr_aliases)
self.entity_vector_length = entity_vector_length
def deserialize_vectors(b):
self._vectors_table = srsly.json_loads(b)
def deserialize_entries(b):
cdef KBEntryC entry
tuples = srsly.json_loads(b)
i = 1
for (entity_hash, freq, vector_index) in tuples:
entry.entity_hash = entity_hash
entry.freq = freq
entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented
self._entries[i] = entry
self._entry_index[entity_hash] = i
i += 1
def deserialize_aliases(b):
cdef AliasC alias
i = 1
all_data = srsly.json_loads(b)
headers = srsly.json_loads(all_data[0])
indices = srsly.json_loads(all_data[1])
probs = srsly.json_loads(all_data[2])
for header, indices, probs in zip(headers, indices, probs):
alias_hash, candidate_length = header
alias.entry_indices = indices
alias.probs = probs
self._aliases_table[i] = alias
self._alias_index[alias_hash] = i
i += 1
setters = {
"header": deserialize_header,
"entity_vectors": deserialize_vectors,
"entries": deserialize_entries,
"aliases": deserialize_aliases,
}
util.from_bytes(bytes_data, setters, exclude)
return self
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
path = ensure_path(path)
if not path.exists():
@ -404,10 +508,8 @@ cdef class KnowledgeBase:
cdef int64_t entity_vector_length
reader.read_header(&nr_entities, &entity_vector_length)
self.initialize_entities(nr_entities)
self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap(nr_entities+1)
self._entries = entry_vec(nr_entities+1)
self._vectors_table = float_matrix(nr_entities+1)
# STEP 1: load entity vectors
cdef int i = 0
@ -445,8 +547,7 @@ cdef class KnowledgeBase:
# STEP 3: load aliases
cdef int64_t nr_aliases
reader.read_alias_length(&nr_aliases)
self._alias_index = PreshMap(nr_aliases+1)
self._aliases_table = alias_vec(nr_aliases+1)
self.initialize_aliases(nr_aliases)
cdef int64_t nr_candidates
cdef vector[int64_t] entry_indices

View File

@ -408,6 +408,48 @@ class EntityLinker(TrainablePipe):
validate_examples(examples, "EntityLinker.score")
return Scorer.score_links(examples, negative_labels=[self.NIL])
def to_bytes(self, *, exclude=tuple()):
"""Serialize the pipe to a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
DOCS: https://spacy.io/api/entitylinker#to_bytes
"""
self._validate_serialization_attrs()
serialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["vocab"] = self.vocab.to_bytes
serialize["kb"] = self.kb.to_bytes
serialize["model"] = self.model.to_bytes
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, *, exclude=tuple()):
"""Load the pipe from a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (TrainablePipe): The loaded object.
DOCS: https://spacy.io/api/entitylinker#from_bytes
"""
self._validate_serialization_attrs()
def load_model(b):
try:
self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149) from None
deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["kb"] = lambda b: self.kb.from_bytes(b)
deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude)
return self
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:

View File

@ -2,7 +2,7 @@ from typing import Callable, Iterable
import pytest
from numpy.testing import assert_equal
from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle
from spacy.kb import KnowledgeBase, get_candidates, Candidate
from spacy.vocab import Vocab
@ -11,7 +11,7 @@ from spacy.ml import load_kb
from spacy.scorer import Scorer
from spacy.training import Example
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
from spacy.tests.util import make_tempdir, make_tempfile
from spacy.tokens import Span
@ -290,6 +290,9 @@ def test_vocab_serialization(nlp):
assert candidates[0].alias == adam_hash
assert candidates[0].alias_ == "adam"
assert kb_new_vocab.get_vector("Q2") == [2]
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)
def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs"""
@ -546,6 +549,98 @@ def test_kb_serialization():
assert "RandomWord" in nlp2.vocab.strings
@pytest.mark.xfail(reason="Needs fixing")
def test_kb_pickle():
# Test that the KB can be pickled
nlp = English()
kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
assert not kb_1.contains_alias("Russ Cochran")
kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
assert kb_1.contains_alias("Russ Cochran")
data = pickle.dumps(kb_1)
kb_2 = pickle.loads(data)
assert kb_2.contains_alias("Russ Cochran")
@pytest.mark.xfail(reason="Needs fixing")
def test_nel_pickle():
# Test that a pipeline with an EL component can be pickled
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=3)
kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
return kb
nlp_1 = English()
nlp_1.add_pipe("ner")
entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True)
entity_linker_1.set_kb(create_kb)
assert nlp_1.pipe_names == ["ner", "entity_linker"]
assert entity_linker_1.kb.contains_alias("Russ Cochran")
data = pickle.dumps(nlp_1)
nlp_2 = pickle.loads(data)
assert nlp_2.pipe_names == ["ner", "entity_linker"]
entity_linker_2 = nlp_2.get_pipe("entity_linker")
assert entity_linker_2.kb.contains_alias("Russ Cochran")
def test_kb_to_bytes():
# Test that the KB's to_bytes method works correctly
nlp = English()
kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3])
kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5])
kb_1.add_alias(alias="Randomness", entities=["Q66", "Q2146908"], probabilities=[0.1, 0.2])
assert kb_1.contains_alias("Russ Cochran")
kb_bytes = kb_1.to_bytes()
kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
assert not kb_2.contains_alias("Russ Cochran")
kb_2 = kb_2.from_bytes(kb_bytes)
# check that both KBs are exactly the same
assert kb_1.get_size_entities() == kb_2.get_size_entities()
assert kb_1.entity_vector_length == kb_2.entity_vector_length
assert kb_1.get_entity_strings() == kb_2.get_entity_strings()
assert kb_1.get_vector("Q2146908") == kb_2.get_vector("Q2146908")
assert kb_1.get_vector("Q66") == kb_2.get_vector("Q66")
assert kb_2.contains_alias("Russ Cochran")
assert kb_1.get_size_aliases() == kb_2.get_size_aliases()
assert kb_1.get_alias_strings() == kb_2.get_alias_strings()
assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(kb_2.get_alias_candidates("Russ Cochran"))
assert len(kb_1.get_alias_candidates("Randomness")) == len(kb_2.get_alias_candidates("Randomness"))
def test_nel_to_bytes():
# Test that a pipeline with an EL component can be converted to bytes
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=3)
kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
return kb
nlp_1 = English()
nlp_1.add_pipe("ner")
entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True)
entity_linker_1.set_kb(create_kb)
assert entity_linker_1.kb.contains_alias("Russ Cochran")
assert nlp_1.pipe_names == ["ner", "entity_linker"]
nlp_bytes = nlp_1.to_bytes()
nlp_2 = English()
nlp_2.add_pipe("ner")
nlp_2.add_pipe("entity_linker", last=True)
assert nlp_2.pipe_names == ["ner", "entity_linker"]
assert not nlp_2.get_pipe("entity_linker").kb.contains_alias("Russ Cochran")
nlp_2 = nlp_2.from_bytes(nlp_bytes)
kb_2 = nlp_2.get_pipe("entity_linker").kb
assert kb_2.contains_alias("Russ Cochran")
assert kb_2.get_vector("Q2146908") == [6, -4, 3]
assert_almost_equal(kb_2.get_prior_prob(entity="Q2146908", alias="Russ Cochran"), 0.8)
def test_scorer_links():
train_examples = []
nlp = English()

View File

@ -213,10 +213,10 @@ if there is no prediction.
> kb_ids = entity_linker.predict([doc1, doc2])
> ```
| Name | Description |
| ----------- | ------------------------------------------- |
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
| **RETURNS** | `List[str]` | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ |
| Name | Description |
| ----------- | -------------------------------------------------------------------------- |
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
| **RETURNS** | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ |
## EntityLinker.set_annotations {#set_annotations tag="method"}
@ -341,6 +341,42 @@ Load the pipe from disk. Modifies the object in place and returns it.
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
| **RETURNS** | The modified `EntityLinker` object. ~~EntityLinker~~ |
## EntityLinker.to_bytes {#to_bytes tag="method"}
> #### Example
>
> ```python
> entity_linker = nlp.add_pipe("entity_linker")
> entity_linker_bytes = entity_linker.to_bytes()
> ```
Serialize the pipe to a bytestring, including the `KnowledgeBase`.
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------- |
| _keyword-only_ | |
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
| **RETURNS** | The serialized form of the `EntityLinker` object. ~~bytes~~ |
## EntityLinker.from_bytes {#from_bytes tag="method"}
Load the pipe from a bytestring. Modifies the object in place and returns it.
> #### Example
>
> ```python
> entity_linker_bytes = entity_linker.to_bytes()
> entity_linker = nlp.add_pipe("entity_linker")
> entity_linker.from_bytes(entity_linker_bytes)
> ```
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------- |
| `bytes_data` | The data to load from. ~~bytes~~ |
| _keyword-only_ | |
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
| **RETURNS** | The `EntityLinker` object. ~~EntityLinker~~ |
## Serialization fields {#serialization-fields}
During serialization, spaCy will export several data fields used to restore