mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Change type for mentions to look up entity candidates for to SpanGroup from Iterable[Span].
This commit is contained in:
		
							parent
							
								
									a97ef65b33
								
							
						
					
					
						commit
						8596fb8b88
					
				|  | @ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union | ||||||
| from cymem.cymem cimport Pool | from cymem.cymem cimport Pool | ||||||
| 
 | 
 | ||||||
| from .candidate import Candidate | from .candidate import Candidate | ||||||
| from ..tokens import Span | from ..tokens import Span, SpanGroup | ||||||
| from ..util import SimpleFrozenList | from ..util import SimpleFrozenList | ||||||
| from ..errors import Errors | from ..errors import Errors | ||||||
| 
 | 
 | ||||||
|  | @ -30,12 +30,12 @@ cdef class KnowledgeBase: | ||||||
|         self.entity_vector_length = entity_vector_length |         self.entity_vector_length = entity_vector_length | ||||||
|         self.mem = Pool() |         self.mem = Pool() | ||||||
| 
 | 
 | ||||||
|     def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: |     def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]: | ||||||
|         """ |         """ | ||||||
|         Return candidate entities for specified texts. Each candidate defines the entity, the original alias, |         Return candidate entities for specified texts. Each candidate defines the entity, the original alias, | ||||||
|         and the prior probability of that alias resolving to that entity. |         and the prior probability of that alias resolving to that entity. | ||||||
|         If no candidate is found for a given text, an empty list is returned. |         If no candidate is found for a given text, an empty list is returned. | ||||||
|         mentions (Iterable[Span]): Mentions for which to get candidates. |         mentions (SpanGroup): Mentions for which to get candidates. | ||||||
|         RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. |         RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. | ||||||
|         """ |         """ | ||||||
|         return [self.get_candidates(span) for span in mentions] |         return [self.get_candidates(span) for span in mentions] | ||||||
|  |  | ||||||
|  | @ -8,7 +8,7 @@ from ...util import registry | ||||||
| from ...kb import KnowledgeBase, InMemoryLookupKB | from ...kb import KnowledgeBase, InMemoryLookupKB | ||||||
| from ...kb import Candidate | from ...kb import Candidate | ||||||
| from ...vocab import Vocab | from ...vocab import Vocab | ||||||
| from ...tokens import Span, Doc | from ...tokens import Doc, Span, SpanGroup | ||||||
| from ..extract_spans import extract_spans | from ..extract_spans import extract_spans | ||||||
| from ...errors import Errors | from ...errors import Errors | ||||||
| 
 | 
 | ||||||
|  | @ -106,7 +106,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: | ||||||
| 
 | 
 | ||||||
| @registry.misc("spacy.CandidateBatchGenerator.v1") | @registry.misc("spacy.CandidateBatchGenerator.v1") | ||||||
| def create_candidates_batch() -> Callable[ | def create_candidates_batch() -> Callable[ | ||||||
|     [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] |     [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] | ||||||
| ]: | ]: | ||||||
|     return get_candidates_batch |     return get_candidates_batch | ||||||
| 
 | 
 | ||||||
|  | @ -122,12 +122,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_candidates_batch( | def get_candidates_batch( | ||||||
|     kb: KnowledgeBase, mentions: Iterable[Span] |     kb: KnowledgeBase, mentions: SpanGroup | ||||||
| ) -> Iterable[Iterable[Candidate]]: | ) -> Iterable[Iterable[Candidate]]: | ||||||
|     """ |     """ | ||||||
|     Return candidate entities for the given mentions and fetching appropriate entries from the index. |     Return candidate entities for the given mentions and fetching appropriate entries from the index. | ||||||
|     kb (KnowledgeBase): Knowledge base to query. |     kb (KnowledgeBase): Knowledge base to query. | ||||||
|     mention (Iterable[Span]): Entity mentions for which to identify candidates. |     mention (SpanGroup): Entity mentions for which to identify candidates. | ||||||
|     RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. |     RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. | ||||||
|     """ |     """ | ||||||
|     return kb.get_candidates_batch(mentions) |     return kb.get_candidates_batch(mentions) | ||||||
|  |  | ||||||
|  | @ -9,9 +9,10 @@ import random | ||||||
| from thinc.api import CosineDistance, Model, Optimizer, Config | from thinc.api import CosineDistance, Model, Optimizer, Config | ||||||
| from thinc.api import set_dropout_rate | from thinc.api import set_dropout_rate | ||||||
| 
 | 
 | ||||||
