mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Inherit from SpanCat instead of TrainablePipe
This commit changes the inheritance structure of Exclusive_Spancat, now it's inheriting from SpanCategorizer than TrainablePipe. This allows me to remove duplicate methods that are already present in the parent function.
This commit is contained in:
		
							parent
							
								
									bdf2a1d1fe
								
							
						
					
					
						commit
						8548e2c311
					
				|  | @ -13,7 +13,7 @@ from ..tokens import Doc, Span, SpanGroup | |||
| from ..training import Example, validate_examples | ||||
| from ..vocab import Vocab | ||||
| from .spancat import spancat_score, build_ngram_suggester | ||||
| from .trainable_pipe import TrainablePipe | ||||
| from .spancat import SpanCategorizer | ||||
| 
 | ||||
| 
 | ||||
| spancat_exclusive_default_config = """ | ||||
|  | @ -71,7 +71,7 @@ def make_spancat( | |||
|     scorer: Optional[Callable], | ||||
|     negative_weight: float = 1.0, | ||||
|     allow_overlap: bool = True, | ||||
| ) -> "SpanCategorizerExclusive": | ||||
| ) -> "Exclusive_SpanCategorizer": | ||||
|     """Create a SpanCategorizerExclusive component. The span categorizer consists of two | ||||
|     parts: a suggester function that proposes candidate spans, and a labeller | ||||
|     model that predicts a single label for each span. | ||||
|  | @ -94,7 +94,7 @@ def make_spancat( | |||
|     allow_overlap (bool): If True the data is assumed to | ||||
|         contain overlapping spans. | ||||
|     """ | ||||
|     return SpanCategorizerExclusive( | ||||
|     return Exclusive_SpanCategorizer( | ||||
|         nlp.vocab, | ||||
|         suggester=suggester, | ||||
|         model=model, | ||||
|  | @ -127,8 +127,8 @@ class Ranges: | |||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| class SpanCategorizerExclusive(TrainablePipe): | ||||
|     """Pipeline component to label spans of text. | ||||
| class Exclusive_SpanCategorizer(SpanCategorizer): | ||||
|     """Pipeline component to label non-overlapping spans of text. | ||||
| 
 | ||||
|     DOCS: https://spacy.io/api/spancategorizerexclusive | ||||
|     """ | ||||
|  | @ -176,47 +176,6 @@ class SpanCategorizerExclusive(TrainablePipe): | |||
|         self.name = name | ||||
|         self.scorer = scorer | ||||
| 
 | ||||
|     @property | ||||
|     def key(self) -> str: | ||||
|         """Key of the doc.spans dict to save the spans under. During | ||||
|         initialization and training, the component will look for spans on the | ||||
|         reference document under the same key. | ||||
|         """ | ||||
|         return str(self.cfg["spans_key"]) | ||||
| 
 | ||||
|     def add_label(self, label: str) -> int: | ||||
|         """Add a new label to the pipe. | ||||
| 
 | ||||
|         label (str): The label to add. | ||||
|         RETURNS (int): 0 if label is already present, otherwise 1. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/spancategorizerexclusive#add_label | ||||
|         """ | ||||
|         if not isinstance(label, str): | ||||
|             raise ValueError(Errors.E187) | ||||
|         if label in self.labels: | ||||
|             return 0 | ||||
|         self._allow_extra_label() | ||||
|         self.cfg["labels"].append(label)  # type: ignore | ||||
|         self.vocab.strings.add(label) | ||||
|         return 1 | ||||
| 
 | ||||
|     @property | ||||
|     def labels(self) -> Tuple[str]: | ||||
|         """RETURNS (Tuple[str]): The labels currently added to the component. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/spancategorizerexclusive#labels | ||||
|         """ | ||||
|         return tuple(self.cfg["labels"])  # type: ignore | ||||
| 
 | ||||
|     @property | ||||
|     def label_data(self) -> List[str]: | ||||
|         """RETURNS (List[str]): Information about the component's labels. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/spancategorizerexclusive#label_data | ||||
|         """ | ||||
|         return list(self.labels) | ||||
| 
 | ||||
|     @property | ||||
|     def label_map(self) -> Dict[str, int]: | ||||
|         """RETURNS (Dict[str, int]): The label map.""" | ||||
|  | @ -232,37 +191,6 @@ class SpanCategorizerExclusive(TrainablePipe): | |||
|         """RETURNS (int): Number of labels including the negative label.""" | ||||
|         return len(self.label_data) + 1 | ||||
| 
 | ||||
|     def predict(self, docs: Iterable[Doc]): | ||||
|         """Apply the pipeline's model to a batch of docs, without modifying them. | ||||
| 
 | ||||
|         docs (Iterable[Doc]): The documents to predict. | ||||
|         RETURNS: The models prediction for each document. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/spancategorizerexclusive#predict | ||||
|         """ | ||||
|         indices = self.suggester(docs, ops=self.model.ops) | ||||
|         scores = self.model.predict((docs, indices))  # type: ignore | ||||
|         return indices, scores | ||||
| 
 | ||||
|     def set_candidates( | ||||
|         self, docs: Iterable[Doc], *, candidates_key: str = "candidates" | ||||
|     ) -> None: | ||||
|         """Use the spancat suggester to add a list of span candidates to a | ||||
|         list of docs. Intended to be used for debugging purposes. | ||||
| 
 | ||||
|         docs (Iterable[Doc]): The documents to modify. | ||||
|         candidates_key (str): Key of the Doc.spans dict to save the | ||||
|             candidate spans under. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/spancategorizerexclusive#set_candidates | ||||
|         """ | ||||
|         suggester_output = self.suggester(docs, ops=self.model.ops) | ||||
| 
 | ||||
|         for candidates, doc in zip(suggester_output, docs):  # type: ignore | ||||
|             doc.spans[candidates_key] = [] | ||||
|             for index in candidates.dataXd: | ||||
|                 doc.spans[candidates_key].append(doc[index[0] : index[1]]) | ||||
| 
 | ||||
|     def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None: | ||||
|         """Modify a batch of Doc objects, using pre-computed scores. | ||||
| 
 | ||||
|  | @ -286,47 +214,6 @@ class SpanCategorizerExclusive(TrainablePipe): | |||
|             ) | ||||
|             offset += indices.lengths[i] | ||||
| 
 | ||||
|     def update( | ||||
|         self, | ||||
|         examples: Iterable[Example], | ||||
|         *, | ||||
|         drop: float = 0.0, | ||||
|         sgd: Optional[Optimizer] = None, | ||||
|         losses: Optional[Dict[str, float]] = None, | ||||
|     ) -> Dict[str, float]: | ||||
|         """Learn from a batch of documents and gold-standard information, | ||||
|         updating the pipe's model. Delegates to predict and get_loss. | ||||
|         examples (Iterable[Example]): A batch of Example objects. | ||||
| 
 | ||||
|         drop (float): The dropout rate. | ||||
|         sgd (thinc.api.Optimizer): The optimizer. | ||||
|         losses (Dict[str, float]): Optional record of the loss during training. | ||||
|             Updated using the component name as the key. | ||||
|         RETURNS (Dict[str, float]): The updated losses dictionary. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/spancategorizerexclusive#update | ||||
|         """ | ||||
|         if losses is None: | ||||
|             losses = {} | ||||
|         losses.setdefault(self.name, 0.0) | ||||
|         validate_examples(examples, "SpanCategorizer.update") | ||||
|         self._validate_categories(examples) | ||||
|         if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): | ||||
|             # Handle cases where there are no tokens in any docs. | ||||
|             return losses | ||||
|         docs = [eg.predicted for eg in examples] | ||||
|         spans = self.suggester(docs, ops=self.model.ops) | ||||
|         if spans.lengths.sum() == 0: | ||||
|             return losses | ||||
|         set_dropout_rate(self.model, drop) | ||||
|         scores, backprop_scores = self.model.begin_update((docs, spans)) | ||||
|         loss, d_scores = self.get_loss(examples, (spans, scores)) | ||||
|         backprop_scores(d_scores)  # type: ignore | ||||
|         if sgd is not None: | ||||
|             self.finish_update(sgd) | ||||
|         losses[self.name] += loss | ||||
|         return losses | ||||
| 
 | ||||
|     def get_loss( | ||||
|         self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Floats2d] | ||||
|     ) -> Tuple[float, float]: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user