diff --git a/spacy/cli/train.py b/spacy/cli/train.py index c6ada957f..743fec9ea 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -58,6 +58,7 @@ from .. import about str, ), noise_level=("Amount of corruption for data augmentation", "option", "nl", float), + eval_beam_widths=("Beam widths to evaluate, e.g. 4,8", "option", "bw", str), gold_preproc=("Use gold preprocessing", "flag", "G", bool), learn_tokens=("Make parser learn gold-standard tokenization", "flag", "T", bool), verbose=("Display more information for debug", "flag", "VV", bool), @@ -81,6 +82,7 @@ def train( parser_multitasks="", entity_multitasks="", noise_level=0.0, + eval_beam_widths="", gold_preproc=False, learn_tokens=False, verbose=False, @@ -134,6 +136,15 @@ def train( util.env_opt("batch_compound", 1.001), ) + if not eval_beam_widths: + eval_beam_widths = [1] + else: + eval_beam_widths = [int(bw) for bw in eval_beam_widths.split(",")] + if 1 not in eval_beam_widths: + eval_beam_widths.append(1) + eval_beam_widths.sort() + has_beam_widths = eval_beam_widths != [1] + # Set up the base model and pipeline. If a base model is specified, load # the model and make sure the pipeline matches the pipeline setting. If # training starts from a blank model, intitalize the language class. @@ -200,12 +211,12 @@ def train( msg.text("Loaded pretrained tok2vec for: {}".format(components)) # fmt: off - row_head = ("Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS") - row_settings = { - "widths": (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7), - "aligns": tuple(["r" for i in row_head]), - "spacing": 2 - } + row_head = ["Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS"] + row_widths = [3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7] + if has_beam_widths: + row_head.insert(1, "Beam W.") + row_widths.insert(1, 7) + row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2} # fmt: on print("") msg.row(row_head, **row_settings) @@ -247,51 +258,76 @@ def train( epoch_model_path = output_path / ("model%d" % i) nlp.to_disk(epoch_model_path) nlp_loaded = util.load_model_from_path(epoch_model_path) - dev_docs = list(corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc)) - nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs) - start_time = timer() - scorer = nlp_loaded.evaluate(dev_docs, debug) - end_time = timer() - if use_gpu < 0: - gpu_wps = None - cpu_wps = nwords / (end_time - start_time) - else: - gpu_wps = nwords / (end_time - start_time) - with Model.use_device("cpu"): - nlp_loaded = util.load_model_from_path(epoch_model_path) - dev_docs = list( - corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc) - ) - start_time = timer() - scorer = nlp_loaded.evaluate(dev_docs) - end_time = timer() + for beam_width in eval_beam_widths: + for name, component in nlp_loaded.pipeline: + if hasattr(component, "cfg"): + component.cfg["beam_width"] = beam_width + dev_docs = list( + corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc) + ) + nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs) + start_time = timer() + scorer = nlp_loaded.evaluate(dev_docs, debug) + end_time = timer() + if use_gpu < 0: + gpu_wps = None cpu_wps = nwords / (end_time - start_time) - acc_loc = output_path / ("model%d" % i) / "accuracy.json" - srsly.write_json(acc_loc, scorer.scores) + else: + gpu_wps = nwords / (end_time - start_time) + with Model.use_device("cpu"): + nlp_loaded = util.load_model_from_path(epoch_model_path) + nlp_loaded.parser.cfg["beam_width"] + dev_docs = list( + corpus.dev_docs(nlp_loaded, gold_preproc=gold_preproc) + ) + start_time = timer() + scorer = nlp_loaded.evaluate(dev_docs) + end_time = timer() + cpu_wps = nwords / (end_time - start_time) + acc_loc = output_path / ("model%d" % i) / "accuracy.json" + srsly.write_json(acc_loc, scorer.scores) - # Update model meta.json - meta["lang"] = nlp.lang - meta["pipeline"] = nlp.pipe_names - meta["spacy_version"] = ">=%s" % about.__version__ - meta["accuracy"] = scorer.scores - meta["speed"] = {"nwords": nwords, "cpu": cpu_wps, "gpu": gpu_wps} - meta["vectors"] = { - "width": nlp.vocab.vectors_length, - "vectors": len(nlp.vocab.vectors), - "keys": nlp.vocab.vectors.n_keys, - "name": nlp.vocab.vectors.name - } - meta.setdefault("name", "model%d" % i) - meta.setdefault("version", version) - meta_loc = output_path / ("model%d" % i) / "meta.json" - srsly.write_json(meta_loc, meta) + # Update model meta.json + meta["lang"] = nlp.lang + meta["pipeline"] = nlp.pipe_names + meta["spacy_version"] = ">=%s" % about.__version__ + if beam_width == 1: + meta["speed"] = { + "nwords": nwords, + "cpu": cpu_wps, + "gpu": gpu_wps, + } + meta["accuracy"] = scorer.scores + else: + meta.setdefault("beam_accuracy", {}) + meta.setdefault("beam_speed", {}) + meta["beam_accuracy"][beam_width] = scorer.scores + meta["beam_speed"][beam_width] = { + "nwords": nwords, + "cpu": cpu_wps, + "gpu": gpu_wps, + } + meta["vectors"] = { + "width": nlp.vocab.vectors_length, + "vectors": len(nlp.vocab.vectors), + "keys": nlp.vocab.vectors.n_keys, + "name": nlp.vocab.vectors.name, + } + meta.setdefault("name", "model%d" % i) + meta.setdefault("version", version) + meta_loc = output_path / ("model%d" % i) / "meta.json" + srsly.write_json(meta_loc, meta) + util.set_env_log(verbose) - util.set_env_log(verbose) - - progress = _get_progress( - i, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps - ) - msg.row(progress, **row_settings) + progress = _get_progress( + i, + losses, + scorer.scores, + beam_width=beam_width if has_beam_widths else None, + cpu_wps=cpu_wps, + gpu_wps=gpu_wps, + ) + msg.row(progress, **row_settings) finally: with nlp.use_params(optimizer.averages): final_model_path = output_path / "model-final" @@ -377,7 +413,7 @@ def _get_metrics(component): return ("token_acc",) -def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0): +def _get_progress(itn, losses, dev_scores, beam_width=None, cpu_wps=0.0, gpu_wps=0.0): scores = {} for col in [ "dep_loss", @@ -398,7 +434,7 @@ def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0): scores.update(dev_scores) scores["cpu_wps"] = cpu_wps scores["gpu_wps"] = gpu_wps or 0.0 - return [ + result = [ itn, "{:.3f}".format(scores["dep_loss"]), "{:.3f}".format(scores["ner_loss"]), @@ -411,3 +447,6 @@ def _get_progress(itn, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0): "{:.0f}".format(scores["cpu_wps"]), "{:.0f}".format(scores["gpu_wps"]), ] + if beam_width is not None: + result.insert(1, beam_width) + return result diff --git a/spacy/language.py b/spacy/language.py index ec365e12b..d47ec3f83 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -590,6 +590,8 @@ class Language(object): ): if scorer is None: scorer = Scorer() + if component_cfg is None: + component_cfg = {} docs, golds = zip(*docs_golds) docs = list(docs) golds = list(golds) diff --git a/spacy/syntax/_beam_utils.pxd b/spacy/syntax/_beam_utils.pxd index 7bae17558..36b0c05da 100644 --- a/spacy/syntax/_beam_utils.pxd +++ b/spacy/syntax/_beam_utils.pxd @@ -1,6 +1,9 @@ -from thinc.typedefs cimport class_t +from thinc.typedefs cimport class_t, hash_t # These are passed as callbacks to thinc.search.Beam cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 cdef int check_final_state(void* _state, void* extra_args) except -1 + + +cdef hash_t hash_state(void* _state, void* _) except 0 diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index f06d54d9d..dc482f278 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -96,7 +96,7 @@ cdef class ParserBeam(object): self._set_scores(beam, scores[i]) if self.golds is not None: self._set_costs(beam, self.golds[i], follow_gold=follow_gold) - beam.advance(transition_state, NULL, self.moves.c) + beam.advance(transition_state, hash_state, self.moves.c) beam.check_done(check_final_state, NULL) # This handles the non-monotonic stuff for the parser. if beam.is_done and self.golds is not None: @@ -209,10 +209,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, # Track the "maximum violation", to use in the update. for i, violn in enumerate(violns): violn.check_crf(pbeam[i], gbeam[i]) - # Use 'early update' if best gold is way out of contention. - if pbeam[i].loss > 0 and pbeam[i].min_score > (gbeam[i].score * 5.00): - pbeam.dones[i] = True - gbeam.dones[i] = True histories = [] losses = [] for violn in violns: diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index f664e6a2c..841e33432 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -156,7 +156,7 @@ cdef void cpu_log_loss(float* d_scores, """Do multi-label log loss""" cdef double max_, gmax, Z, gZ best = arg_max_if_gold(scores, costs, is_valid, O) - guess = arg_max_if_valid(scores, is_valid, O) + guess = Vec.arg_max(scores, O) if best == -1 or guess == -1: # These shouldn't happen, but if they do, we want to make sure we don't # cause an OOB access. @@ -166,14 +166,11 @@ cdef void cpu_log_loss(float* d_scores, max_ = scores[guess] gmax = scores[best] for i in range(O): - if is_valid[i]: - Z += exp(scores[i] - max_) - if costs[i] <= costs[best]: - gZ += exp(scores[i] - gmax) + Z += exp(scores[i] - max_) + if costs[i] <= costs[best]: + gZ += exp(scores[i] - gmax) for i in range(O): - if not is_valid[i]: - d_scores[i] = 0. - elif costs[i] <= costs[best]: + if costs[i] <= costs[best]: d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ) else: d_scores[i] = exp(scores[i]-max_) / Z diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 8458dd981..f7938d0a4 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -119,6 +119,8 @@ cdef class Parser: cfg['beam_width'] = util.env_opt('beam_width', 1) if 'beam_density' not in cfg: cfg['beam_density'] = util.env_opt('beam_density', 0.0) + if 'beam_update_prob' not in cfg: + cfg['beam_update_prob'] = util.env_opt('beam_update_prob', 1.0) cfg.setdefault('cnn_maxout_pieces', 3) self.cfg = cfg self.model = model @@ -381,7 +383,7 @@ cdef class Parser: self.moves.set_valid(beam.is_valid[i], state) memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float)) c_scores += scores.shape[1] - beam.advance(_beam_utils.transition_state, NULL, self.moves.c) + beam.advance(_beam_utils.transition_state, _beam_utils.hash_state, self.moves.c) beam.check_done(_beam_utils.check_final_state, NULL) return [b for b in beams if not b.is_done]