WIP: Various small training changes (#6818)

* Allow output_path to be None during training

* Fix cat scoring (?)

* Improve error message for weighted None score

* Improve messages

So we can call this in other places etc.

* FIx output path check

* Use latest wasabi

* Revert "Improve error message for weighted None score"

This reverts commit 7059926763.

* Exclude None scores from final score by default

It's otherwise very difficult to keep track of the score weights if we modify a config programmatically, source components etc.

* Update warnings and use logger.warning
This commit is contained in:
Ines Montani 2021-01-26 14:51:52 +11:00 committed by GitHub
parent f049df1715
commit c0926c9088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 23 additions and 17 deletions

View File

@ -6,7 +6,7 @@ thinc>=8.0.0,<8.1.0
blis>=0.4.0,<0.8.0
ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.8.0,<1.1.0
wasabi>=0.8.1,<1.1.0
srsly>=2.3.0,<3.0.0
catalogue>=2.0.1,<2.1.0
typer>=0.3.0,<0.4.0

View File

@ -43,7 +43,7 @@ install_requires =
preshed>=3.0.2,<3.1.0
thinc>=8.0.0,<8.1.0
blis>=0.4.0,<0.8.0
wasabi>=0.8.0,<1.1.0
wasabi>=0.8.1,<1.1.0
srsly>=2.3.0,<3.0.0
catalogue>=2.0.1,<2.1.0
typer>=0.3.0,<0.4.0

View File

@ -112,6 +112,6 @@ def init_labels_cli(
if getattr(component, "label_data", None) is not None:
output_file = output_path / f"{name}.json"
srsly.write_json(output_file, component.label_data)
msg.good(f"Saving {name} labels to {output_file}")
msg.good(f"Saving label data for component '{name}' to {output_file}")
else:
msg.info(f"No labels found for {name}")
msg.info(f"No label data found for component '{name}'")

View File

@ -79,6 +79,13 @@ class Warnings:
"attribute or operator.")
# TODO: fix numbering after merging develop into master
W086 = ("Component '{listener}' will be (re)trained, but it needs the component "
"'{name}' which is frozen. You should either freeze both, or neither "
"of the two.")
W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' "
"depends on it and is frozen. This means that the performance of "
"'{listener}' will be degraded. You should either freeze both, or "
"neither of the two.")
W088 = ("The pipeline component {name} implements a `begin_training` "
"method, which won't be called by spaCy. As of v3.0, `begin_training` "
"has been renamed to `initialize`, so you likely want to rename the "

View File

@ -461,7 +461,7 @@ class Scorer:
if gold_score is not None and gold_score > 0:
f_per_type[gold_label].fn += 1
elif pred_cats:
pred_label, pred_score = max(pred_cats, key=lambda it: it[1])
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
if pred_score >= threshold:
f_per_type[pred_label].fp += 1
micro_prf = PRFScore()

View File

@ -11,7 +11,7 @@ import tqdm
from ..lookups import Lookups
from ..vectors import Vectors
from ..errors import Errors
from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
@ -71,15 +71,9 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
if getattr(proc, "listening_components", None):
for listener in proc.listening_components:
if listener in frozen_components and name not in frozen_components:
logger.warn(f"Component '{name}' will be (re)trained, but the "
f"'{listener}' depends on it and is frozen. This means "
f"that the performance of the '{listener}' will be degraded. "
f"You should either freeze both, or neither of the two.")
logger.warning(Warnings.W087.format(name=name, listener=listener))
if listener not in frozen_components and name in frozen_components:
logger.warn(f"Component '{listener}' will be (re)trained, but it needs the "
f"'{name}' which is frozen. "
f"You should either freeze both, or neither of the two.")
logger.warning(Warnings.W086.format(name=name, listener=listener))
return nlp

View File

@ -97,9 +97,10 @@ def train(
try:
for batch, info, is_best_checkpoint in training_step_iterator:
log_step(info if is_best_checkpoint is not None else None)
if is_best_checkpoint is not None and output_path is not None:
if is_best_checkpoint is not None:
with nlp.select_pipes(disable=frozen_components):
update_meta(T, nlp, info)
if output_path is not None:
save_checkpoint(is_best_checkpoint)
except Exception as e:
if output_path is not None:
@ -113,6 +114,7 @@ def train(
raise e
finally:
finalize_logger()
if output_path is not None:
save_checkpoint(False)
# This will only run if we did't hit an error
if optimizer.averages:
@ -257,6 +259,7 @@ def create_evaluation_callback(
weights = {key: value for key, value in weights.items() if value is not None}
def evaluate() -> Tuple[float, Dict[str, float]]:
nonlocal weights
try:
scores = nlp.evaluate(dev_corpus(nlp))
except KeyError as e:
@ -264,6 +267,8 @@ def create_evaluation_callback(
# Calculate a weighted sum based on score_weights for the main score.
# We can only consider scores that are ints/floats, not dicts like
# entity scores per type etc.
scores = {key: value for key, value in scores.items() if value is not None}
weights = {key: value for key, value in weights.items() if key in scores}
for key, value in scores.items():
if key in weights and not isinstance(value, (int, float)):
raise ValueError(Errors.E915.format(name=key, score_type=type(value)))