|  | from ..tokens import SpanGroup | ||||||
| from ..kb import KnowledgeBase, Candidate | from ..kb import KnowledgeBase, Candidate | ||||||
| from ..ml import empty_kb | from ..ml import empty_kb | ||||||
| from ..tokens import Doc, Span | from ..tokens import Doc, Span, SpanGroup | ||||||
| from .pipe import deserialize_config | from .pipe import deserialize_config | ||||||
| from .trainable_pipe import TrainablePipe | from .trainable_pipe import TrainablePipe | ||||||
| from ..language import Language | from ..language import Language | ||||||
|  | @ -82,7 +83,7 @@ def make_entity_linker( | ||||||
|     entity_vector_length: int, |     entity_vector_length: int, | ||||||
|     get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], |     get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], | ||||||
|     get_candidates_batch: Callable[ |     get_candidates_batch: Callable[ | ||||||
|         [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] |         [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] | ||||||
|     ], |     ], | ||||||
|     overwrite: bool, |     overwrite: bool, | ||||||
|     scorer: Optional[Callable], |     scorer: Optional[Callable], | ||||||
|  | @ -104,7 +105,7 @@ def make_entity_linker( | ||||||
|     get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that |     get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that | ||||||
|         produces a list of candidates, given a certain knowledge base and a textual mention. |         produces a list of candidates, given a certain knowledge base and a textual mention. | ||||||
|     get_candidates_batch ( |     get_candidates_batch ( | ||||||
|         Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] |         Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] | ||||||
|         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. |         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||||
|     scorer (Optional[Callable]): The scoring method. |     scorer (Optional[Callable]): The scoring method. | ||||||
|     use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another |     use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||||
|  | @ -186,7 +187,7 @@ class EntityLinker(TrainablePipe): | ||||||
|         entity_vector_length: int, |         entity_vector_length: int, | ||||||
|         get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], |         get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], | ||||||
|         get_candidates_batch: Callable[ |         get_candidates_batch: Callable[ | ||||||
|             [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] |             [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] | ||||||
|         ], |         ], | ||||||
|         overwrite: bool = False, |         overwrite: bool = False, | ||||||
|         scorer: Optional[Callable] = entity_linker_score, |         scorer: Optional[Callable] = entity_linker_score, | ||||||
|  | @ -209,9 +210,9 @@ class EntityLinker(TrainablePipe): | ||||||
|         get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that |         get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that | ||||||
|             produces a list of candidates, given a certain knowledge base and a textual mention. |             produces a list of candidates, given a certain knowledge base and a textual mention. | ||||||
|         get_candidates_batch ( |         get_candidates_batch ( | ||||||
|             Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], |             Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], | ||||||
|             Iterable[Candidate]] |             Iterable[Candidate]] | ||||||
|             ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. |         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||||
|         overwrite (bool): Whether to overwrite existing non-empty annotations. |         overwrite (bool): Whether to overwrite existing non-empty annotations. | ||||||
|         scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. |         scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. | ||||||
|         use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another |         use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||||
|  | @ -485,7 +486,8 @@ class EntityLinker(TrainablePipe): | ||||||
| 
 | 
 | ||||||
|                 batch_candidates = list( |                 batch_candidates = list( | ||||||
|                     self.get_candidates_batch( |                     self.get_candidates_batch( | ||||||
|                         self.kb, [ent_batch[idx] for idx in valid_ent_idx] |                         self.kb, | ||||||
|  |                         SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]), | ||||||
|                     ) |                     ) | ||||||
|                     if self.candidates_batch_size > 1 |                     if self.candidates_batch_size > 1 | ||||||
|                     else [ |                     else [ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user