mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-27 04:13:41 +03:00
Make training.loop return nlp object and path (#6520)
This commit is contained in:
parent
2c27093c5f
commit
6cfa66ed1c
|
@ -28,7 +28,7 @@ def train(
|
||||||
use_gpu: int = -1,
|
use_gpu: int = -1,
|
||||||
stdout: IO = sys.stdout,
|
stdout: IO = sys.stdout,
|
||||||
stderr: IO = sys.stderr,
|
stderr: IO = sys.stderr,
|
||||||
) -> None:
|
) -> Tuple["Language", Optional[Path]]:
|
||||||
"""Train a pipeline.
|
"""Train a pipeline.
|
||||||
|
|
||||||
nlp (Language): The initialized nlp object with the full config.
|
nlp (Language): The initialized nlp object with the full config.
|
||||||
|
@ -40,7 +40,7 @@ def train(
|
||||||
stderr (file): A second file-like object to write output messages. To disable
|
stderr (file): A second file-like object to write output messages. To disable
|
||||||
printing, set to io.StringIO.
|
printing, set to io.StringIO.
|
||||||
|
|
||||||
RETURNS (Path / None): The path to the final exported model.
|
RETURNS (tuple): The final nlp object and the path to the exported model.
|
||||||
"""
|
"""
|
||||||
# We use no_print here so we can respect the stdout/stderr options.
|
# We use no_print here so we can respect the stdout/stderr options.
|
||||||
msg = Printer(no_print=True)
|
msg = Printer(no_print=True)
|
||||||
|
@ -105,17 +105,18 @@ def train(
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
finalize_logger()
|
finalize_logger()
|
||||||
|
if optimizer.averages:
|
||||||
|
nlp.use_params(optimizer.averages)
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
final_model_path = output_path / DIR_MODEL_LAST
|
final_model_path = output_path / DIR_MODEL_LAST
|
||||||
if optimizer.averages:
|
nlp.to_disk(final_model_path)
|
||||||
with nlp.use_params(optimizer.averages):
|
|
||||||
nlp.to_disk(final_model_path)
|
|
||||||
else:
|
|
||||||
nlp.to_disk(final_model_path)
|
|
||||||
# This will only run if we don't hit an error
|
# This will only run if we don't hit an error
|
||||||
stdout.write(
|
stdout.write(
|
||||||
msg.good("Saved pipeline to output directory", final_model_path) + "\n"
|
msg.good("Saved pipeline to output directory", final_model_path) + "\n"
|
||||||
)
|
)
|
||||||
|
return (nlp, final_model_path)
|
||||||
|
else:
|
||||||
|
return (nlp, None)
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user