mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Make the Tagger neg_prefix configurable (#9802)
This commit is contained in:
		
							parent
							
								
									b56b9e7f31
								
							
						
					
					
						commit
						9964243eb2
					
				|  | @ -45,7 +45,7 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"] | |||
| @Language.factory( | ||||
|     "tagger", | ||||
|     assigns=["token.tag"], | ||||
|     default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}}, | ||||
|     default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"}, | ||||
|     default_score_weights={"tag_acc": 1.0}, | ||||
| ) | ||||
| def make_tagger( | ||||
|  | @ -54,6 +54,7 @@ def make_tagger( | |||
|     model: Model, | ||||
|     overwrite: bool, | ||||
|     scorer: Optional[Callable], | ||||
|     neg_prefix: str, | ||||
| ): | ||||
|     """Construct a part-of-speech tagger component. | ||||
| 
 | ||||
|  | @ -62,7 +63,7 @@ def make_tagger( | |||
|         in size, and be normalized as probabilities (all scores between 0 and 1, | ||||
|         with the rows summing to 1). | ||||
|     """ | ||||
|     return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer) | ||||
|     return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix) | ||||
| 
 | ||||
| 
 | ||||
| def tagger_score(examples, **kwargs): | ||||
|  | @ -87,6 +88,7 @@ class Tagger(TrainablePipe): | |||
|         *, | ||||
|         overwrite=BACKWARD_OVERWRITE, | ||||
|         scorer=tagger_score, | ||||
|         neg_prefix="!", | ||||
|     ): | ||||
|         """Initialize a part-of-speech tagger. | ||||
| 
 | ||||
|  | @ -103,7 +105,7 @@ class Tagger(TrainablePipe): | |||
|         self.model = model | ||||
|         self.name = name | ||||
|         self._rehearsal_model = None | ||||
|         cfg = {"labels": [], "overwrite": overwrite} | ||||
|         cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix} | ||||
|         self.cfg = dict(sorted(cfg.items())) | ||||
|         self.scorer = scorer | ||||
| 
 | ||||
|  | @ -253,7 +255,7 @@ class Tagger(TrainablePipe): | |||
|         DOCS: https://spacy.io/api/tagger#get_loss | ||||
|         """ | ||||
|         validate_examples(examples, "Tagger.get_loss") | ||||
|         loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix="!") | ||||
|         loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"]) | ||||
|         # Convert empty tag "" to missing value None so that both misaligned | ||||
|         # tokens and tokens with missing annotation have the default missing | ||||
|         # value None. | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user