mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Change type for mentions to look up entity candidates for to SpanGroup from Iterable[Span].
This commit is contained in:
		
							parent
							
								
									a97ef65b33
								
							
						
					
					
						commit
						8596fb8b88
					
				|  | @ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union | |||
| from cymem.cymem cimport Pool | ||||
| 
 | ||||
| from .candidate import Candidate | ||||
| from ..tokens import Span | ||||
| from ..tokens import Span, SpanGroup | ||||
| from ..util import SimpleFrozenList | ||||
| from ..errors import Errors | ||||
| 
 | ||||
|  | @ -30,12 +30,12 @@ cdef class KnowledgeBase: | |||
|         self.entity_vector_length = entity_vector_length | ||||
|         self.mem = Pool() | ||||
| 
 | ||||
|     def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: | ||||
|     def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]: | ||||
|         """ | ||||
|         Return candidate entities for specified texts. Each candidate defines the entity, the original alias, | ||||
|         and the prior probability of that alias resolving to that entity. | ||||
|         If no candidate is found for a given text, an empty list is returned. | ||||
|         mentions (Iterable[Span]): Mentions for which to get candidates. | ||||
|         mentions (SpanGroup): Mentions for which to get candidates. | ||||
|         RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. | ||||
|         """ | ||||
|         return [self.get_candidates(span) for span in mentions] | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ from ...util import registry | |||
| from ...kb import KnowledgeBase, InMemoryLookupKB | ||||
| from ...kb import Candidate | ||||
| from ...vocab import Vocab | ||||
| from ...tokens import Span, Doc | ||||
| from ...tokens import Doc, Span, SpanGroup | ||||
| from ..extract_spans import extract_spans | ||||
| from ...errors import Errors | ||||
| 
 | ||||
|  | @ -106,7 +106,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: | |||
| 
 | ||||
| @registry.misc("spacy.CandidateBatchGenerator.v1") | ||||
| def create_candidates_batch() -> Callable[ | ||||
|     [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||
|     [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] | ||||
| ]: | ||||
|     return get_candidates_batch | ||||
| 
 | ||||
|  | @ -122,12 +122,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: | |||
| 
 | ||||
| 
 | ||||
| def get_candidates_batch( | ||||
|     kb: KnowledgeBase, mentions: Iterable[Span] | ||||
|     kb: KnowledgeBase, mentions: SpanGroup | ||||
| ) -> Iterable[Iterable[Candidate]]: | ||||
|     """ | ||||
|     Return candidate entities for the given mentions and fetching appropriate entries from the index. | ||||
|     kb (KnowledgeBase): Knowledge base to query. | ||||
|     mention (Iterable[Span]): Entity mentions for which to identify candidates. | ||||
|     mention (SpanGroup): Entity mentions for which to identify candidates. | ||||
|     RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. | ||||
|     """ | ||||
|     return kb.get_candidates_batch(mentions) | ||||
|  |  | |||
|  | @ -9,9 +9,10 @@ import random | |||
| from thinc.api import CosineDistance, Model, Optimizer, Config | ||||
| from thinc.api import set_dropout_rate | ||||
| 
 | ||||
| from ..tokens import SpanGroup | ||||
| from ..kb import KnowledgeBase, Candidate | ||||
| from ..ml import empty_kb | ||||
| from ..tokens import Doc, Span | ||||
| from ..tokens import Doc, Span, SpanGroup | ||||
| from .pipe import deserialize_config | ||||
| from .trainable_pipe import TrainablePipe | ||||
| from ..language import Language | ||||
|  | @ -82,7 +83,7 @@ def make_entity_linker( | |||
|     entity_vector_length: int, | ||||
|     get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], | ||||
|     get_candidates_batch: Callable[ | ||||
|         [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||
|         [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] | ||||
|     ], | ||||
|     overwrite: bool, | ||||
|     scorer: Optional[Callable], | ||||
|  | @ -104,7 +105,7 @@ def make_entity_linker( | |||
|     get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that | ||||
|         produces a list of candidates, given a certain knowledge base and a textual mention. | ||||
|     get_candidates_batch ( | ||||
|         Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] | ||||
|         Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] | ||||
|         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||
|     scorer (Optional[Callable]): The scoring method. | ||||
|     use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another | ||||
|  | @ -186,7 +187,7 @@ class EntityLinker(TrainablePipe): | |||
|         entity_vector_length: int, | ||||
|         get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], | ||||
|         get_candidates_batch: Callable[ | ||||
|             [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] | ||||
|             [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] | ||||
|         ], | ||||
|         overwrite: bool = False, | ||||
|         scorer: Optional[Callable] = entity_linker_score, | ||||
|  | @ -209,7 +210,7 @@ class EntityLinker(TrainablePipe): | |||
|         get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that | ||||
|             produces a list of candidates, given a certain knowledge base and a textual mention. | ||||
|         get_candidates_batch ( | ||||
|             Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], | ||||
|             Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], | ||||
|             Iterable[Candidate]] | ||||
|         ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. | ||||
|         overwrite (bool): Whether to overwrite existing non-empty annotations. | ||||
|  | @ -485,7 +486,8 @@ class EntityLinker(TrainablePipe): | |||
| 
 | ||||
|                 batch_candidates = list( | ||||
|                     self.get_candidates_batch( | ||||
|                         self.kb, [ent_batch[idx] for idx in valid_ent_idx] | ||||
|                         self.kb, | ||||
|                         SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]), | ||||
|                     ) | ||||
|                     if self.candidates_batch_size > 1 | ||||
|                     else [ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user