Merge pull request #3416 from explosion/feature/improve-beam

Improve beam search support
This commit is contained in:
Matthew Honnibal 2019-03-16 18:42:18 +01:00 committed by GitHub
commit 58d562d9b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 104 additions and 65 deletions

View File

@ -58,6 +58,7 @@ from .. import about
str,
),
noise_level=("Amount of corruption for data augmentation", "option", "nl", float),
eval_beam_widths=("Beam widths to evaluate, e.g. 4,8", "option", "bw", str),
gold_preproc=("Use gold preprocessing", "flag", "G", bool),
learn_tokens=("Make parser learn gold-standard tokenization", "flag", "T", bool),
verbose=("Display more information for debug", "flag", "VV", bool),
@ -81,6 +82,7 @@ def train(
parser_multitasks="",
entity_multitasks="",
noise_level=0.0,
eval_beam_widths="",
gold_preproc=False,
learn_tokens=False,
verbose=False,
@ -134,6 +136,15 @@ def train(
util.env_opt("batch_compound", 1.001),
)
if not eval_beam_widths:
eval_beam_widths = [1]
else:
eval_beam_widths = [int(bw) for bw in eval_beam_widths.split(",")]
if 1 not in eval_beam_widths:
eval_beam_widths.append(1)
eval_beam_widths.sort()
has_beam_widths = eval_beam_widths != [1]
# Set up the base model and pipeline. If a base model is specified, load
# the model and make sure the pipeline matches the pipeline setting. If
# training starts from a blank model, intitalize the language class.
@ -200,12 +211,12 @@ def train(
msg.text("Loaded pretrained tok2vec for: {}".format(components))
# fmt: off
row_head = ("Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS")
row_settings = {
"widths": (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7),
"aligns": tuple(["r" for i in row_head]),
"spacing": 2
}
row_head = ["Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS"]
row_widths = [3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7]
if has_beam_widths:
row_head.insert(1, "Beam W.")
row_widths.insert(1, 7)
row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2}
# fmt: on
print("")
msg.row(row_head, **row_settings)
@ -247,51 +258,76 @@ def train(
epoch_model_path = output_path / ("model%d" % i)
nlp.to_disk(epoch_model_path)
nlp_loaded = util.load_model_from_path(epoch_model_path)
dev_docs = list(corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc))
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs, debug)
end_time = timer()
if use_gpu < 0:
gpu_wps = None
cpu_wps = nwords / (end_time - start_time)
else:
gpu_wps = nwords / (end_time - start_time)
with Model.use_device("cpu"):
nlp_loaded = util.load_model_from_path(epoch_model_path)
dev_docs = list(
corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)
)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs)
end_time = timer()
for beam_width in eval_beam_widths:
for name, component in nlp_loaded.pipeline:
if hasattr(component, "cfg"):
component.cfg["beam_width"] = beam_width
dev_docs = list(
corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)
)
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs, debug)
end_time = timer()
if use_gpu < 0:
gpu_wps = None
cpu_wps = nwords / (end_time - start_time)
acc_loc = output_path / ("model%d" % i) / "accuracy.json"
srsly.write_json(acc_loc, scorer.scores)
else:
gpu_wps = nwords / (end_time - start_time)
with Model.use_device("cpu"):
nlp_loaded = util.load_model_from_path(epoch_model_path)
nlp_loaded.parser.cfg["beam_width"]
dev_docs = list(
corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)
)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs)
end_time = timer()
cpu_wps = nwords / (end_time - start_time)
acc_loc = output_path / ("model%d" % i) / "accuracy.json"
srsly.write_json(acc_loc, scorer.scores)
# Update model meta.json
meta["lang"] = nlp.lang
meta["pipeline"] = nlp.pipe_names
meta["spacy_version"] = ">=%s" % about.__version__
meta["accuracy"] = scorer.scores
meta["speed"] = {"nwords": nwords, "cpu": cpu_wps, "gpu": gpu_wps}
meta["vectors"] = {
"width": nlp.vocab.vectors_length,
"vectors": len(nlp.vocab.vectors),
"keys": nlp.vocab.vectors.n_keys,
"name": nlp.vocab.vectors.name
}
meta.setdefault("name", "model%d" % i)
meta.setdefault("version", version)
meta_loc = output_path / ("model%d" % i) / "meta.json"
srsly.write_json(meta_loc, meta)
# Update model meta.json
meta["lang"] = nlp.lang
meta["pipeline"] = nlp.pipe_names
meta["spacy_version"] = ">=%s" % about.__version__
if beam_width == 1:
meta["speed"] = {
"nwords": nwords,
"cpu": cpu_wps,
"gpu": gpu_wps,
}
meta["accuracy"] = scorer.scores
else:
meta.setdefault("beam_accuracy", {})
meta.setdefault("beam_speed", {})
meta["beam_accuracy"][beam_width] = scorer.scores
meta["beam_speed"][beam_width] = {
"nwords": nwords,
"cpu": cpu_wps,
"gpu": gpu_wps,
}
meta["vectors"] = {
"width": nlp.vocab.vectors_length,
"vectors": len(nlp.vocab.vectors),
"keys": nlp.vocab.vectors.n_keys,
"name": nlp.vocab.vectors.name,
}
meta.setdefault("name", "model%d" % i)
meta.setdefault("version", version)
meta_loc = output_path / ("model%d" % i) / "meta.json"
srsly.write_json(meta_loc, meta)
util.set_env_log(verbose)
util.set_env_log(verbose)
progress = _get_progress(
i, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps
)
msg.row(progress, **row_settings)
progress = _get_progress(
i,
losses,
scorer.scores,
beam_width=beam_width if has_beam_widths else None,
cpu_wps=cpu_wps,
gpu_wps=gpu_wps,
)
msg.row(progress, **row_settings)
finally:
with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final"
@ -377,7 +413,7 @@ def _get_metrics(component):
return ("token_acc",)
def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
def _get_progress(itn, losses, dev_scores, beam_width=None, cpu_wps=0.0, gpu_wps=0.0):
scores = {}
for col in [
"dep_loss",
@ -398,7 +434,7 @@ def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
scores.update(dev_scores)
scores["cpu_wps"] = cpu_wps
scores["gpu_wps"] = gpu_wps or 0.0
return [
result = [
itn,
"{:.3f}".format(scores["dep_loss"]),
"{:.3f}".format(scores["ner_loss"]),
@ -411,3 +447,6 @@ def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
"{:.0f}".format(scores["cpu_wps"]),
"{:.0f}".format(scores["gpu_wps"]),
]
if beam_width is not None:
result.insert(1, beam_width)
return result

View File

@ -590,6 +590,8 @@ class Language(object):
):
if scorer is None:
scorer = Scorer()
if component_cfg is None:
component_cfg = {}
docs, golds = zip(*docs_golds)
docs = list(docs)
golds = list(golds)

View File

@ -1,6 +1,9 @@
from thinc.typedefs cimport class_t
from thinc.typedefs cimport class_t, hash_t
# These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
cdef int check_final_state(void* _state, void* extra_args) except -1
cdef hash_t hash_state(void* _state, void* _) except 0

View File

@ -96,7 +96,7 @@ cdef class ParserBeam(object):
self._set_scores(beam, scores[i])
if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
beam.advance(transition_state, NULL, <void*>self.moves.c)
beam.advance(transition_state, hash_state, <void*>self.moves.c)
beam.check_done(check_final_state, NULL)
# This handles the non-monotonic stuff for the parser.
if beam.is_done and self.golds is not None:
@ -209,10 +209,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i])
# Use 'early update' if best gold is way out of contention.
if pbeam[i].loss > 0 and pbeam[i].min_score > (gbeam[i].score * 5.00):
pbeam.dones[i] = True
gbeam.dones[i] = True
histories = []
losses = []
for violn in violns:

View File

@ -156,7 +156,7 @@ cdef void cpu_log_loss(float* d_scores,
"""Do multi-label log loss"""
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = arg_max_if_valid(scores, is_valid, O)
guess = Vec.arg_max(scores, O)
if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access.
@ -166,14 +166,11 @@ cdef void cpu_log_loss(float* d_scores,
max_ = scores[guess]
gmax = scores[best]
for i in range(O):
if is_valid[i]:
Z += exp(scores[i] - max_)
if costs[i] <= costs[best]:
gZ += exp(scores[i] - gmax)
Z += exp(scores[i] - max_)
if costs[i] <= costs[best]:
gZ += exp(scores[i] - gmax)
for i in range(O):
if not is_valid[i]:
d_scores[i] = 0.
elif costs[i] <= costs[best]:
if costs[i] <= costs[best]:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else:
d_scores[i] = exp(scores[i]-max_) / Z

View File

@ -119,6 +119,8 @@ cdef class Parser:
cfg['beam_width'] = util.env_opt('beam_width', 1)
if 'beam_density' not in cfg:
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
if 'beam_update_prob' not in cfg:
cfg['beam_update_prob'] = util.env_opt('beam_update_prob', 1.0)
cfg.setdefault('cnn_maxout_pieces', 3)
self.cfg = cfg
self.model = model
@ -381,7 +383,7 @@ cdef class Parser:
self.moves.set_valid(beam.is_valid[i], state)
memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float))
c_scores += scores.shape[1]
beam.advance(_beam_utils.transition_state, NULL, <void*>self.moves.c)
beam.advance(_beam_utils.transition_state, _beam_utils.hash_state, <void*>self.moves.c)
beam.check_done(_beam_utils.check_final_state, NULL)
return [b for b in beams if not b.is_done]