Merge pull request #3416 from explosion/feature/improve-beam

Improve beam search support
This commit is contained in:
Matthew Honnibal 2019-03-16 18:42:18 +01:00 committed by GitHub
commit 58d562d9b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 104 additions and 65 deletions

View File

@ -58,6 +58,7 @@ from .. import about
str, str,
), ),
noise_level=("Amount of corruption for data augmentation", "option", "nl", float), noise_level=("Amount of corruption for data augmentation", "option", "nl", float),
eval_beam_widths=("Beam widths to evaluate, e.g. 4,8", "option", "bw", str),
gold_preproc=("Use gold preprocessing", "flag", "G", bool), gold_preproc=("Use gold preprocessing", "flag", "G", bool),
learn_tokens=("Make parser learn gold-standard tokenization", "flag", "T", bool), learn_tokens=("Make parser learn gold-standard tokenization", "flag", "T", bool),
verbose=("Display more information for debug", "flag", "VV", bool), verbose=("Display more information for debug", "flag", "VV", bool),
@ -81,6 +82,7 @@ def train(
parser_multitasks="", parser_multitasks="",
entity_multitasks="", entity_multitasks="",
noise_level=0.0, noise_level=0.0,
eval_beam_widths="",
gold_preproc=False, gold_preproc=False,
learn_tokens=False, learn_tokens=False,
verbose=False, verbose=False,
@ -134,6 +136,15 @@ def train(
util.env_opt("batch_compound", 1.001), util.env_opt("batch_compound", 1.001),
) )
if not eval_beam_widths:
eval_beam_widths = [1]
else:
eval_beam_widths = [int(bw) for bw in eval_beam_widths.split(",")]
if 1 not in eval_beam_widths:
eval_beam_widths.append(1)
eval_beam_widths.sort()
has_beam_widths = eval_beam_widths != [1]
# Set up the base model and pipeline. If a base model is specified, load # Set up the base model and pipeline. If a base model is specified, load
# the model and make sure the pipeline matches the pipeline setting. If # the model and make sure the pipeline matches the pipeline setting. If
# training starts from a blank model, intitalize the language class. # training starts from a blank model, intitalize the language class.
@ -200,12 +211,12 @@ def train(
msg.text("Loaded pretrained tok2vec for: {}".format(components)) msg.text("Loaded pretrained tok2vec for: {}".format(components))
# fmt: off # fmt: off
row_head = ("Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS") row_head = ["Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS"]
row_settings = { row_widths = [3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7]
"widths": (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7), if has_beam_widths:
"aligns": tuple(["r" for i in row_head]), row_head.insert(1, "Beam W.")
"spacing": 2 row_widths.insert(1, 7)
} row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2}
# fmt: on # fmt: on
print("") print("")
msg.row(row_head, **row_settings) msg.row(row_head, **row_settings)
@ -247,51 +258,76 @@ def train(
epoch_model_path = output_path / ("model%d" % i) epoch_model_path = output_path / ("model%d" % i)
nlp.to_disk(epoch_model_path) nlp.to_disk(epoch_model_path)
nlp_loaded = util.load_model_from_path(epoch_model_path) nlp_loaded = util.load_model_from_path(epoch_model_path)
dev_docs = list(corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)) for beam_width in eval_beam_widths:
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs) for name, component in nlp_loaded.pipeline:
start_time = timer() if hasattr(component, "cfg"):
scorer = nlp_loaded.evaluate(dev_docs, debug) component.cfg["beam_width"] = beam_width
end_time = timer() dev_docs = list(
if use_gpu < 0: corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)
gpu_wps = None )
cpu_wps = nwords / (end_time - start_time) nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
else: start_time = timer()
gpu_wps = nwords / (end_time - start_time) scorer = nlp_loaded.evaluate(dev_docs, debug)
with Model.use_device("cpu"): end_time = timer()
nlp_loaded = util.load_model_from_path(epoch_model_path) if use_gpu < 0:
dev_docs = list( gpu_wps = None
corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)
)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs)
end_time = timer()
cpu_wps = nwords / (end_time - start_time) cpu_wps = nwords / (end_time - start_time)
acc_loc = output_path / ("model%d" % i) / "accuracy.json" else:
srsly.write_json(acc_loc, scorer.scores) gpu_wps = nwords / (end_time - start_time)
with Model.use_device("cpu"):
nlp_loaded = util.load_model_from_path(epoch_model_path)
nlp_loaded.parser.cfg["beam_width"]
dev_docs = list(
corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)
)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs)
end_time = timer()
cpu_wps = nwords / (end_time - start_time)
acc_loc = output_path / ("model%d" % i) / "accuracy.json"
srsly.write_json(acc_loc, scorer.scores)
# Update model meta.json # Update model meta.json
meta["lang"] = nlp.lang meta["lang"] = nlp.lang
meta["pipeline"] = nlp.pipe_names meta["pipeline"] = nlp.pipe_names
meta["spacy_version"] = ">=%s" % about.__version__ meta["spacy_version"] = ">=%s" % about.__version__
meta["accuracy"] = scorer.scores if beam_width == 1:
meta["speed"] = {"nwords": nwords, "cpu": cpu_wps, "gpu": gpu_wps} meta["speed"] = {
meta["vectors"] = { "nwords": nwords,
"width": nlp.vocab.vectors_length, "cpu": cpu_wps,
"vectors": len(nlp.vocab.vectors), "gpu": gpu_wps,
"keys": nlp.vocab.vectors.n_keys, }
"name": nlp.vocab.vectors.name meta["accuracy"] = scorer.scores
} else:
meta.setdefault("name", "model%d" % i) meta.setdefault("beam_accuracy", {})
meta.setdefault("version", version) meta.setdefault("beam_speed", {})
meta_loc = output_path / ("model%d" % i) / "meta.json" meta["beam_accuracy"][beam_width] = scorer.scores
srsly.write_json(meta_loc, meta) meta["beam_speed"][beam_width] = {
"nwords": nwords,
"cpu": cpu_wps,
"gpu": gpu_wps,
}
meta["vectors"] = {
"width": nlp.vocab.vectors_length,
"vectors": len(nlp.vocab.vectors),
"keys": nlp.vocab.vectors.n_keys,
"name": nlp.vocab.vectors.name,
}
meta.setdefault("name", "model%d" % i)
meta.setdefault("version", version)
meta_loc = output_path / ("model%d" % i) / "meta.json"
srsly.write_json(meta_loc, meta)
util.set_env_log(verbose)
util.set_env_log(verbose) progress = _get_progress(
i,
progress = _get_progress( losses,
i, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps scorer.scores,
) beam_width=beam_width if has_beam_widths else None,
msg.row(progress, **row_settings) cpu_wps=cpu_wps,
gpu_wps=gpu_wps,
)
msg.row(progress, **row_settings)
finally: finally:
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final" final_model_path = output_path / "model-final"
@ -377,7 +413,7 @@ def _get_metrics(component):
return ("token_acc",) return ("token_acc",)
def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0): def _get_progress(itn, losses, dev_scores, beam_width=None, cpu_wps=0.0, gpu_wps=0.0):
scores = {} scores = {}
for col in [ for col in [
"dep_loss", "dep_loss",
@ -398,7 +434,7 @@ def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
scores.update(dev_scores) scores.update(dev_scores)
scores["cpu_wps"] = cpu_wps scores["cpu_wps"] = cpu_wps
scores["gpu_wps"] = gpu_wps or 0.0 scores["gpu_wps"] = gpu_wps or 0.0
return [ result = [
itn, itn,
"{:.3f}".format(scores["dep_loss"]), "{:.3f}".format(scores["dep_loss"]),
"{:.3f}".format(scores["ner_loss"]), "{:.3f}".format(scores["ner_loss"]),
@ -411,3 +447,6 @@ def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
"{:.0f}".format(scores["cpu_wps"]), "{:.0f}".format(scores["cpu_wps"]),
"{:.0f}".format(scores["gpu_wps"]), "{:.0f}".format(scores["gpu_wps"]),
] ]
if beam_width is not None:
result.insert(1, beam_width)
return result

View File

@ -590,6 +590,8 @@ class Language(object):
): ):
if scorer is None: if scorer is None:
scorer = Scorer() scorer = Scorer()
if component_cfg is None:
component_cfg = {}
docs, golds = zip(*docs_golds) docs, golds = zip(*docs_golds)
docs = list(docs) docs = list(docs)
golds = list(golds) golds = list(golds)

