mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
KB & NEL to/from bytes (#8113)
* unit test for pickling KB * add pickling test for NEL * KB to_bytes and from_bytes * NEL to_bytes and from_bytes * xfail pickle tests for now * fix docs * cleanup
This commit is contained in:
parent
f6128c06b0
commit
202943bc8c
115
spacy/kb.pyx
115
spacy/kb.pyx
|
@ -93,6 +93,15 @@ cdef class KnowledgeBase:
|
|||
self.vocab = vocab
|
||||
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||
|
||||
def initialize_entities(self, int64_t nr_entities):
|
||||
self._entry_index = PreshMap(nr_entities + 1)
|
||||
self._entries = entry_vec(nr_entities + 1)
|
||||
self._vectors_table = float_matrix(nr_entities + 1)
|
||||
|
||||
def initialize_aliases(self, int64_t nr_aliases):
|
||||
self._alias_index = PreshMap(nr_aliases + 1)
|
||||
self._aliases_table = alias_vec(nr_aliases + 1)
|
||||
|
||||
@property
|
||||
def entity_vector_length(self):
|
||||
"""RETURNS (uint64): length of the entity vectors"""
|
||||
|
@ -144,8 +153,7 @@ cdef class KnowledgeBase:
|
|||
raise ValueError(Errors.E140)
|
||||
|
||||
nr_entities = len(set(entity_list))
|
||||
self._entry_index = PreshMap(nr_entities+1)
|
||||
self._entries = entry_vec(nr_entities+1)
|
||||
self.initialize_entities(nr_entities)
|
||||
|
||||
i = 0
|
||||
cdef KBEntryC entry
|
||||
|
@ -325,6 +333,102 @@ cdef class KnowledgeBase:
|
|||
|
||||
return 0.0
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
"""Serialize the current state to a binary string.
|
||||
"""
|
||||
def serialize_header():
|
||||
header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length)
|
||||
return srsly.json_dumps(header)
|
||||
|
||||
def serialize_entries():
|
||||
i = 1
|
||||
tuples = []
|
||||
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
|
||||
entry = self._entries[entry_index]
|
||||
assert entry.entity_hash == entry_hash
|
||||
assert entry_index == i
|
||||
tuples.append((entry.entity_hash, entry.freq, entry.vector_index))
|
||||
i = i + 1
|
||||
return srsly.json_dumps(tuples)
|
||||
|
||||
def serialize_aliases():
|
||||
i = 1
|
||||
headers = []
|
||||
indices_lists = []
|
||||
probs_lists = []
|
||||
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]):
|
||||
alias = self._aliases_table[alias_index]
|
||||
assert alias_index == i
|
||||
candidate_length = len(alias.entry_indices)
|
||||
headers.append((alias_hash, candidate_length))
|
||||
indices_lists.append(alias.entry_indices)
|
||||
probs_lists.append(alias.probs)
|
||||
i = i + 1
|
||||
headers_dump = srsly.json_dumps(headers)
|
||||
indices_dump = srsly.json_dumps(indices_lists)
|
||||
probs_dump = srsly.json_dumps(probs_lists)
|
||||
return srsly.json_dumps((headers_dump, indices_dump, probs_dump))
|
||||
|
||||
serializers = {
|
||||
"header": serialize_header,
|
||||
"entity_vectors": lambda: srsly.json_dumps(self._vectors_table),
|
||||
"entries": serialize_entries,
|
||||
"aliases": serialize_aliases,
|
||||
}
|
||||
return util.to_bytes(serializers, [])
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load state from a binary string.
|
||||
"""
|
||||
def deserialize_header(b):
|
||||
header = srsly.json_loads(b)
|
||||
nr_entities = header[0]
|
||||
nr_aliases = header[1]
|
||||
entity_vector_length = header[2]
|
||||
self.initialize_entities(nr_entities)
|
||||
self.initialize_aliases(nr_aliases)
|
||||
self.entity_vector_length = entity_vector_length
|
||||
|
||||
def deserialize_vectors(b):
|
||||
self._vectors_table = srsly.json_loads(b)
|
||||
|
||||
def deserialize_entries(b):
|
||||
cdef KBEntryC entry
|
||||
tuples = srsly.json_loads(b)
|
||||
i = 1
|
||||
for (entity_hash, freq, vector_index) in tuples:
|
||||
entry.entity_hash = entity_hash
|
||||
entry.freq = freq
|
||||
entry.vector_index = vector_index
|
||||
entry.feats_row = -1 # Features table currently not implemented
|
||||
self._entries[i] = entry
|
||||
self._entry_index[entity_hash] = i
|
||||
i += 1
|
||||
|
||||
def deserialize_aliases(b):
|
||||
cdef AliasC alias
|
||||
i = 1
|
||||
all_data = srsly.json_loads(b)
|
||||
headers = srsly.json_loads(all_data[0])
|
||||
indices = srsly.json_loads(all_data[1])
|
||||
probs = srsly.json_loads(all_data[2])
|
||||
for header, indices, probs in zip(headers, indices, probs):
|
||||
alias_hash, candidate_length = header
|
||||
alias.entry_indices = indices
|
||||
alias.probs = probs
|
||||
self._aliases_table[i] = alias
|
||||
self._alias_index[alias_hash] = i
|
||||
i += 1
|
||||
|
||||
setters = {
|
||||
"header": deserialize_header,
|
||||
"entity_vectors": deserialize_vectors,
|
||||
"entries": deserialize_entries,
|
||||
"aliases": deserialize_aliases,
|
||||
}
|
||||
util.from_bytes(bytes_data, setters, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
|
@ -404,10 +508,8 @@ cdef class KnowledgeBase:
|
|||
cdef int64_t entity_vector_length
|
||||
reader.read_header(&nr_entities, &entity_vector_length)
|
||||
|
||||
self.initialize_entities(nr_entities)
|
||||
self.entity_vector_length = entity_vector_length
|
||||
self._entry_index = PreshMap(nr_entities+1)
|
||||
self._entries = entry_vec(nr_entities+1)
|
||||
self._vectors_table = float_matrix(nr_entities+1)
|
||||
|
||||
# STEP 1: load entity vectors
|
||||
cdef int i = 0
|
||||
|
@ -445,8 +547,7 @@ cdef class KnowledgeBase:
|
|||
# STEP 3: load aliases
|
||||
cdef int64_t nr_aliases
|
||||
reader.read_alias_length(&nr_aliases)
|
||||
self._alias_index = PreshMap(nr_aliases+1)
|
||||
self._aliases_table = alias_vec(nr_aliases+1)
|
||||
self.initialize_aliases(nr_aliases)
|
||||
|
||||
cdef int64_t nr_candidates
|
||||
cdef vector[int64_t] entry_indices
|
||||
|
|
|
@ -408,6 +408,48 @@ class EntityLinker(TrainablePipe):
|
|||
validate_examples(examples, "EntityLinker.score")
|
||||
return Scorer.score_links(examples, negative_labels=[self.NIL])
|
||||
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The serialized object.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#to_bytes
|
||||
"""
|
||||
self._validate_serialization_attrs()
|
||||
serialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
serialize["kb"] = self.kb.to_bytes
|
||||
serialize["model"] = self.model.to_bytes
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (TrainablePipe): The loaded object.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#from_bytes
|
||||
"""
|
||||
self._validate_serialization_attrs()
|
||||
|
||||
def load_model(b):
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
|
||||
deserialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||
deserialize["kb"] = lambda b: self.kb.from_bytes(b)
|
||||
deserialize["model"] = load_model
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> None:
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Callable, Iterable
|
|||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from spacy.attrs import ENT_KB_ID
|
||||
|
||||
from spacy.compat import pickle
|
||||
from spacy.kb import KnowledgeBase, get_candidates, Candidate
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
|
@ -11,7 +11,7 @@ from spacy.ml import load_kb
|
|||
from spacy.scorer import Scorer
|
||||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.tests.util import make_tempdir, make_tempfile
|
||||
from spacy.tokens import Span
|
||||
|
||||
|
||||
|
@ -290,6 +290,9 @@ def test_vocab_serialization(nlp):
|
|||
assert candidates[0].alias == adam_hash
|
||||
assert candidates[0].alias_ == "adam"
|
||||
|
||||
assert kb_new_vocab.get_vector("Q2") == [2]
|
||||
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)
|
||||
|
||||
|
||||
def test_append_alias(nlp):
|
||||
"""Test that we can append additional alias-entity pairs"""
|
||||
|
@ -546,6 +549,98 @@ def test_kb_serialization():
|
|||
assert "RandomWord" in nlp2.vocab.strings
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Needs fixing")
|
||||
def test_kb_pickle():
|
||||
# Test that the KB can be pickled
|
||||
nlp = English()
|
||||
kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
assert not kb_1.contains_alias("Russ Cochran")
|
||||
kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
|
||||
assert kb_1.contains_alias("Russ Cochran")
|
||||
data = pickle.dumps(kb_1)
|
||||
kb_2 = pickle.loads(data)
|
||||
assert kb_2.contains_alias("Russ Cochran")
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Needs fixing")
|
||||
def test_nel_pickle():
|
||||
# Test that a pipeline with an EL component can be pickled
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=3)
|
||||
kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
|
||||
return kb
|
||||
|
||||
nlp_1 = English()
|
||||
nlp_1.add_pipe("ner")
|
||||
entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True)
|
||||
entity_linker_1.set_kb(create_kb)
|
||||
assert nlp_1.pipe_names == ["ner", "entity_linker"]
|
||||
assert entity_linker_1.kb.contains_alias("Russ Cochran")
|
||||
|
||||
data = pickle.dumps(nlp_1)
|
||||
nlp_2 = pickle.loads(data)
|
||||
assert nlp_2.pipe_names == ["ner", "entity_linker"]
|
||||
entity_linker_2 = nlp_2.get_pipe("entity_linker")
|
||||
assert entity_linker_2.kb.contains_alias("Russ Cochran")
|
||||
|
||||
|
||||
def test_kb_to_bytes():
|
||||
# Test that the KB's to_bytes method works correctly
|
||||
nlp = English()
|
||||
kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3])
|
||||
kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
|
||||
kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5])
|
||||
kb_1.add_alias(alias="Randomness", entities=["Q66", "Q2146908"], probabilities=[0.1, 0.2])
|
||||
assert kb_1.contains_alias("Russ Cochran")
|
||||
kb_bytes = kb_1.to_bytes()
|
||||
kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
assert not kb_2.contains_alias("Russ Cochran")
|
||||
kb_2 = kb_2.from_bytes(kb_bytes)
|
||||
# check that both KBs are exactly the same
|
||||
assert kb_1.get_size_entities() == kb_2.get_size_entities()
|
||||
assert kb_1.entity_vector_length == kb_2.entity_vector_length
|
||||
assert kb_1.get_entity_strings() == kb_2.get_entity_strings()
|
||||
assert kb_1.get_vector("Q2146908") == kb_2.get_vector("Q2146908")
|
||||
assert kb_1.get_vector("Q66") == kb_2.get_vector("Q66")
|
||||
assert kb_2.contains_alias("Russ Cochran")
|
||||
assert kb_1.get_size_aliases() == kb_2.get_size_aliases()
|
||||
assert kb_1.get_alias_strings() == kb_2.get_alias_strings()
|
||||
assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(kb_2.get_alias_candidates("Russ Cochran"))
|
||||
assert len(kb_1.get_alias_candidates("Randomness")) == len(kb_2.get_alias_candidates("Randomness"))
|
||||
|
||||
|
||||
def test_nel_to_bytes():
|
||||
# Test that a pipeline with an EL component can be converted to bytes
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=3)
|
||||
kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
|
||||
return kb
|
||||
|
||||
nlp_1 = English()
|
||||
nlp_1.add_pipe("ner")
|
||||
entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True)
|
||||
entity_linker_1.set_kb(create_kb)
|
||||
assert entity_linker_1.kb.contains_alias("Russ Cochran")
|
||||
assert nlp_1.pipe_names == ["ner", "entity_linker"]
|
||||
|
||||
nlp_bytes = nlp_1.to_bytes()
|
||||
nlp_2 = English()
|
||||
nlp_2.add_pipe("ner")
|
||||
nlp_2.add_pipe("entity_linker", last=True)
|
||||
assert nlp_2.pipe_names == ["ner", "entity_linker"]
|
||||
assert not nlp_2.get_pipe("entity_linker").kb.contains_alias("Russ Cochran")
|
||||
nlp_2 = nlp_2.from_bytes(nlp_bytes)
|
||||
kb_2 = nlp_2.get_pipe("entity_linker").kb
|
||||
assert kb_2.contains_alias("Russ Cochran")
|
||||
assert kb_2.get_vector("Q2146908") == [6, -4, 3]
|
||||
assert_almost_equal(kb_2.get_prior_prob(entity="Q2146908", alias="Russ Cochran"), 0.8)
|
||||
|
||||
|
||||
def test_scorer_links():
|
||||
train_examples = []
|
||||
nlp = English()
|
||||
|
|
|
@ -214,9 +214,9 @@ if there is no prediction.
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------- |
|
||||
| ----------- | -------------------------------------------------------------------------- |
|
||||
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
|
||||
| **RETURNS** | `List[str]` | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ |
|
||||
| **RETURNS** | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ |
|
||||
|
||||
## EntityLinker.set_annotations {#set_annotations tag="method"}
|
||||
|
||||
|
@ -341,6 +341,42 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The modified `EntityLinker` object. ~~EntityLinker~~ |
|
||||
|
||||
## EntityLinker.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> entity_linker_bytes = entity_linker.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring, including the `KnowledgeBase`.
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The serialized form of the `EntityLinker` object. ~~bytes~~ |
|
||||
|
||||
## EntityLinker.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker_bytes = entity_linker.to_bytes()
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> entity_linker.from_bytes(entity_linker_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `bytes_data` | The data to load from. ~~bytes~~ |
|
||||
| _keyword-only_ | |
|
||||
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||
| **RETURNS** | The `EntityLinker` object. ~~EntityLinker~~ |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
During serialization, spaCy will export several data fields used to restore
|
||||
|
|
Loading…
Reference in New Issue
Block a user