mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
WIP: Various small training changes (#6818)
* Allow output_path to be None during training
* Fix cat scoring (?)
* Improve error message for weighted None score
* Improve messages
So we can call this in other places etc.
* FIx output path check
* Use latest wasabi
* Revert "Improve error message for weighted None score"
This reverts commit 7059926763
.
* Exclude None scores from final score by default
It's otherwise very difficult to keep track of the score weights if we modify a config programmatically, source components etc.
* Update warnings and use logger.warning
This commit is contained in:
parent
f049df1715
commit
c0926c9088
|
@ -6,7 +6,7 @@ thinc>=8.0.0,<8.1.0
|
|||
blis>=0.4.0,<0.8.0
|
||||
ml_datasets==0.2.0a0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.8.0,<1.1.0
|
||||
wasabi>=0.8.1,<1.1.0
|
||||
srsly>=2.3.0,<3.0.0
|
||||
catalogue>=2.0.1,<2.1.0
|
||||
typer>=0.3.0,<0.4.0
|
||||
|
|
|
@ -43,7 +43,7 @@ install_requires =
|
|||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.0,<8.1.0
|
||||
blis>=0.4.0,<0.8.0
|
||||
wasabi>=0.8.0,<1.1.0
|
||||
wasabi>=0.8.1,<1.1.0
|
||||
srsly>=2.3.0,<3.0.0
|
||||
catalogue>=2.0.1,<2.1.0
|
||||
typer>=0.3.0,<0.4.0
|
||||
|
|
|
@ -112,6 +112,6 @@ def init_labels_cli(
|
|||
if getattr(component, "label_data", None) is not None:
|
||||
output_file = output_path / f"{name}.json"
|
||||
srsly.write_json(output_file, component.label_data)
|
||||
msg.good(f"Saving {name} labels to {output_file}")
|
||||
msg.good(f"Saving label data for component '{name}' to {output_file}")
|
||||
else:
|
||||
msg.info(f"No labels found for {name}")
|
||||
msg.info(f"No label data found for component '{name}'")
|
||||
|
|
|
@ -79,6 +79,13 @@ class Warnings:
|
|||
"attribute or operator.")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
W086 = ("Component '{listener}' will be (re)trained, but it needs the component "
|
||||
"'{name}' which is frozen. You should either freeze both, or neither "
|
||||
"of the two.")
|
||||
W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' "
|
||||
"depends on it and is frozen. This means that the performance of "
|
||||
"'{listener}' will be degraded. You should either freeze both, or "
|
||||
"neither of the two.")
|
||||
W088 = ("The pipeline component {name} implements a `begin_training` "
|
||||
"method, which won't be called by spaCy. As of v3.0, `begin_training` "
|
||||
"has been renamed to `initialize`, so you likely want to rename the "
|
||||
|
|
|
@ -461,7 +461,7 @@ class Scorer:
|
|||
if gold_score is not None and gold_score > 0:
|
||||
f_per_type[gold_label].fn += 1
|
||||
elif pred_cats:
|
||||
pred_label, pred_score = max(pred_cats, key=lambda it: it[1])
|
||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||
if pred_score >= threshold:
|
||||
f_per_type[pred_label].fp += 1
|
||||
micro_prf = PRFScore()
|
||||
|
|
|
@ -11,7 +11,7 @@ import tqdm
|
|||
|
||||
from ..lookups import Lookups
|
||||
from ..vectors import Vectors
|
||||
from ..errors import Errors
|
||||
from ..errors import Errors, Warnings
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
|
||||
|
@ -71,15 +71,9 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
if getattr(proc, "listening_components", None):
|
||||
for listener in proc.listening_components:
|
||||
if listener in frozen_components and name not in frozen_components:
|
||||
logger.warn(f"Component '{name}' will be (re)trained, but the "
|
||||
f"'{listener}' depends on it and is frozen. This means "
|
||||
f"that the performance of the '{listener}' will be degraded. "
|
||||
f"You should either freeze both, or neither of the two.")
|
||||
|
||||
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||
if listener not in frozen_components and name in frozen_components:
|
||||
logger.warn(f"Component '{listener}' will be (re)trained, but it needs the "
|
||||
f"'{name}' which is frozen. "
|
||||
f"You should either freeze both, or neither of the two.")
|
||||
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
||||
return nlp
|
||||
|
||||
|
||||
|
|
|
@ -97,10 +97,11 @@ def train(
|
|||
try:
|
||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||
log_step(info if is_best_checkpoint is not None else None)
|
||||
if is_best_checkpoint is not None and output_path is not None:
|
||||
if is_best_checkpoint is not None:
|
||||
with nlp.select_pipes(disable=frozen_components):
|
||||
update_meta(T, nlp, info)
|
||||
save_checkpoint(is_best_checkpoint)
|
||||
if output_path is not None:
|
||||
save_checkpoint(is_best_checkpoint)
|
||||
except Exception as e:
|
||||
if output_path is not None:
|
||||
stdout.write(
|
||||
|
@ -113,7 +114,8 @@ def train(
|
|||
raise e
|
||||
finally:
|
||||
finalize_logger()
|
||||
save_checkpoint(False)
|
||||
if output_path is not None:
|
||||
save_checkpoint(False)
|
||||
# This will only run if we did't hit an error
|
||||
if optimizer.averages:
|
||||
nlp.use_params(optimizer.averages)
|
||||
|
@ -257,6 +259,7 @@ def create_evaluation_callback(
|
|||
weights = {key: value for key, value in weights.items() if value is not None}
|
||||
|
||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||
nonlocal weights
|
||||
try:
|
||||
scores = nlp.evaluate(dev_corpus(nlp))
|
||||
except KeyError as e:
|
||||
|
@ -264,6 +267,8 @@ def create_evaluation_callback(
|
|||
# Calculate a weighted sum based on score_weights for the main score.
|
||||
# We can only consider scores that are ints/floats, not dicts like
|
||||
# entity scores per type etc.
|
||||
scores = {key: value for key, value in scores.items() if value is not None}
|
||||
weights = {key: value for key, value in weights.items() if key in scores}
|
||||
for key, value in scores.items():
|
||||
if key in weights and not isinstance(value, (int, float)):
|
||||
raise ValueError(Errors.E915.format(name=key, score_type=type(value)))
|
||||
|
|
Loading…
Reference in New Issue
Block a user