View File

@ -1,6 +1,9 @@
from thinc.typedefs cimport class_t from thinc.typedefs cimport class_t, hash_t
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
cdef int check_final_state(void* _state, void* extra_args) except -1 cdef int check_final_state(void* _state, void* extra_args) except -1
cdef hash_t hash_state(void* _state, void* _) except 0

View File

@ -96,7 +96,7 @@ cdef class ParserBeam(object):
self._set_scores(beam, scores[i]) self._set_scores(beam, scores[i])
if self.golds is not None: if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold) self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
beam.advance(transition_state, NULL, <void*>self.moves.c) beam.advance(transition_state, hash_state, <void*>self.moves.c)
beam.check_done(check_final_state, NULL) beam.check_done(check_final_state, NULL)
# This handles the non-monotonic stuff for the parser. # This handles the non-monotonic stuff for the parser.
if beam.is_done and self.golds is not None: if beam.is_done and self.golds is not None:
@ -209,10 +209,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Track the "maximum violation", to use in the update. # Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns): for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i]) violn.check_crf(pbeam[i], gbeam[i])
# Use 'early update' if best gold is way out of contention.
if pbeam[i].loss > 0 and pbeam[i].min_score > (gbeam[i].score * 5.00):
pbeam.dones[i] = True
gbeam.dones[i] = True
histories = [] histories = []
losses = [] losses = []
for violn in violns: for violn in violns:

