mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add eval_beam_widths argument to spacy train
This commit is contained in:
parent
b13b2aeb54
commit
daa8c3787a
|
@ -58,6 +58,7 @@ from .. import about
|
|||
str,
|
||||
),
|
||||
noise_level=("Amount of corruption for data augmentation", "option", "nl", float),
|
||||
eval_beam_widths=("Beam widths to evaluate, e.g. 4,8", "option", "bw", str),
|
||||
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
|
||||
learn_tokens=("Make parser learn gold-standard tokenization", "flag", "T", bool),
|
||||
verbose=("Display more information for debug", "flag", "VV", bool),
|
||||
|
@ -81,6 +82,7 @@ def train(
|
|||
parser_multitasks="",
|
||||
entity_multitasks="",
|
||||
noise_level=0.0,
|
||||
eval_beam_widths="",
|
||||
gold_preproc=False,
|
||||
learn_tokens=False,
|
||||
verbose=False,
|
||||
|
@ -134,6 +136,14 @@ def train(
|
|||
util.env_opt("batch_compound", 1.001),
|
||||
)
|
||||
|
||||
if not eval_beam_widths:
|
||||
eval_beam_widths = [1]
|
||||
else:
|
||||
eval_beam_widths = [int(bw) for bw in eval_beam_widths.split(",")]
|
||||
if 1 not in eval_beam_widths:
|
||||
eval_beam_widths.append(1)
|
||||
eval_beam_widths.sort()
|
||||
|
||||
# Set up the base model and pipeline. If a base model is specified, load
|
||||
# the model and make sure the pipeline matches the pipeline setting. If
|
||||
# training starts from a blank model, intitalize the language class.
|
||||
|
@ -247,7 +257,7 @@ def train(
|
|||
epoch_model_path = output_path / ("model%d" % i)
|
||||
nlp.to_disk(epoch_model_path)
|
||||
nlp_loaded = util.load_model_from_path(epoch_model_path)
|
||||
for beam_width in [1, 4, 16, 128]:
|
||||
for beam_width in eval_beam_widths:
|
||||
for name, component in nlp_loaded.pipeline:
|
||||
if hasattr(component, "cfg"):
|
||||
component.cfg["beam_width"] = beam_width
|
||||
|
|
Loading…
Reference in New Issue
Block a user