mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	KB & NEL to/from bytes (#8113)
* unit test for pickling KB * add pickling test for NEL * KB to_bytes and from_bytes * NEL to_bytes and from_bytes * xfail pickle tests for now * fix docs * cleanup
This commit is contained in:
		
							parent
							
								
									f6128c06b0
								
							
						
					
					
						commit
						202943bc8c
					
				
							
								
								
									
										115
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										115
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							| 
						 | 
					@ -93,6 +93,15 @@ cdef class KnowledgeBase:
 | 
				
			||||||
        self.vocab = vocab
 | 
					        self.vocab = vocab
 | 
				
			||||||
        self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
 | 
					        self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def initialize_entities(self, int64_t nr_entities):
 | 
				
			||||||
 | 
					        self._entry_index = PreshMap(nr_entities + 1)
 | 
				
			||||||
 | 
					        self._entries = entry_vec(nr_entities + 1)
 | 
				
			||||||
 | 
					        self._vectors_table = float_matrix(nr_entities + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def initialize_aliases(self, int64_t nr_aliases):
 | 
				
			||||||
 | 
					        self._alias_index = PreshMap(nr_aliases + 1)
 | 
				
			||||||
 | 
					        self._aliases_table = alias_vec(nr_aliases + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def entity_vector_length(self):
 | 
					    def entity_vector_length(self):
 | 
				
			||||||
        """RETURNS (uint64): length of the entity vectors"""
 | 
					        """RETURNS (uint64): length of the entity vectors"""
 | 
				
			||||||
| 
						 | 
					@ -144,8 +153,7 @@ cdef class KnowledgeBase:
 | 
				
			||||||
            raise ValueError(Errors.E140)
 | 
					            raise ValueError(Errors.E140)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        nr_entities = len(set(entity_list))
 | 
					        nr_entities = len(set(entity_list))
 | 
				
			||||||
        self._entry_index = PreshMap(nr_entities+1)
 | 
					        self.initialize_entities(nr_entities)
 | 
				
			||||||
        self._entries = entry_vec(nr_entities+1)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        i = 0
 | 
					        i = 0
 | 
				
			||||||
        cdef KBEntryC entry
 | 
					        cdef KBEntryC entry
 | 
				
			||||||
| 
						 | 
					@ -325,6 +333,102 @@ cdef class KnowledgeBase:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return 0.0
 | 
					        return 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_bytes(self, **kwargs):
 | 
				
			||||||
 | 
					        """Serialize the current state to a binary string.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        def serialize_header():
 | 
				
			||||||
 | 
					            header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length)
 | 
				
			||||||
 | 
					            return srsly.json_dumps(header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def serialize_entries():
 | 
				
			||||||
 | 
					            i = 1
 | 
				
			||||||
 | 
					            tuples = []
 | 
				
			||||||
 | 
					            for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
 | 
				
			||||||
 | 
					                entry = self._entries[entry_index]
 | 
				
			||||||
 | 
					                assert entry.entity_hash == entry_hash
 | 
				
			||||||
 | 
					                assert entry_index == i
 | 
				
			||||||
 | 
					                tuples.append((entry.entity_hash, entry.freq, entry.vector_index))
 | 
				
			||||||
 | 
					                i = i + 1
 | 
				
			||||||
 | 
					            return srsly.json_dumps(tuples)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def serialize_aliases():
 | 
				
			||||||
 | 
					            i = 1
 | 
				
			||||||
 | 
					            headers = []
 | 
				
			||||||
 | 
					            indices_lists = []
 | 
				
			||||||
 | 
					            probs_lists = []
 | 
				
			||||||
 | 
					            for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]):
 | 
				
			||||||
 | 
					                alias = self._aliases_table[alias_index]
 | 
				
			||||||
 | 
					                assert alias_index == i
 | 
				
			||||||
 | 
					                candidate_length = len(alias.entry_indices)
 | 
				
			||||||
 | 
					                headers.append((alias_hash, candidate_length))
 | 
				
			||||||
 | 
					                indices_lists.append(alias.entry_indices)
 | 
				
			||||||
 | 
					                probs_lists.append(alias.probs)
 | 
				
			||||||
 | 
					                i = i + 1
 | 
				
			||||||
 | 
					            headers_dump = srsly.json_dumps(headers)
 | 
				
			||||||
 | 
					            indices_dump = srsly.json_dumps(indices_lists)
 | 
				
			||||||
 | 
					            probs_dump = srsly.json_dumps(probs_lists)
 | 
				
			||||||
 | 
					            return srsly.json_dumps((headers_dump, indices_dump, probs_dump))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        serializers = {
 | 
				
			||||||
 | 
					            "header": serialize_header,
 | 
				
			||||||
 | 
					            "entity_vectors": lambda: srsly.json_dumps(self._vectors_table),
 | 
				
			||||||
 | 
					            "entries": serialize_entries,
 | 
				
			||||||
 | 
					            "aliases": serialize_aliases,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return util.to_bytes(serializers, [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def from_bytes(self, bytes_data, *, exclude=tuple()):
 | 
				
			||||||
 | 
					        """Load state from a binary string.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        def deserialize_header(b):
 | 
				
			||||||
 | 
					            header = srsly.json_loads(b)
 | 
				
			||||||
 | 
					            nr_entities = header[0]
 | 
				
			||||||
 | 
					            nr_aliases = header[1]
 | 
				
			||||||
 | 
					            entity_vector_length = header[2]
 | 
				
			||||||
 | 
					            self.initialize_entities(nr_entities)
 | 
				
			||||||
 | 
					            self.initialize_aliases(nr_aliases)
 | 
				
			||||||
 | 
					            self.entity_vector_length = entity_vector_length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def deserialize_vectors(b):
 | 
				
			||||||
 | 
					            self._vectors_table = srsly.json_loads(b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def deserialize_entries(b):
 | 
				
			||||||
 | 
					            cdef KBEntryC entry
 | 
				
			||||||
 | 
					            tuples = srsly.json_loads(b)
 | 
				
			||||||
 | 
					            i = 1
 | 
				
			||||||
 | 
					            for (entity_hash, freq, vector_index) in tuples:
 | 
				
			||||||
 | 
					                entry.entity_hash = entity_hash
 | 
				
			||||||
 | 
					                entry.freq = freq
 | 
				
			||||||
 | 
					                entry.vector_index = vector_index
 | 
				
			||||||
 | 
					                entry.feats_row = -1  # Features table currently not implemented
 | 
				
			||||||
 | 
					                self._entries[i] = entry
 | 
				
			||||||
 | 
					                self._entry_index[entity_hash] = i
 | 
				
			||||||
 | 
					                i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def deserialize_aliases(b):
 | 
				
			||||||
 | 
					            cdef AliasC alias
 | 
				
			||||||
 | 
					            i = 1
 | 
				
			||||||
 | 
					            all_data = srsly.json_loads(b)
 | 
				
			||||||
 | 
					            headers = srsly.json_loads(all_data[0])
 | 
				
			||||||
 | 
					            indices = srsly.json_loads(all_data[1])
 | 
				
			||||||
 | 
					            probs = srsly.json_loads(all_data[2])
 | 
				
			||||||
 | 
					            for header, indices, probs in zip(headers, indices, probs):
 | 
				
			||||||
 | 
					                alias_hash, candidate_length = header
 | 
				
			||||||
 | 
					                alias.entry_indices = indices
 | 
				
			||||||
 | 
					                alias.probs = probs
 | 
				
			||||||
 | 
					                self._aliases_table[i] = alias
 | 
				
			||||||
 | 
					                self._alias_index[alias_hash] = i
 | 
				
			||||||
 | 
					                i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        setters = {
 | 
				
			||||||
 | 
					            "header": deserialize_header,
 | 
				
			||||||
 | 
					            "entity_vectors": deserialize_vectors,
 | 
				
			||||||
 | 
					            "entries": deserialize_entries,
 | 
				
			||||||
 | 
					            "aliases": deserialize_aliases,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        util.from_bytes(bytes_data, setters, exclude)
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
 | 
					    def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
 | 
				
			||||||
        path = ensure_path(path)
 | 
					        path = ensure_path(path)
 | 
				
			||||||
        if not path.exists():
 | 
					        if not path.exists():
 | 
				
			||||||
| 
						 | 
					@ -404,10 +508,8 @@ cdef class KnowledgeBase:
 | 
				
			||||||
        cdef int64_t entity_vector_length
 | 
					        cdef int64_t entity_vector_length
 | 
				
			||||||
        reader.read_header(&nr_entities, &entity_vector_length)
 | 
					        reader.read_header(&nr_entities, &entity_vector_length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.initialize_entities(nr_entities)
 | 
				
			||||||
        self.entity_vector_length = entity_vector_length
 | 
					        self.entity_vector_length = entity_vector_length
 | 
				
			||||||
        self._entry_index = PreshMap(nr_entities+1)
 | 
					 | 
				
			||||||
        self._entries = entry_vec(nr_entities+1)
 | 
					 | 
				
			||||||
        self._vectors_table = float_matrix(nr_entities+1)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # STEP 1: load entity vectors
 | 
					        # STEP 1: load entity vectors
 | 
				
			||||||
        cdef int i = 0
 | 
					        cdef int i = 0
 | 
				
			||||||
| 
						 | 
					@ -445,8 +547,7 @@ cdef class KnowledgeBase:
 | 
				
			||||||
        # STEP 3: load aliases
 | 
					        # STEP 3: load aliases
 | 
				
			||||||
        cdef int64_t nr_aliases
 | 
					        cdef int64_t nr_aliases
 | 
				
			||||||
        reader.read_alias_length(&nr_aliases)
 | 
					        reader.read_alias_length(&nr_aliases)
 | 
				
			||||||
        self._alias_index = PreshMap(nr_aliases+1)
 | 
					        self.initialize_aliases(nr_aliases)
 | 
				
			||||||
        self._aliases_table = alias_vec(nr_aliases+1)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cdef int64_t nr_candidates
 | 
					        cdef int64_t nr_candidates
 | 
				
			||||||
        cdef vector[int64_t] entry_indices
 | 
					        cdef vector[int64_t] entry_indices
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -408,6 +408,48 @@ class EntityLinker(TrainablePipe):
 | 
				
			||||||
        validate_examples(examples, "EntityLinker.score")
 | 
					        validate_examples(examples, "EntityLinker.score")
 | 
				
			||||||
        return Scorer.score_links(examples, negative_labels=[self.NIL])
 | 
					        return Scorer.score_links(examples, negative_labels=[self.NIL])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_bytes(self, *, exclude=tuple()):
 | 
				
			||||||
 | 
					        """Serialize the pipe to a bytestring.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exclude (Iterable[str]): String names of serialization fields to exclude.
 | 
				
			||||||
 | 
					        RETURNS (bytes): The serialized object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/entitylinker#to_bytes
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self._validate_serialization_attrs()
 | 
				
			||||||
 | 
					        serialize = {}
 | 
				
			||||||
 | 
					        if hasattr(self, "cfg") and self.cfg is not None:
 | 
				
			||||||
 | 
					            serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
 | 
				
			||||||
 | 
					        serialize["vocab"] = self.vocab.to_bytes
 | 
				
			||||||
 | 
					        serialize["kb"] = self.kb.to_bytes
 | 
				
			||||||
 | 
					        serialize["model"] = self.model.to_bytes
 | 
				
			||||||
 | 
					        return util.to_bytes(serialize, exclude)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def from_bytes(self, bytes_data, *, exclude=tuple()):
 | 
				
			||||||
 | 
					        """Load the pipe from a bytestring.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exclude (Iterable[str]): String names of serialization fields to exclude.
 | 
				
			||||||
 | 
					        RETURNS (TrainablePipe): The loaded object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        DOCS: https://spacy.io/api/entitylinker#from_bytes
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self._validate_serialization_attrs()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def load_model(b):
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                self.model.from_bytes(b)
 | 
				
			||||||
 | 
					            except AttributeError:
 | 
				
			||||||
 | 
					                raise ValueError(Errors.E149) from None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        deserialize = {}
 | 
				
			||||||
 | 
					        if hasattr(self, "cfg") and self.cfg is not None:
 | 
				
			||||||
 | 
					            deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
 | 
				
			||||||
 | 
					        deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
 | 
				
			||||||
 | 
					        deserialize["kb"] = lambda b: self.kb.from_bytes(b)
 | 
				
			||||||
 | 
					        deserialize["model"] = load_model
 | 
				
			||||||
 | 
					        util.from_bytes(bytes_data, deserialize, exclude)
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_disk(
 | 
					    def to_disk(
 | 
				
			||||||
        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
 | 
					        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@ from typing import Callable, Iterable
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from numpy.testing import assert_equal
 | 
					from numpy.testing import assert_equal
 | 
				
			||||||
from spacy.attrs import ENT_KB_ID
 | 
					from spacy.attrs import ENT_KB_ID
 | 
				
			||||||
 | 
					from spacy.compat import pickle
 | 
				
			||||||
from spacy.kb import KnowledgeBase, get_candidates, Candidate
 | 
					from spacy.kb import KnowledgeBase, get_candidates, Candidate
 | 
				
			||||||
from spacy.vocab import Vocab
 | 
					from spacy.vocab import Vocab
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,7 +11,7 @@ from spacy.ml import load_kb
 | 
				
			||||||
from spacy.scorer import Scorer
 | 
					from spacy.scorer import Scorer
 | 
				
			||||||
from spacy.training import Example
 | 
					from spacy.training import Example
 | 
				
			||||||
from spacy.lang.en import English
 | 
					from spacy.lang.en import English
 | 
				
			||||||
from spacy.tests.util import make_tempdir
 | 
					from spacy.tests.util import make_tempdir, make_tempfile
 | 
				
			||||||
from spacy.tokens import Span
 | 
					from spacy.tokens import Span
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -290,6 +290,9 @@ def test_vocab_serialization(nlp):
 | 
				
			||||||
        assert candidates[0].alias == adam_hash
 | 
					        assert candidates[0].alias == adam_hash
 | 
				
			||||||
        assert candidates[0].alias_ == "adam"
 | 
					        assert candidates[0].alias_ == "adam"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert kb_new_vocab.get_vector("Q2") == [2]
 | 
				
			||||||
 | 
					        assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_append_alias(nlp):
 | 
					def test_append_alias(nlp):
 | 
				
			||||||
    """Test that we can append additional alias-entity pairs"""
 | 
					    """Test that we can append additional alias-entity pairs"""
 | 
				
			||||||
| 
						 | 
					@ -546,6 +549,98 @@ def test_kb_serialization():
 | 
				
			||||||
        assert "RandomWord" in nlp2.vocab.strings
 | 
					        assert "RandomWord" in nlp2.vocab.strings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.xfail(reason="Needs fixing")
 | 
				
			||||||
 | 
					def test_kb_pickle():
 | 
				
			||||||
 | 
					    # Test that the KB can be pickled
 | 
				
			||||||
 | 
					    nlp = English()
 | 
				
			||||||
 | 
					    kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
 | 
				
			||||||
 | 
					    kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
 | 
				
			||||||
 | 
					    assert not kb_1.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
 | 
				
			||||||
 | 
					    assert kb_1.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    data = pickle.dumps(kb_1)
 | 
				
			||||||
 | 
					    kb_2 = pickle.loads(data)
 | 
				
			||||||
 | 
					    assert kb_2.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.xfail(reason="Needs fixing")
 | 
				
			||||||
 | 
					def test_nel_pickle():
 | 
				
			||||||
 | 
					    # Test that a pipeline with an EL component can be pickled
 | 
				
			||||||
 | 
					    def create_kb(vocab):
 | 
				
			||||||
 | 
					        kb = KnowledgeBase(vocab, entity_vector_length=3)
 | 
				
			||||||
 | 
					        kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
 | 
				
			||||||
 | 
					        kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
 | 
				
			||||||
 | 
					        return kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nlp_1 = English()
 | 
				
			||||||
 | 
					    nlp_1.add_pipe("ner")
 | 
				
			||||||
 | 
					    entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True)
 | 
				
			||||||
 | 
					    entity_linker_1.set_kb(create_kb)
 | 
				
			||||||
 | 
					    assert nlp_1.pipe_names == ["ner", "entity_linker"]
 | 
				
			||||||
 | 
					    assert entity_linker_1.kb.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data = pickle.dumps(nlp_1)
 | 
				
			||||||
 | 
					    nlp_2 = pickle.loads(data)
 | 
				
			||||||
 | 
					    assert nlp_2.pipe_names == ["ner", "entity_linker"]
 | 
				
			||||||
 | 
					    entity_linker_2 = nlp_2.get_pipe("entity_linker")
 | 
				
			||||||
 | 
					    assert entity_linker_2.kb.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_kb_to_bytes():
 | 
				
			||||||
 | 
					    # Test that the KB's to_bytes method works correctly
 | 
				
			||||||
 | 
					    nlp = English()
 | 
				
			||||||
 | 
					    kb_1 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
 | 
				
			||||||
 | 
					    kb_1.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
 | 
				
			||||||
 | 
					    kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3])
 | 
				
			||||||
 | 
					    kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
 | 
				
			||||||
 | 
					    kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5])
 | 
				
			||||||
 | 
					    kb_1.add_alias(alias="Randomness", entities=["Q66", "Q2146908"], probabilities=[0.1, 0.2])
 | 
				
			||||||
 | 
					    assert kb_1.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    kb_bytes = kb_1.to_bytes()
 | 
				
			||||||
 | 
					    kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
 | 
				
			||||||
 | 
					    assert not kb_2.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    kb_2 = kb_2.from_bytes(kb_bytes)
 | 
				
			||||||
 | 
					    # check that both KBs are exactly the same
 | 
				
			||||||
 | 
					    assert kb_1.get_size_entities() == kb_2.get_size_entities()
 | 
				
			||||||
 | 
					    assert kb_1.entity_vector_length == kb_2.entity_vector_length
 | 
				
			||||||
 | 
					    assert kb_1.get_entity_strings() == kb_2.get_entity_strings()
 | 
				
			||||||
 | 
					    assert kb_1.get_vector("Q2146908") == kb_2.get_vector("Q2146908")
 | 
				
			||||||
 | 
					    assert kb_1.get_vector("Q66") == kb_2.get_vector("Q66")
 | 
				
			||||||
 | 
					    assert kb_2.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    assert kb_1.get_size_aliases() == kb_2.get_size_aliases()
 | 
				
			||||||
 | 
					    assert kb_1.get_alias_strings() == kb_2.get_alias_strings()
 | 
				
			||||||
 | 
					    assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(kb_2.get_alias_candidates("Russ Cochran"))
 | 
				
			||||||
 | 
					    assert len(kb_1.get_alias_candidates("Randomness")) == len(kb_2.get_alias_candidates("Randomness"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_nel_to_bytes():
 | 
				
			||||||
 | 
					    # Test that a pipeline with an EL component can be converted to bytes
 | 
				
			||||||
 | 
					    def create_kb(vocab):
 | 
				
			||||||
 | 
					        kb = KnowledgeBase(vocab, entity_vector_length=3)
 | 
				
			||||||
 | 
					        kb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
 | 
				
			||||||
 | 
					        kb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
 | 
				
			||||||
 | 
					        return kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nlp_1 = English()
 | 
				
			||||||
 | 
					    nlp_1.add_pipe("ner")
 | 
				
			||||||
 | 
					    entity_linker_1 = nlp_1.add_pipe("entity_linker", last=True)
 | 
				
			||||||
 | 
					    entity_linker_1.set_kb(create_kb)
 | 
				
			||||||
 | 
					    assert entity_linker_1.kb.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    assert nlp_1.pipe_names == ["ner", "entity_linker"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nlp_bytes = nlp_1.to_bytes()
 | 
				
			||||||
 | 
					    nlp_2 = English()
 | 
				
			||||||
 | 
					    nlp_2.add_pipe("ner")
 | 
				
			||||||
 | 
					    nlp_2.add_pipe("entity_linker", last=True)
 | 
				
			||||||
 | 
					    assert nlp_2.pipe_names == ["ner", "entity_linker"]
 | 
				
			||||||
 | 
					    assert not nlp_2.get_pipe("entity_linker").kb.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    nlp_2 = nlp_2.from_bytes(nlp_bytes)
 | 
				
			||||||
 | 
					    kb_2 = nlp_2.get_pipe("entity_linker").kb
 | 
				
			||||||
 | 
					    assert kb_2.contains_alias("Russ Cochran")
 | 
				
			||||||
 | 
					    assert kb_2.get_vector("Q2146908") == [6, -4, 3]
 | 
				
			||||||
 | 
					    assert_almost_equal(kb_2.get_prior_prob(entity="Q2146908", alias="Russ Cochran"), 0.8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_scorer_links():
 | 
					def test_scorer_links():
 | 
				
			||||||
    train_examples = []
 | 
					    train_examples = []
 | 
				
			||||||
    nlp = English()
 | 
					    nlp = English()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,10 +213,10 @@ if there is no prediction.
 | 
				
			||||||
> kb_ids = entity_linker.predict([doc1, doc2])
 | 
					> kb_ids = entity_linker.predict([doc1, doc2])
 | 
				
			||||||
> ```
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| Name        | Description                                 |
 | 
					| Name        | Description                                                                |
 | 
				
			||||||
| ----------- | ------------------------------------------- |
 | 
					| ----------- | -------------------------------------------------------------------------- |
 | 
				
			||||||
| `docs`      | The documents to predict. ~~Iterable[Doc]~~ |
 | 
					| `docs`      | The documents to predict. ~~Iterable[Doc]~~                                |
 | 
				
			||||||
| **RETURNS** | `List[str]`                                 | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ |
 | 
					| **RETURNS** | The predicted KB identifiers for the entities in the `docs`. ~~List[str]~~ |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## EntityLinker.set_annotations {#set_annotations tag="method"}
 | 
					## EntityLinker.set_annotations {#set_annotations tag="method"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -341,6 +341,42 @@ Load the pipe from disk. Modifies the object in place and returns it.
 | 
				
			||||||
| `exclude`      | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~     |
 | 
					| `exclude`      | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~     |
 | 
				
			||||||
| **RETURNS**    | The modified `EntityLinker` object. ~~EntityLinker~~                                            |
 | 
					| **RETURNS**    | The modified `EntityLinker` object. ~~EntityLinker~~                                            |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## EntityLinker.to_bytes {#to_bytes tag="method"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					> #### Example
 | 
				
			||||||
 | 
					>
 | 
				
			||||||
 | 
					> ```python
 | 
				
			||||||
 | 
					> entity_linker = nlp.add_pipe("entity_linker")
 | 
				
			||||||
 | 
					> entity_linker_bytes = entity_linker.to_bytes()
 | 
				
			||||||
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Serialize the pipe to a bytestring, including the `KnowledgeBase`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Name           | Description                                                                                 |
 | 
				
			||||||
 | 
					| -------------- | ------------------------------------------------------------------------------------------- |
 | 
				
			||||||
 | 
					| _keyword-only_ |                                                                                             |
 | 
				
			||||||
 | 
					| `exclude`      | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
 | 
				
			||||||
 | 
					| **RETURNS**    | The serialized form of the `EntityLinker` object. ~~bytes~~                                 |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## EntityLinker.from_bytes {#from_bytes tag="method"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Load the pipe from a bytestring. Modifies the object in place and returns it.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					> #### Example
 | 
				
			||||||
 | 
					>
 | 
				
			||||||
 | 
					> ```python
 | 
				
			||||||
 | 
					> entity_linker_bytes = entity_linker.to_bytes()
 | 
				
			||||||
 | 
					> entity_linker = nlp.add_pipe("entity_linker")
 | 
				
			||||||
 | 
					> entity_linker.from_bytes(entity_linker_bytes)
 | 
				
			||||||
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Name           | Description                                                                                 |
 | 
				
			||||||
 | 
					| -------------- | ------------------------------------------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `bytes_data`   | The data to load from. ~~bytes~~                                                            |
 | 
				
			||||||
 | 
					| _keyword-only_ |                                                                                             |
 | 
				
			||||||
 | 
					| `exclude`      | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
 | 
				
			||||||
 | 
					| **RETURNS**    | The `EntityLinker` object. ~~EntityLinker~~                                                 |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Serialization fields {#serialization-fields}
 | 
					## Serialization fields {#serialization-fields}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
During serialization, spaCy will export several data fields used to restore
 | 
					During serialization, spaCy will export several data fields used to restore
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user