mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Remove output dirs before training (#6204)
* Remove output dirs before training * Re-raise error if cleaning fails
This commit is contained in:
parent
3ee3649b52
commit
be99f1e4de
|
@ -456,6 +456,10 @@ class Errors:
|
|||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E901 = ("Failed to remove existing output directory: {path}. If your "
|
||||
"config and the components you train change between runs, a "
|
||||
"non-empty output directory can lead to stale pipeline data. To "
|
||||
"solve this, remove the existing directories in the output directory.")
|
||||
E902 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
|
||||
"Try checking whitespace and delimiters. See "
|
||||
"https://nightly.spacy.io/api/cli#convert")
|
||||
|
|
|
@ -3,19 +3,24 @@ from typing import Optional, TYPE_CHECKING
|
|||
from pathlib import Path
|
||||
from timeit import default_timer as timer
|
||||
from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator
|
||||
from wasabi import Printer
|
||||
import random
|
||||
import wasabi
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
from .example import Example
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..errors import Errors
|
||||
from ..util import resolve_dot_names, registry
|
||||
from ..util import resolve_dot_names, registry, logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
DIR_MODEL_BEST = "model-best"
|
||||
DIR_MODEL_LAST = "model-last"
|
||||
|
||||
|
||||
def train(
|
||||
nlp: "Language",
|
||||
output_path: Optional[Path] = None,
|
||||
|
@ -38,7 +43,7 @@ def train(
|
|||
RETURNS (Path / None): The path to the final exported model.
|
||||
"""
|
||||
# We use no_print here so we can respect the stdout/stderr options.
|
||||
msg = wasabi.Printer(no_print=True)
|
||||
msg = Printer(no_print=True)
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
config = nlp.config.interpolate()
|
||||
if config["training"]["seed"] is not None:
|
||||
|
@ -69,6 +74,7 @@ def train(
|
|||
eval_frequency=T["eval_frequency"],
|
||||
exclude=frozen_components,
|
||||
)
|
||||
clean_output_dir(output_path)
|
||||
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
||||
if frozen_components:
|
||||
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
||||
|
@ -83,7 +89,7 @@ def train(
|
|||
update_meta(T, nlp, info)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
nlp = before_to_disk(nlp)
|
||||
nlp.to_disk(output_path / "model-best")
|
||||
nlp.to_disk(output_path / DIR_MODEL_BEST)
|
||||
except Exception as e:
|
||||
if output_path is not None:
|
||||
# We don't want to swallow the traceback if we don't have a
|
||||
|
@ -100,7 +106,7 @@ def train(
|
|||
finally:
|
||||
finalize_logger()
|
||||
if output_path is not None:
|
||||
final_model_path = output_path / "model-last"
|
||||
final_model_path = output_path / DIR_MODEL_LAST
|
||||
if optimizer.averages:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
nlp.to_disk(final_model_path)
|
||||
|
@ -305,3 +311,19 @@ def create_before_to_disk_callback(
|
|||
return modified_nlp
|
||||
|
||||
return before_to_disk
|
||||
|
||||
|
||||
def clean_output_dir(path: Union[str, Path]) -> None:
|
||||
"""Remove an existing output directory. Typically used to ensure that that
|
||||
a directory like model-best and its contents aren't just being overwritten
|
||||
by nlp.to_disk, which could preserve existing subdirectories (e.g.
|
||||
components that don't exist anymore).
|
||||
"""
|
||||
if path is not None and path.exists():
|
||||
for subdir in [path / DIR_MODEL_BEST, path / DIR_MODEL_LAST]:
|
||||
if subdir.exists():
|
||||
try:
|
||||
shutil.rmtree(str(subdir))
|
||||
logger.debug(f"Removed existing output directory: {subdir}")
|
||||
except Exception as e:
|
||||
raise IOError(Errors.E901.format(path=path)) from e
|
||||
|
|
Loading…
Reference in New Issue
Block a user