mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	Fix kb/kb_in_memory.pyx.
This commit is contained in:
		
							parent
							
								
									6c483971b7
								
							
						
					
					
						commit
						a0bf50661b
					
				|  | @ -1,5 +1,5 @@ | |||
| # cython: infer_types=True, profile=True | ||||
| from typing import Any, Callable, Dict, Iterable, Union | ||||
| from typing import Any, Callable, Dict, Iterable | ||||
| 
 | ||||
| import srsly | ||||
| 
 | ||||
|  | @ -27,8 +27,9 @@ from .candidate import Candidate as Candidate | |||
| 
 | ||||
| 
 | ||||
| cdef class InMemoryLookupKB(KnowledgeBase): | ||||
|     """An `InMemoryLookupKB` instance stores unique identifiers for entities and their textual aliases, | ||||
|     to support entity linking of named entities to real-world concepts. | ||||
|     """An `InMemoryLookupKB` instance stores unique identifiers for entities | ||||
|     and their textual aliases, to support entity linking of named entities to | ||||
|     real-world concepts. | ||||
| 
 | ||||
|     DOCS: https://spacy.io/api/inmemorylookupkb | ||||
|     """ | ||||
|  | @ -71,7 +72,8 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
| 
 | ||||
|     def add_entity(self, str entity, float freq, vector[float] entity_vector): | ||||
|         """ | ||||
|         Add an entity to the KB, optionally specifying its log probability based on corpus frequency | ||||
|         Add an entity to the KB, optionally specifying its log probability | ||||
|         based on corpus frequency. | ||||
|         Return the hash of the entity ID/name at the end. | ||||
|         """ | ||||
|         cdef hash_t entity_hash = self.vocab.strings.add(entity) | ||||
|  | @ -83,14 +85,20 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
| 
 | ||||
|         # Raise an error if the provided entity vector is not of the correct length | ||||
|         if len(entity_vector) != self.entity_vector_length: | ||||
|             raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length)) | ||||
|             raise ValueError( | ||||
|                 Errors.E141.format( | ||||
|                     found=len(entity_vector), required=self.entity_vector_length | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         vector_index = self.c_add_vector(entity_vector=entity_vector) | ||||
| 
 | ||||
|         new_index = self.c_add_entity(entity_hash=entity_hash, | ||||
|                                       freq=freq, | ||||
|                                       vector_index=vector_index, | ||||
|                                       feats_row=-1)  # Features table currently not implemented | ||||
|         new_index = self.c_add_entity( | ||||
|             entity_hash=entity_hash, | ||||
|             freq=freq, | ||||
|             vector_index=vector_index, | ||||
|             feats_row=-1 | ||||
|         )  # Features table currently not implemented | ||||
|         self._entry_index[entity_hash] = new_index | ||||
| 
 | ||||
|         return entity_hash | ||||
|  | @ -115,7 +123,12 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|             else: | ||||
|                 entity_vector = vector_list[i] | ||||
|                 if len(entity_vector) != self.entity_vector_length: | ||||
|                     raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length)) | ||||
|                     raise ValueError( | ||||
|                         Errors.E141.format( | ||||
|                             found=len(entity_vector), | ||||
|                             required=self.entity_vector_length | ||||
|                         ) | ||||
|                     ) | ||||
| 
 | ||||
|                 entry.entity_hash = entity_hash | ||||
|                 entry.freq = freq_list[i] | ||||
|  | @ -149,11 +162,15 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|         previous_alias_nr = self.get_size_aliases() | ||||
|         # Throw an error if the length of entities and probabilities are not the same | ||||
|         if not len(entities) == len(probabilities): | ||||
|             raise ValueError(Errors.E132.format(alias=alias, | ||||
|                                                 entities_length=len(entities), | ||||
|                                                 probabilities_length=len(probabilities))) | ||||
|             raise ValueError( | ||||
|                 Errors.E132.format( | ||||
|                     alias=alias, | ||||
|                     entities_length=len(entities), | ||||
|                     probabilities_length=len(probabilities)) | ||||
|             ) | ||||
| 
 | ||||
|         # Throw an error if the probabilities sum up to more than 1 (allow for some rounding errors) | ||||
|         # Throw an error if the probabilities sum up to more than 1 (allow for | ||||
|         # some rounding errors) | ||||
|         prob_sum = sum(probabilities) | ||||
|         if prob_sum > 1.00001: | ||||
|             raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum)) | ||||
|  | @ -170,40 +187,47 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
| 
 | ||||
|         for entity, prob in zip(entities, probabilities): | ||||
|             entity_hash = self.vocab.strings[entity] | ||||
|             if not entity_hash in self._entry_index: | ||||
|             if entity_hash not in self._entry_index: | ||||
|                 raise ValueError(Errors.E134.format(entity=entity)) | ||||
| 
 | ||||
|             entry_index = <int64_t>self._entry_index.get(entity_hash) | ||||
|             entry_indices.push_back(int(entry_index)) | ||||
|             probs.push_back(float(prob)) | ||||
| 
 | ||||
|         new_index = self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs) | ||||
|         new_index = self.c_add_aliases( | ||||
|             alias_hash=alias_hash, entry_indices=entry_indices, probs=probs | ||||
|         ) | ||||
|         self._alias_index[alias_hash] = new_index | ||||
| 
 | ||||
|         if previous_alias_nr + 1 != self.get_size_aliases(): | ||||
|             raise RuntimeError(Errors.E891.format(alias=alias)) | ||||
|         return alias_hash | ||||
| 
 | ||||
|     def append_alias(self, str alias, str entity, float prior_prob, ignore_warnings=False): | ||||
|     def append_alias( | ||||
|         self, str alias, str entity, float prior_prob, ignore_warnings=False | ||||
|     ): | ||||
|         """ | ||||
|         For an alias already existing in the KB, extend its potential entities with one more. | ||||
|         For an alias already existing in the KB, extend its potential entities | ||||
|         with one more. | ||||
|         Throw a warning if either the alias or the entity is unknown, | ||||
|         or when the combination is already previously recorded. | ||||
|         Throw an error if this entity+prior prob would exceed the sum of 1. | ||||
|         For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. | ||||
|         For efficiency, it's best to use the method `add_alias` as much as | ||||
|         possible instead of this one. | ||||
|         """ | ||||
|         # Check if the alias exists in the KB | ||||
|         cdef hash_t alias_hash = self.vocab.strings[alias] | ||||
|         if not alias_hash in self._alias_index: | ||||
|         if alias_hash not in self._alias_index: | ||||
|             raise ValueError(Errors.E176.format(alias=alias)) | ||||
| 
 | ||||
|         # Check if the entity exists in the KB | ||||
|         cdef hash_t entity_hash = self.vocab.strings[entity] | ||||
|         if not entity_hash in self._entry_index: | ||||
|         if entity_hash not in self._entry_index: | ||||
|             raise ValueError(Errors.E134.format(entity=entity)) | ||||
|         entry_index = <int64_t>self._entry_index.get(entity_hash) | ||||
| 
 | ||||
|         # Throw an error if the prior probabilities (including the new one) sum up to more than 1 | ||||
|         # Throw an error if the prior probabilities (including the new one) | ||||
|         # sum up to more than 1 | ||||
|         alias_index = <int64_t>self._alias_index.get(alias_hash) | ||||
|         alias_entry = self._aliases_table[alias_index] | ||||
|         current_sum = sum([p for p in alias_entry.probs]) | ||||
|  | @ -236,12 +260,13 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
| 
 | ||||
|     def get_alias_candidates(self, str alias) -> Iterable[Candidate]: | ||||
|         """ | ||||
|         Return candidate entities for an alias. Each candidate defines the entity, the original alias, | ||||
|         and the prior probability of that alias resolving to that entity. | ||||
|         Return candidate entities for an alias. Each candidate defines the | ||||
|         entity, the original alias, and the prior probability of that alias | ||||
|         resolving to that entity. | ||||
|         If the alias is not known in the KB, and empty list is returned. | ||||
|         """ | ||||
|         cdef hash_t alias_hash = self.vocab.strings[alias] | ||||
|         if not alias_hash in self._alias_index: | ||||
|         if alias_hash not in self._alias_index: | ||||
|             return [] | ||||
|         alias_index = <int64_t>self._alias_index.get(alias_hash) | ||||
|         alias_entry = self._aliases_table[alias_index] | ||||
|  | @ -249,10 +274,14 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|         return [Candidate(kb=self, | ||||
|                           entity_hash=self._entries[entry_index].entity_hash, | ||||
|                           entity_freq=self._entries[entry_index].freq, | ||||
|                           entity_vector=self._vectors_table[self._entries[entry_index].vector_index], | ||||
|                           entity_vector=self._vectors_table[ | ||||
|                               self._entries[entry_index].vector_index | ||||
|                           ], | ||||
|                           alias_hash=alias_hash, | ||||
|                           prior_prob=prior_prob) | ||||
|                 for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) | ||||
|                 for (entry_index, prior_prob) in zip( | ||||
|                     alias_entry.entry_indices, alias_entry.probs | ||||
|                 ) | ||||
|                 if entry_index != 0] | ||||
| 
 | ||||
|     def get_vector(self, str entity): | ||||
|  | @ -266,8 +295,9 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|         return self._vectors_table[self._entries[entry_index].vector_index] | ||||
| 
 | ||||
|     def get_prior_prob(self, str entity, str alias): | ||||
|         """ Return the prior probability of a given alias being linked to a given entity, | ||||
|         or return 0.0 when this combination is not known in the knowledge base""" | ||||
|         """ Return the prior probability of a given alias being linked to a | ||||
|         given entity, or return 0.0 when this combination is not known in the | ||||
|         knowledge base.""" | ||||
|         cdef hash_t alias_hash = self.vocab.strings[alias] | ||||
|         cdef hash_t entity_hash = self.vocab.strings[entity] | ||||
| 
 | ||||
|  | @ -278,7 +308,9 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|         entry_index = self._entry_index[entity_hash] | ||||
| 
 | ||||
|         alias_entry = self._aliases_table[alias_index] | ||||
|         for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs): | ||||
|         for (entry_index, prior_prob) in zip( | ||||
|             alias_entry.entry_indices, alias_entry.probs | ||||
|         ): | ||||
|             if self._entries[entry_index].entity_hash == entity_hash: | ||||
|                 return prior_prob | ||||
| 
 | ||||
|  | @ -288,13 +320,19 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|         """Serialize the current state to a binary string. | ||||
|         """ | ||||
|         def serialize_header(): | ||||
|             header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length) | ||||
|             header = ( | ||||
|                 self.get_size_entities(), | ||||
|                 self.get_size_aliases(), | ||||
|                 self.entity_vector_length | ||||
|             ) | ||||
|             return srsly.json_dumps(header) | ||||
| 
 | ||||
|         def serialize_entries(): | ||||
|             i = 1 | ||||
|             tuples = [] | ||||
|             for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): | ||||
|             for entry_hash, entry_index in sorted( | ||||
|                 self._entry_index.items(), key=lambda x: x[1] | ||||
|             ): | ||||
|                 entry = self._entries[entry_index] | ||||
|                 assert entry.entity_hash == entry_hash | ||||
|                 assert entry_index == i | ||||
|  | @ -307,7 +345,9 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|             headers = [] | ||||
|             indices_lists = [] | ||||
|             probs_lists = [] | ||||
|             for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]): | ||||
|             for alias_hash, alias_index in sorted( | ||||
|                 self._alias_index.items(), key=lambda x: x[1] | ||||
|             ): | ||||
|                 alias = self._aliases_table[alias_index] | ||||
|                 assert alias_index == i | ||||
|                 candidate_length = len(alias.entry_indices) | ||||
|  | @ -365,7 +405,7 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|             indices = srsly.json_loads(all_data[1]) | ||||
|             probs = srsly.json_loads(all_data[2]) | ||||
|             for header, indices, probs in zip(headers, indices, probs): | ||||
|                 alias_hash, candidate_length = header | ||||
|                 alias_hash, _candidate_length = header | ||||
|                 alias.entry_indices = indices | ||||
|                 alias.probs = probs | ||||
|                 self._aliases_table[i] = alias | ||||
|  | @ -414,10 +454,14 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|                 writer.write_vector_element(element) | ||||
|             i = i+1 | ||||
| 
 | ||||
|         # dumping the entry records in the order in which they are in the _entries vector. | ||||
|         # index 0 is a dummy object not stored in the _entry_index and can be ignored. | ||||
|         # dumping the entry records in the order in which they are in the | ||||
|         # _entries vector. | ||||
|         # index 0 is a dummy object not stored in the _entry_index and can | ||||
|         # be ignored. | ||||
|         i = 1 | ||||
|         for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): | ||||
|         for entry_hash, entry_index in sorted( | ||||
|             self._entry_index.items(), key=lambda x: x[1] | ||||
|         ): | ||||
|             entry = self._entries[entry_index] | ||||
|             assert entry.entity_hash == entry_hash | ||||
|             assert entry_index == i | ||||
|  | @ -429,7 +473,9 @@ cdef class InMemoryLookupKB(KnowledgeBase): | |||
|         # dumping the aliases in the order in which they are in the _alias_index vector. | ||||
|         # index 0 is a dummy object not stored in the _aliases_table and can be ignored. | ||||
|         i = 1 | ||||
|         for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]): | ||||
|         for alias_hash, alias_index in sorted( | ||||
|                 self._alias_index.items(), key=lambda x: x[1] | ||||
|         ): | ||||
|             alias = self._aliases_table[alias_index] | ||||
|             assert alias_index == i | ||||
| 
 | ||||
|  | @ -535,7 +581,8 @@ cdef class Writer: | |||
|     def __init__(self, path): | ||||
|         assert isinstance(path, Path) | ||||
|         content = bytes(path) | ||||
|         cdef bytes bytes_loc = content.encode('utf8') if type(content) == str else content | ||||
|         cdef bytes bytes_loc = content.encode('utf8') \ | ||||
|             if type(content) == str else content | ||||
|         self._fp = fopen(<char*>bytes_loc, 'wb') | ||||
|         if not self._fp: | ||||
|             raise IOError(Errors.E146.format(path=path)) | ||||
|  | @ -545,14 +592,18 @@ cdef class Writer: | |||
|         cdef size_t status = fclose(self._fp) | ||||
|         assert status == 0 | ||||
| 
 | ||||
|     cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1: | ||||
|     cdef int write_header( | ||||
|         self, int64_t nr_entries, int64_t entity_vector_length | ||||
|     ) except -1: | ||||
|         self._write(&nr_entries, sizeof(nr_entries)) | ||||
|         self._write(&entity_vector_length, sizeof(entity_vector_length)) | ||||
| 
 | ||||
|     cdef int write_vector_element(self, float element) except -1: | ||||
|         self._write(&element, sizeof(element)) | ||||
| 
 | ||||
|     cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1: | ||||
|     cdef int write_entry( | ||||
|         self, hash_t entry_hash, float entry_freq, int32_t vector_index | ||||
|     ) except -1: | ||||
|         self._write(&entry_hash, sizeof(entry_hash)) | ||||
|         self._write(&entry_freq, sizeof(entry_freq)) | ||||
|         self._write(&vector_index, sizeof(vector_index)) | ||||
|  | @ -561,7 +612,9 @@ cdef class Writer: | |||
|     cdef int write_alias_length(self, int64_t alias_length) except -1: | ||||
|         self._write(&alias_length, sizeof(alias_length)) | ||||
| 
 | ||||
|     cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1: | ||||
|     cdef int write_alias_header( | ||||
|         self, hash_t alias_hash, int64_t candidate_length | ||||
|     ) except -1: | ||||
|         self._write(&alias_hash, sizeof(alias_hash)) | ||||
|         self._write(&candidate_length, sizeof(candidate_length)) | ||||
| 
 | ||||
|  | @ -577,16 +630,19 @@ cdef class Writer: | |||
| cdef class Reader: | ||||
|     def __init__(self, path): | ||||
|         content = bytes(path) | ||||
|         cdef bytes bytes_loc = content.encode('utf8') if type(content) == str else content | ||||
|         cdef bytes bytes_loc = content.encode('utf8') \ | ||||
|             if type(content) == str else content | ||||
|         self._fp = fopen(<char*>bytes_loc, 'rb') | ||||
|         if not self._fp: | ||||
|             PyErr_SetFromErrno(IOError) | ||||
|         status = fseek(self._fp, 0, 0)  # this can be 0 if there is no header | ||||
|         fseek(self._fp, 0, 0)  # this can be 0 if there is no header | ||||
| 
 | ||||
|     def __dealloc__(self): | ||||
|         fclose(self._fp) | ||||
| 
 | ||||
|     cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1: | ||||
|     cdef int read_header( | ||||
|         self, int64_t* nr_entries, int64_t* entity_vector_length | ||||
|     ) except -1: | ||||
|         status = self._read(nr_entries, sizeof(int64_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|  | @ -606,7 +662,9 @@ cdef class Reader: | |||
|                 return 0  # end of file | ||||
|             raise IOError(Errors.E145.format(param="vector element")) | ||||
| 
 | ||||
|     cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1: | ||||
|     cdef int read_entry( | ||||
|         self, hash_t* entity_hash, float* freq, int32_t* vector_index | ||||
|     ) except -1: | ||||
|         status = self._read(entity_hash, sizeof(hash_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|  | @ -637,7 +695,9 @@ cdef class Reader: | |||
|                 return 0  # end of file | ||||
|             raise IOError(Errors.E145.format(param="alias length")) | ||||
| 
 | ||||
|     cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1: | ||||
|     cdef int read_alias_header( | ||||
|         self, hash_t* alias_hash, int64_t* candidate_length | ||||
|     ) except -1: | ||||
|         status = self._read(alias_hash, sizeof(hash_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user