mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Allow pipeline components to set default scores and weights
This commit is contained in:
parent
787d066e22
commit
2470486543
|
@ -34,8 +34,8 @@ seed = 0
|
|||
accumulate_gradient = 1
|
||||
use_pytorch_for_gpu_memory = false
|
||||
# Control how scores are printed and checkpoints are evaluated.
|
||||
scores = ["speed", "tag_acc", "dep_uas", "dep_las", "ents_f"]
|
||||
score_weights = {"tag_acc": 0.2, "dep_las": 0.4, "ents_f": 0.4}
|
||||
scores = ["token_acc", "speed"]
|
||||
score_weights = {}
|
||||
# These settings are invalid for the transformer models.
|
||||
init_tok2vec = null
|
||||
discard_oversize = false
|
||||
|
|
|
@ -21,7 +21,7 @@ from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
|
|||
from .gold import Example
|
||||
from .scorer import Scorer
|
||||
from .util import link_vectors_to_models, create_default_optimizer, registry
|
||||
from .util import SimpleFrozenDict
|
||||
from .util import SimpleFrozenDict, combine_score_weights
|
||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||
from .lang.punctuation import TOKENIZER_INFIXES
|
||||
|
@ -218,16 +218,24 @@ class Language:
|
|||
@property
|
||||
def config(self) -> Config:
|
||||
self._config.setdefault("nlp", {})
|
||||
self._config.setdefault("training", {})
|
||||
self._config["nlp"]["lang"] = self.lang
|
||||
# We're storing the filled config for each pipeline component and so
|
||||
# we can populate the config again later
|
||||
pipeline = {}
|
||||
scores = self._config["training"].get("scores", [])
|
||||
score_weights = []
|
||||
for pipe_name in self.pipe_names:
|
||||
pipe_meta = self.get_pipe_meta(pipe_name)
|
||||
pipe_config = self.get_pipe_config(pipe_name)
|
||||
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
|
||||
scores.extend(pipe_meta.scores)
|
||||
if pipe_meta.score_weights:
|
||||
score_weights.append(pipe_meta.score_weights)
|
||||
self._config["nlp"]["pipeline"] = self.pipe_names
|
||||
self._config["components"] = pipeline
|
||||
self._config["training"]["scores"] = list(scores)
|
||||
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
|
||||
if not srsly.is_json_serializable(self._config):
|
||||
raise ValueError(Errors.E961.format(config=self._config))
|
||||
return self._config
|
||||
|
@ -348,6 +356,8 @@ class Language:
|
|||
assigns: Iterable[str] = tuple(),
|
||||
requires: Iterable[str] = tuple(),
|
||||
retokenizes: bool = False,
|
||||
scores: Iterable[str] = tuple(),
|
||||
score_weights: Dict[str, float] = SimpleFrozenDict(),
|
||||
func: Optional[Callable] = None,
|
||||
) -> Callable:
|
||||
"""Register a new pipeline component factory. Can be used as a decorator
|
||||
|
@ -393,6 +403,8 @@ class Language:
|
|||
default_config=default_config,
|
||||
assigns=validate_attrs(assigns),
|
||||
requires=validate_attrs(requires),
|
||||
scores=scores,
|
||||
score_weights=score_weights,
|
||||
retokenizes=retokenizes,
|
||||
)
|
||||
cls.set_factory_meta(name, factory_meta)
|
||||
|
@ -417,6 +429,8 @@ class Language:
|
|||
assigns: Iterable[str] = tuple(),
|
||||
requires: Iterable[str] = tuple(),
|
||||
retokenizes: bool = False,
|
||||
scores: Iterable[str] = tuple(),
|
||||
score_weights: Dict[str, float] = SimpleFrozenDict(),
|
||||
func: Optional[Callable[[Doc], Doc]] = None,
|
||||
) -> Callable:
|
||||
"""Register a new pipeline component. Can be used for stateless function
|
||||
|
@ -450,6 +464,8 @@ class Language:
|
|||
assigns=assigns,
|
||||
requires=requires,
|
||||
retokenizes=retokenizes,
|
||||
scores=scores,
|
||||
score_weights=score_weights,
|
||||
func=factory_func,
|
||||
)
|
||||
return component_func
|
||||
|
@ -1484,6 +1500,8 @@ class FactoryMeta:
|
|||
assigns: Iterable[str] = tuple()
|
||||
requires: Iterable[str] = tuple()
|
||||
retokenizes: bool = False
|
||||
scores: Iterable[str] = tuple()
|
||||
score_weights: Dict[str, float] = None
|
||||
|
||||
|
||||
def _get_config_overrides(
|
||||
|
|
|
@ -42,7 +42,9 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"model": DEFAULT_PARSER_MODEL,
|
||||
}
|
||||
},
|
||||
scores=["dep_uas", "dep_las", "sents_f"],
|
||||
score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0},
|
||||
)
|
||||
def make_parser(
|
||||
nlp: Language,
|
||||
|
|
|
@ -40,7 +40,9 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"model": DEFAULT_NER_MODEL,
|
||||
}
|
||||
},
|
||||
scores=["ents_f", "ents_r", "ents_p"],
|
||||
score_weights={"ents_f": 1.0, "ents_r": 0.0, "ents_p": 0.0},
|
||||
)
|
||||
def make_ner(
|
||||
nlp: Language,
|
||||
|
|
|
@ -33,7 +33,9 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"senter",
|
||||
assigns=["token.is_sent_start"],
|
||||
default_config={"model": DEFAULT_SENTER_MODEL}
|
||||
default_config={"model": DEFAULT_SENTER_MODEL},
|
||||
scores=["sents_p", "sents_r", "sents_f"],
|
||||
score_weights={"sents_p": 0.0, "sents_r": 0.0, "sents_f": 1.0},
|
||||
)
|
||||
def make_senter(nlp: Language, name: str, model: Model):
|
||||
return SentenceRecognizer(nlp.vocab, model, name)
|
||||
|
|
|
@ -39,7 +39,9 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False}
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False},
|
||||
scores=["tag_acc", "pos_acc"],
|
||||
score_weights={"tag_acc": 0.5, "pos_acc": 0.5},
|
||||
)
|
||||
def make_tagger(nlp: Language, name: str, model: Model, set_morphology: bool):
|
||||
return Tagger(nlp.vocab, model, name, set_morphology=set_morphology)
|
||||
|
|
|
@ -3,7 +3,7 @@ from spacy.language import Language
|
|||
from spacy.lang.en import English
|
||||
from spacy.lang.de import German
|
||||
from spacy.tokens import Doc
|
||||
from spacy.util import registry, SimpleFrozenDict
|
||||
from spacy.util import registry, SimpleFrozenDict, combine_score_weights
|
||||
from thinc.api import Model, Linear
|
||||
from thinc.config import ConfigValidationError
|
||||
from pydantic import StrictInt, StrictStr
|
||||
|
@ -328,3 +328,52 @@ def test_language_factories_invalid():
|
|||
assert len(nlp.factories)
|
||||
with pytest.raises(NotImplementedError):
|
||||
nlp.factories["foo"] = "bar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights,expected",
|
||||
[
|
||||
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}),
|
||||
([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}),
|
||||
(
|
||||
[{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}],
|
||||
{"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17},
|
||||
),
|
||||
(
|
||||
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
|
||||
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_language_factories_combine_score_weights(weights, expected):
|
||||
result = combine_score_weights(weights)
|
||||
assert sum(result.values()) in (0.99, 1.0)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_language_factories_scores():
|
||||
name = "test_language_factories_scores"
|
||||
func = lambda doc: doc
|
||||
scores1 = ["a1", "a2"]
|
||||
weights1 = {"a1": 0.5, "a2": 0.5}
|
||||
scores2 = ["b1", "b2", "b3"]
|
||||
weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1}
|
||||
Language.component(
|
||||
f"{name}1", scores=scores1, score_weights=weights1, func=func,
|
||||
)
|
||||
Language.component(
|
||||
f"{name}2", scores=scores2, score_weights=weights2, func=func,
|
||||
)
|
||||
meta1 = Language.get_factory_meta(f"{name}1")
|
||||
assert meta1.scores == scores1
|
||||
assert meta1.score_weights == weights1
|
||||
meta2 = Language.get_factory_meta(f"{name}2")
|
||||
assert meta2.scores == scores2
|
||||
assert meta2.score_weights == weights2
|
||||
nlp = Language(config={"training": {"scores": ["speed"], "score_weights": {}}})
|
||||
nlp.add_pipe(f"{name}1")
|
||||
nlp.add_pipe(f"{name}2")
|
||||
cfg = nlp.config["training"]
|
||||
assert cfg["scores"] == ["speed", *scores1, *scores2]
|
||||
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
|
||||
assert cfg["score_weights"] == expected_weights
|
||||
|
|
|
@ -1130,6 +1130,23 @@ def get_arg_names(func: Callable) -> List[str]:
|
|||
return list(set([*argspec.args, *argspec.kwonlyargs]))
|
||||
|
||||
|
||||
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
||||
"""Combine and normalize score weights defined by components, e.g.
|
||||
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
|
||||
|
||||
weights (List[dict]): The weights defined by the components.
|
||||
RETURNS (Dict[str, float]): The combined and normalized weights.
|
||||
"""
|
||||
result = {}
|
||||
for w_dict in weights:
|
||||
# We need to account for weights that don't sum to 1.0 and normalize the
|
||||
# score weights accordingly, then divide score by the number of components
|
||||
total = sum([w for w in w_dict.values()])
|
||||
for key, value in w_dict.items():
|
||||
result[key] = round(value / total / len(weights), 2)
|
||||
return result
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
|
||||
# allow serialization (see #1557)
|
||||
|
|
Loading…
Reference in New Issue
Block a user