mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Remove output dirs before training (#6204)
* Remove output dirs before training * Re-raise error if cleaning fails
This commit is contained in:
parent
3ee3649b52
commit
be99f1e4de
|
@ -456,6 +456,10 @@ class Errors:
|
||||||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E901 = ("Failed to remove existing output directory: {path}. If your "
|
||||||
|
"config and the components you train change between runs, a "
|
||||||
|
"non-empty output directory can lead to stale pipeline data. To "
|
||||||
|
"solve this, remove the existing directories in the output directory.")
|
||||||
E902 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
|
E902 = ("The sentence-per-line IOB/IOB2 file is not formatted correctly. "
|
||||||
"Try checking whitespace and delimiters. See "
|
"Try checking whitespace and delimiters. See "
|
||||||
"https://nightly.spacy.io/api/cli#convert")
|
"https://nightly.spacy.io/api/cli#convert")
|
||||||
|
|
|
@ -3,19 +3,24 @@ from typing import Optional, TYPE_CHECKING
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator
|
from thinc.api import Optimizer, Config, constant, fix_random_seed, set_gpu_allocator
|
||||||
|
from wasabi import Printer
|
||||||
import random
|
import random
|
||||||
import wasabi
|
|
||||||
import sys
|
import sys
|
||||||
|
import shutil
|
||||||
|
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..util import resolve_dot_names, registry
|
from ..util import resolve_dot_names, registry, logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
DIR_MODEL_BEST = "model-best"
|
||||||
|
DIR_MODEL_LAST = "model-last"
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
output_path: Optional[Path] = None,
|
output_path: Optional[Path] = None,
|
||||||
|
@ -38,7 +43,7 @@ def train(
|
||||||
RETURNS (Path / None): The path to the final exported model.
|
RETURNS (Path / None): The path to the final exported model.
|
||||||
"""
|
"""
|
||||||
# We use no_print here so we can respect the stdout/stderr options.
|
# We use no_print here so we can respect the stdout/stderr options.
|
||||||
msg = wasabi.Printer(no_print=True)
|
msg = Printer(no_print=True)
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
if config["training"]["seed"] is not None:
|
if config["training"]["seed"] is not None:
|
||||||
|
@ -69,6 +74,7 @@ def train(
|
||||||
eval_frequency=T["eval_frequency"],
|
eval_frequency=T["eval_frequency"],
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
)
|
)
|
||||||
|
clean_output_dir(output_path)
|
||||||
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
||||||
if frozen_components:
|
if frozen_components:
|
||||||
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
||||||
|
@ -83,7 +89,7 @@ def train(
|
||||||
update_meta(T, nlp, info)
|
update_meta(T, nlp, info)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
nlp = before_to_disk(nlp)
|
nlp = before_to_disk(nlp)
|
||||||
nlp.to_disk(output_path / "model-best")
|
nlp.to_disk(output_path / DIR_MODEL_BEST)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
# We don't want to swallow the traceback if we don't have a
|
# We don't want to swallow the traceback if we don't have a
|
||||||
|
@ -100,7 +106,7 @@ def train(
|
||||||
finally:
|
finally:
|
||||||
finalize_logger()
|
finalize_logger()
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
final_model_path = output_path / "model-last"
|
final_model_path = output_path / DIR_MODEL_LAST
|
||||||
if optimizer.averages:
|
if optimizer.averages:
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
nlp.to_disk(final_model_path)
|
nlp.to_disk(final_model_path)
|
||||||
|
@ -305,3 +311,19 @@ def create_before_to_disk_callback(
|
||||||
return modified_nlp
|
return modified_nlp
|
||||||
|
|
||||||
return before_to_disk
|
return before_to_disk
|
||||||
|
|
||||||
|
|
||||||
|
def clean_output_dir(path: Union[str, Path]) -> None:
|
||||||
|
"""Remove an existing output directory. Typically used to ensure that that
|
||||||
|
a directory like model-best and its contents aren't just being overwritten
|
||||||
|
by nlp.to_disk, which could preserve existing subdirectories (e.g.
|
||||||
|
components that don't exist anymore).
|
||||||
|
"""
|
||||||
|
if path is not None and path.exists():
|
||||||
|
for subdir in [path / DIR_MODEL_BEST, path / DIR_MODEL_LAST]:
|
||||||
|
if subdir.exists():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(str(subdir))
|
||||||
|
logger.debug(f"Removed existing output directory: {subdir}")
|
||||||
|
except Exception as e:
|
||||||
|
raise IOError(Errors.E901.format(path=path)) from e
|
||||||
|
|
Loading…
Reference in New Issue
Block a user