View File

@ -156,7 +156,7 @@ cdef void cpu_log_loss(float* d_scores,
"""Do multi-label log loss""" """Do multi-label log loss"""
cdef double max_, gmax, Z, gZ cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O) best = arg_max_if_gold(scores, costs, is_valid, O)
guess = arg_max_if_valid(scores, is_valid, O) guess = Vec.arg_max(scores, O)
if best == -1 or guess == -1: if best == -1 or guess == -1:
# These shouldn't happen, but if they do, we want to make sure we don't # These shouldn't happen, but if they do, we want to make sure we don't
# cause an OOB access. # cause an OOB access.
@ -166,14 +166,11 @@ cdef void cpu_log_loss(float* d_scores,
max_ = scores[guess] max_ = scores[guess]
gmax = scores[best] gmax = scores[best]
for i in range(O): for i in range(O):
if is_valid[i]: Z += exp(scores[i] - max_)
Z += exp(scores[i] - max_) if costs[i] <= costs[best]:
if costs[i] <= costs[best]: gZ += exp(scores[i] - gmax)
gZ += exp(scores[i] - gmax)
for i in range(O): for i in range(O):
if not is_valid[i]: if costs[i] <= costs[best]:
d_scores[i] = 0.
elif costs[i] <= costs[best]:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ) d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else: else:
d_scores[i] = exp(scores[i]-max_) / Z d_scores[i] = exp(scores[i]-max_) / Z

View File

@ -119,6 +119,8 @@ cdef class Parser:
cfg['beam_width'] = util.env_opt('beam_width', 1) cfg['beam_width'] = util.env_opt('beam_width', 1)
if 'beam_density' not in cfg: if 'beam_density' not in cfg:
cfg['beam_density'] = util.env_opt('beam_density', 0.0) cfg['beam_density'] = util.env_opt('beam_density', 0.0)
if 'beam_update_prob' not in cfg:
cfg['beam_update_prob'] = util.env_opt('beam_update_prob', 1.0)
cfg.setdefault('cnn_maxout_pieces', 3) cfg.setdefault('cnn_maxout_pieces', 3)
self.cfg = cfg self.cfg = cfg
self.model = model self.model = model
@ -381,7 +383,7 @@ cdef class Parser:
self.moves.set_valid(beam.is_valid[i], state) self.moves.set_valid(beam.is_valid[i], state)
memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float)) memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float))
c_scores += scores.shape[1] c_scores += scores.shape[1]
beam.advance(_beam_utils.transition_state, NULL, <void*>self.moves.c) beam.advance(_beam_utils.transition_state, _beam_utils.hash_state, <void*>self.moves.c)
beam.check_done(_beam_utils.check_final_state, NULL) beam.check_done(_beam_utils.check_final_state, NULL)
return [b for b in beams if not b.is_done] return [b for b in beams if not b.is_done]