mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Make training.loop return nlp object and path (#6520)
This commit is contained in:
parent
2c27093c5f
commit
6cfa66ed1c
|
@ -28,7 +28,7 @@ def train(
|
|||
use_gpu: int = -1,
|
||||
stdout: IO = sys.stdout,
|
||||
stderr: IO = sys.stderr,
|
||||
) -> None:
|
||||
) -> Tuple["Language", Optional[Path]]:
|
||||
"""Train a pipeline.
|
||||
|
||||
nlp (Language): The initialized nlp object with the full config.
|
||||
|
@ -40,7 +40,7 @@ def train(
|
|||
stderr (file): A second file-like object to write output messages. To disable
|
||||
printing, set to io.StringIO.
|
||||
|
||||
RETURNS (Path / None): The path to the final exported model.
|
||||
RETURNS (tuple): The final nlp object and the path to the exported model.
|
||||
"""
|
||||
# We use no_print here so we can respect the stdout/stderr options.
|
||||
msg = Printer(no_print=True)
|
||||
|
@ -105,17 +105,18 @@ def train(
|
|||
raise e
|
||||
finally:
|
||||
finalize_logger()
|
||||
if optimizer.averages:
|
||||
nlp.use_params(optimizer.averages)
|
||||
if output_path is not None:
|
||||
final_model_path = output_path / DIR_MODEL_LAST
|
||||
if optimizer.averages:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
nlp.to_disk(final_model_path)
|
||||
else:
|
||||
nlp.to_disk(final_model_path)
|
||||
nlp.to_disk(final_model_path)
|
||||
# This will only run if we don't hit an error
|
||||
stdout.write(
|
||||
msg.good("Saved pipeline to output directory", final_model_path) + "\n"
|
||||
)
|
||||
return (nlp, final_model_path)
|
||||
else:
|
||||
return (nlp, None)
|
||||
|
||||
|
||||
def train_while_improving(
|
||||
|
|
Loading…
Reference in New Issue
Block a user