mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Allow pipeline components to set default scores and weights
This commit is contained in:
		
							parent
							
								
									787d066e22
								
							
						
					
					
						commit
						2470486543
					
				|  | @ -34,8 +34,8 @@ seed = 0 | ||||||
| accumulate_gradient = 1 | accumulate_gradient = 1 | ||||||
| use_pytorch_for_gpu_memory = false | use_pytorch_for_gpu_memory = false | ||||||
| # Control how scores are printed and checkpoints are evaluated. | # Control how scores are printed and checkpoints are evaluated. | ||||||
| scores = ["speed", "tag_acc", "dep_uas", "dep_las", "ents_f"] | scores = ["token_acc", "speed"] | ||||||
| score_weights = {"tag_acc": 0.2, "dep_las": 0.4, "ents_f": 0.4} | score_weights = {} | ||||||
| # These settings are invalid for the transformer models. | # These settings are invalid for the transformer models. | ||||||
| init_tok2vec = null | init_tok2vec = null | ||||||
| discard_oversize = false | discard_oversize = false | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs | ||||||
| from .gold import Example | from .gold import Example | ||||||
| from .scorer import Scorer | from .scorer import Scorer | ||||||
| from .util import link_vectors_to_models, create_default_optimizer, registry | from .util import link_vectors_to_models, create_default_optimizer, registry | ||||||
| from .util import SimpleFrozenDict | from .util import SimpleFrozenDict, combine_score_weights | ||||||
| from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS | from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS | ||||||
| from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES | from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES | ||||||
| from .lang.punctuation import TOKENIZER_INFIXES | from .lang.punctuation import TOKENIZER_INFIXES | ||||||
|  | @ -218,16 +218,24 @@ class Language: | ||||||
|     @property |     @property | ||||||
|     def config(self) -> Config: |     def config(self) -> Config: | ||||||
|         self._config.setdefault("nlp", {}) |         self._config.setdefault("nlp", {}) | ||||||
|  |         self._config.setdefault("training", {}) | ||||||
|         self._config["nlp"]["lang"] = self.lang |         self._config["nlp"]["lang"] = self.lang | ||||||
|         # We're storing the filled config for each pipeline component and so |         # We're storing the filled config for each pipeline component and so | ||||||
|         # we can populate the config again later |         # we can populate the config again later | ||||||
|         pipeline = {} |         pipeline = {} | ||||||
|  |         scores = self._config["training"].get("scores", []) | ||||||
|  |         score_weights = [] | ||||||
|         for pipe_name in self.pipe_names: |         for pipe_name in self.pipe_names: | ||||||
|             pipe_meta = self.get_pipe_meta(pipe_name) |             pipe_meta = self.get_pipe_meta(pipe_name) | ||||||
|             pipe_config = self.get_pipe_config(pipe_name) |             pipe_config = self.get_pipe_config(pipe_name) | ||||||
|             pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} |             pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} | ||||||
|  |             scores.extend(pipe_meta.scores) | ||||||
|  |             if pipe_meta.score_weights: | ||||||
|  |                 score_weights.append(pipe_meta.score_weights) | ||||||
|         self._config["nlp"]["pipeline"] = self.pipe_names |         self._config["nlp"]["pipeline"] = self.pipe_names | ||||||
|         self._config["components"] = pipeline |         self._config["components"] = pipeline | ||||||
|  |         self._config["training"]["scores"] = list(scores) | ||||||
|  |         self._config["training"]["score_weights"] = combine_score_weights(score_weights) | ||||||
|         if not srsly.is_json_serializable(self._config): |         if not srsly.is_json_serializable(self._config): | ||||||
|             raise ValueError(Errors.E961.format(config=self._config)) |             raise ValueError(Errors.E961.format(config=self._config)) | ||||||
|         return self._config |         return self._config | ||||||
|  | @ -348,6 +356,8 @@ class Language: | ||||||
|         assigns: Iterable[str] = tuple(), |         assigns: Iterable[str] = tuple(), | ||||||
|         requires: Iterable[str] = tuple(), |         requires: Iterable[str] = tuple(), | ||||||
|         retokenizes: bool = False, |         retokenizes: bool = False, | ||||||
|  |         scores: Iterable[str] = tuple(), | ||||||
|  |         score_weights: Dict[str, float] = SimpleFrozenDict(), | ||||||
|         func: Optional[Callable] = None, |         func: Optional[Callable] = None, | ||||||
|     ) -> Callable: |     ) -> Callable: | ||||||
|         """Register a new pipeline component factory. Can be used as a decorator |         """Register a new pipeline component factory. Can be used as a decorator | ||||||
|  | @ -393,6 +403,8 @@ class Language: | ||||||
|                 default_config=default_config, |                 default_config=default_config, | ||||||
|                 assigns=validate_attrs(assigns), |                 assigns=validate_attrs(assigns), | ||||||
|                 requires=validate_attrs(requires), |                 requires=validate_attrs(requires), | ||||||
|  |                 scores=scores, | ||||||
|  |                 score_weights=score_weights, | ||||||
|                 retokenizes=retokenizes, |                 retokenizes=retokenizes, | ||||||
|             ) |             ) | ||||||
|             cls.set_factory_meta(name, factory_meta) |             cls.set_factory_meta(name, factory_meta) | ||||||
|  | @ -417,6 +429,8 @@ class Language: | ||||||
|         assigns: Iterable[str] = tuple(), |         assigns: Iterable[str] = tuple(), | ||||||
|         requires: Iterable[str] = tuple(), |         requires: Iterable[str] = tuple(), | ||||||
|         retokenizes: bool = False, |         retokenizes: bool = False, | ||||||
|  |         scores: Iterable[str] = tuple(), | ||||||
|  |         score_weights: Dict[str, float] = SimpleFrozenDict(), | ||||||
|         func: Optional[Callable[[Doc], Doc]] = None, |         func: Optional[Callable[[Doc], Doc]] = None, | ||||||
|     ) -> Callable: |     ) -> Callable: | ||||||
|         """Register a new pipeline component. Can be used for stateless function |         """Register a new pipeline component. Can be used for stateless function | ||||||
|  | @ -450,6 +464,8 @@ class Language: | ||||||
|                 assigns=assigns, |                 assigns=assigns, | ||||||
|                 requires=requires, |                 requires=requires, | ||||||
|                 retokenizes=retokenizes, |                 retokenizes=retokenizes, | ||||||
|  |                 scores=scores, | ||||||
|  |                 score_weights=score_weights, | ||||||
|                 func=factory_func, |                 func=factory_func, | ||||||
|             ) |             ) | ||||||
|             return component_func |             return component_func | ||||||
|  | @ -1484,6 +1500,8 @@ class FactoryMeta: | ||||||
|     assigns: Iterable[str] = tuple() |     assigns: Iterable[str] = tuple() | ||||||
|     requires: Iterable[str] = tuple() |     requires: Iterable[str] = tuple() | ||||||
|     retokenizes: bool = False |     retokenizes: bool = False | ||||||
|  |     scores: Iterable[str] = tuple() | ||||||
|  |     score_weights: Dict[str, float] = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _get_config_overrides( | def _get_config_overrides( | ||||||
|  |  | ||||||
|  | @ -42,7 +42,9 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"] | ||||||
|         "learn_tokens": False, |         "learn_tokens": False, | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "model": DEFAULT_PARSER_MODEL, |         "model": DEFAULT_PARSER_MODEL, | ||||||
|     } |     }, | ||||||
|  |     scores=["dep_uas", "dep_las", "sents_f"], | ||||||
|  |     score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0}, | ||||||
| ) | ) | ||||||
| def make_parser( | def make_parser( | ||||||
|     nlp: Language, |     nlp: Language, | ||||||
|  |  | ||||||
|  | @ -40,7 +40,9 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"] | ||||||
|         "learn_tokens": False, |         "learn_tokens": False, | ||||||
|         "min_action_freq": 30, |         "min_action_freq": 30, | ||||||
|         "model": DEFAULT_NER_MODEL, |         "model": DEFAULT_NER_MODEL, | ||||||
|     } |     }, | ||||||
|  |     scores=["ents_f", "ents_r", "ents_p"], | ||||||
|  |     score_weights={"ents_f": 1.0, "ents_r": 0.0, "ents_p": 0.0}, | ||||||
| ) | ) | ||||||
| def make_ner( | def make_ner( | ||||||
|     nlp: Language, |     nlp: Language, | ||||||
|  |  | ||||||
|  | @ -33,7 +33,9 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"] | ||||||
| @Language.factory( | @Language.factory( | ||||||
|     "senter", |     "senter", | ||||||
|     assigns=["token.is_sent_start"], |     assigns=["token.is_sent_start"], | ||||||
|     default_config={"model": DEFAULT_SENTER_MODEL} |     default_config={"model": DEFAULT_SENTER_MODEL}, | ||||||
|  |     scores=["sents_p", "sents_r", "sents_f"], | ||||||
|  |     score_weights={"sents_p": 0.0, "sents_r": 0.0, "sents_f": 1.0}, | ||||||
| ) | ) | ||||||
| def make_senter(nlp: Language, name: str, model: Model): | def make_senter(nlp: Language, name: str, model: Model): | ||||||
|     return SentenceRecognizer(nlp.vocab, model, name) |     return SentenceRecognizer(nlp.vocab, model, name) | ||||||
|  |  | ||||||
|  | @ -39,7 +39,9 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"] | ||||||
| @Language.factory( | @Language.factory( | ||||||
|     "tagger", |     "tagger", | ||||||
|     assigns=["token.tag"], |     assigns=["token.tag"], | ||||||
|     default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False} |     default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False}, | ||||||
|  |     scores=["tag_acc", "pos_acc"], | ||||||
|  |     score_weights={"tag_acc": 0.5, "pos_acc": 0.5}, | ||||||
| ) | ) | ||||||
| def make_tagger(nlp: Language, name: str, model: Model, set_morphology: bool): | def make_tagger(nlp: Language, name: str, model: Model, set_morphology: bool): | ||||||
|     return Tagger(nlp.vocab, model, name, set_morphology=set_morphology) |     return Tagger(nlp.vocab, model, name, set_morphology=set_morphology) | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ from spacy.language import Language | ||||||
| from spacy.lang.en import English | from spacy.lang.en import English | ||||||
| from spacy.lang.de import German | from spacy.lang.de import German | ||||||
| from spacy.tokens import Doc | from spacy.tokens import Doc | ||||||
| from spacy.util import registry, SimpleFrozenDict | from spacy.util import registry, SimpleFrozenDict, combine_score_weights | ||||||
| from thinc.api import Model, Linear | from thinc.api import Model, Linear | ||||||
| from thinc.config import ConfigValidationError | from thinc.config import ConfigValidationError | ||||||
| from pydantic import StrictInt, StrictStr | from pydantic import StrictInt, StrictStr | ||||||
|  | @ -328,3 +328,52 @@ def test_language_factories_invalid(): | ||||||
|     assert len(nlp.factories) |     assert len(nlp.factories) | ||||||
|     with pytest.raises(NotImplementedError): |     with pytest.raises(NotImplementedError): | ||||||
|         nlp.factories["foo"] = "bar" |         nlp.factories["foo"] = "bar" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "weights,expected", | ||||||
|  |     [ | ||||||
|  |         ([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}), | ||||||
|  |         ([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}), | ||||||
|  |         ( | ||||||
|  |             [{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}], | ||||||
|  |             {"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17}, | ||||||
|  |         ), | ||||||
|  |         ( | ||||||
|  |             [{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}], | ||||||
|  |             {"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25}, | ||||||
|  |         ), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_language_factories_combine_score_weights(weights, expected): | ||||||
|  |     result = combine_score_weights(weights) | ||||||
|  |     assert sum(result.values()) in (0.99, 1.0) | ||||||
|  |     assert result == expected | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_language_factories_scores(): | ||||||
|  |     name = "test_language_factories_scores" | ||||||
|  |     func = lambda doc: doc | ||||||
|  |     scores1 = ["a1", "a2"] | ||||||
|  |     weights1 = {"a1": 0.5, "a2": 0.5} | ||||||
|  |     scores2 = ["b1", "b2", "b3"] | ||||||
|  |     weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1} | ||||||
|  |     Language.component( | ||||||
|  |         f"{name}1", scores=scores1, score_weights=weights1, func=func, | ||||||
|  |     ) | ||||||
|  |     Language.component( | ||||||
|  |         f"{name}2", scores=scores2, score_weights=weights2, func=func, | ||||||
|  |     ) | ||||||
|  |     meta1 = Language.get_factory_meta(f"{name}1") | ||||||
|  |     assert meta1.scores == scores1 | ||||||
|  |     assert meta1.score_weights == weights1 | ||||||
|  |     meta2 = Language.get_factory_meta(f"{name}2") | ||||||
|  |     assert meta2.scores == scores2 | ||||||
|  |     assert meta2.score_weights == weights2 | ||||||
|  |     nlp = Language(config={"training": {"scores": ["speed"], "score_weights": {}}}) | ||||||
|  |     nlp.add_pipe(f"{name}1") | ||||||
|  |     nlp.add_pipe(f"{name}2") | ||||||
|  |     cfg = nlp.config["training"] | ||||||
|  |     assert cfg["scores"] == ["speed", *scores1, *scores2] | ||||||
|  |     expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05} | ||||||
|  |     assert cfg["score_weights"] == expected_weights | ||||||
|  |  | ||||||
|  | @ -1130,6 +1130,23 @@ def get_arg_names(func: Callable) -> List[str]: | ||||||
|     return list(set([*argspec.args, *argspec.kwonlyargs])) |     return list(set([*argspec.args, *argspec.kwonlyargs])) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]: | ||||||
|  |     """Combine and normalize score weights defined by components, e.g. | ||||||
|  |     {"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}. | ||||||
|  | 
 | ||||||
|  |     weights (List[dict]): The weights defined by the components. | ||||||
|  |     RETURNS (Dict[str, float]): The combined and normalized weights. | ||||||
|  |     """ | ||||||
|  |     result = {} | ||||||
|  |     for w_dict in weights: | ||||||
|  |         # We need to account for weights that don't sum to 1.0 and normalize the | ||||||
|  |         # score weights accordingly, then divide score by the number of components | ||||||
|  |         total = sum([w for w in w_dict.values()]) | ||||||
|  |         for key, value in w_dict.items(): | ||||||
|  |             result[key] = round(value / total / len(weights), 2) | ||||||
|  |     return result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class DummyTokenizer: | class DummyTokenizer: | ||||||
|     # add dummy methods for to_bytes, from_bytes, to_disk and from_disk to |     # add dummy methods for to_bytes, from_bytes, to_disk and from_disk to | ||||||
|     # allow serialization (see #1557) |     # allow serialization (see #1557) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user