WIP: Various small training changes (#6818)

* Allow output_path to be None during training

* Fix cat scoring (?)

* Improve error message for weighted None score

* Improve messages

So we can call this in other places etc.

* FIx output path check

* Use latest wasabi

* Revert "Improve error message for weighted None score"

This reverts commit 7059926763.

* Exclude None scores from final score by default

It's otherwise very difficult to keep track of the score weights if we modify a config programmatically, source components etc.

* Update warnings and use logger.warning
This commit is contained in:
Ines Montani 2021-01-26 14:51:52 +11:00 committed by GitHub
parent f049df1715
commit c0926c9088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 23 additions and 17 deletions

View File

@ -6,7 +6,7 @@ thinc>=8.0.0,<8.1.0
blis>=0.4.0,<0.8.0 blis>=0.4.0,<0.8.0
ml_datasets==0.2.0a0 ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.8.0,<1.1.0 wasabi>=0.8.1,<1.1.0
srsly>=2.3.0,<3.0.0 srsly>=2.3.0,<3.0.0
catalogue>=2.0.1,<2.1.0 catalogue>=2.0.1,<2.1.0
typer>=0.3.0,<0.4.0 typer>=0.3.0,<0.4.0

View File

@ -43,7 +43,7 @@ install_requires =
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0,<8.1.0 thinc>=8.0.0,<8.1.0
blis>=0.4.0,<0.8.0 blis>=0.4.0,<0.8.0
wasabi>=0.8.0,<1.1.0 wasabi>=0.8.1,<1.1.0
srsly>=2.3.0,<3.0.0 srsly>=2.3.0,<3.0.0
catalogue>=2.0.1,<2.1.0 catalogue>=2.0.1,<2.1.0
typer>=0.3.0,<0.4.0 typer>=0.3.0,<0.4.0

View File

@ -112,6 +112,6 @@ def init_labels_cli(
if getattr(component, "label_data", None) is not None: if getattr(component, "label_data", None) is not None:
output_file = output_path / f"{name}.json" output_file = output_path / f"{name}.json"
srsly.write_json(output_file, component.label_data) srsly.write_json(output_file, component.label_data)
msg.good(f"Saving {name} labels to {output_file}") msg.good(f"Saving label data for component '{name}' to {output_file}")
else: else:
msg.info(f"No labels found for {name}") msg.info(f"No label data found for component '{name}'")

View File

@ -79,6 +79,13 @@ class Warnings:
"attribute or operator.") "attribute or operator.")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
W086 = ("Component '{listener}' will be (re)trained, but it needs the component "
"'{name}' which is frozen. You should either freeze both, or neither "
"of the two.")
W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' "
"depends on it and is frozen. This means that the performance of "
"'{listener}' will be degraded. You should either freeze both, or "
"neither of the two.")
W088 = ("The pipeline component {name} implements a `begin_training` " W088 = ("The pipeline component {name} implements a `begin_training` "
"method, which won't be called by spaCy. As of v3.0, `begin_training` " "method, which won't be called by spaCy. As of v3.0, `begin_training` "
"has been renamed to `initialize`, so you likely want to rename the " "has been renamed to `initialize`, so you likely want to rename the "

View File

@ -461,7 +461,7 @@ class Scorer:
if gold_score is not None and gold_score > 0: if gold_score is not None and gold_score > 0:
f_per_type[gold_label].fn += 1 f_per_type[gold_label].fn += 1
elif pred_cats: elif pred_cats:
pred_label, pred_score = max(pred_cats, key=lambda it: it[1]) pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
if pred_score >= threshold: if pred_score >= threshold:
f_per_type[pred_label].fp += 1 f_per_type[pred_label].fp += 1
micro_prf = PRFScore() micro_prf = PRFScore()

View File

@ -11,7 +11,7 @@ import tqdm
from ..lookups import Lookups from ..lookups import Lookups
from ..vectors import Vectors from ..vectors import Vectors
from ..errors import Errors from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
@ -71,15 +71,9 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
if getattr(proc, "listening_components", None): if getattr(proc, "listening_components", None):
for listener in proc.listening_components: for listener in proc.listening_components:
if listener in frozen_components and name not in frozen_components: if listener in frozen_components and name not in frozen_components:
logger.warn(f"Component '{name}' will be (re)trained, but the " logger.warning(Warnings.W087.format(name=name, listener=listener))
f"'{listener}' depends on it and is frozen. This means "
f"that the performance of the '{listener}' will be degraded. "
f"You should either freeze both, or neither of the two.")
if listener not in frozen_components and name in frozen_components: if listener not in frozen_components and name in frozen_components:
logger.warn(f"Component '{listener}' will be (re)trained, but it needs the " logger.warning(Warnings.W086.format(name=name, listener=listener))
f"'{name}' which is frozen. "
f"You should either freeze both, or neither of the two.")
return nlp return nlp

View File

@ -97,10 +97,11 @@ def train(
try: try:
for batch, info, is_best_checkpoint in training_step_iterator: for batch, info, is_best_checkpoint in training_step_iterator:
log_step(info if is_best_checkpoint is not None else None) log_step(info if is_best_checkpoint is not None else None)
if is_best_checkpoint is not None and output_path is not None: if is_best_checkpoint is not None:
with nlp.select_pipes(disable=frozen_components): with nlp.select_pipes(disable=frozen_components):
update_meta(T, nlp, info) update_meta(T, nlp, info)
save_checkpoint(is_best_checkpoint) if output_path is not None:
save_checkpoint(is_best_checkpoint)
except Exception as e: except Exception as e:
if output_path is not None: if output_path is not None:
stdout.write( stdout.write(
@ -113,7 +114,8 @@ def train(
raise e raise e
finally: finally:
finalize_logger() finalize_logger()
save_checkpoint(False) if output_path is not None:
save_checkpoint(False)
# This will only run if we did't hit an error # This will only run if we did't hit an error
if optimizer.averages: if optimizer.averages:
nlp.use_params(optimizer.averages) nlp.use_params(optimizer.averages)
@ -257,6 +259,7 @@ def create_evaluation_callback(
weights = {key: value for key, value in weights.items() if value is not None} weights = {key: value for key, value in weights.items() if value is not None}
def evaluate() -> Tuple[float, Dict[str, float]]: def evaluate() -> Tuple[float, Dict[str, float]]:
nonlocal weights
try: try:
scores = nlp.evaluate(dev_corpus(nlp)) scores = nlp.evaluate(dev_corpus(nlp))
except KeyError as e: except KeyError as e:
@ -264,6 +267,8 @@ def create_evaluation_callback(
# Calculate a weighted sum based on score_weights for the main score. # Calculate a weighted sum based on score_weights for the main score.
# We can only consider scores that are ints/floats, not dicts like # We can only consider scores that are ints/floats, not dicts like
# entity scores per type etc. # entity scores per type etc.
scores = {key: value for key, value in scores.items() if value is not None}
weights = {key: value for key, value in weights.items() if key in scores}
for key, value in scores.items(): for key, value in scores.items():
if key in weights and not isinstance(value, (int, float)): if key in weights and not isinstance(value, (int, float)):
raise ValueError(Errors.E915.format(name=key, score_type=type(value))) raise ValueError(Errors.E915.format(name=key, score_type=type(value)))