diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx new file mode 100644 index 000000000..447bbc811 --- /dev/null +++ b/spacy/syntax/beam_parser.pyx @@ -0,0 +1,259 @@ +# cython: profile=True +# cython: experimental_cpp_class_def=True +# cython: cdivision=True +# cython: infer_types=True +""" +MALT-style dependency parser +""" +from __future__ import unicode_literals +cimport cython + +from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF + +from libc.stdint cimport uint32_t, uint64_t +from libc.string cimport memset, memcpy +from libc.stdlib cimport rand +from libc.math cimport log, exp, isnan, isinf +import random +import os.path +from os import path +import shutil +import json +import math + +from cymem.cymem cimport Pool, Address +from murmurhash.mrmr cimport real_hash64 as hash64 +from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t + + +from util import Config + +from thinc.linear.features cimport ConjunctionExtracter +from thinc.structs cimport FeatureC, ExampleC + +from thinc.extra.search cimport Beam +from thinc.extra.search cimport MaxViolation +from thinc.extra.eg cimport Example +from thinc.extra.mb cimport Minibatch + +from ..structs cimport TokenC + +from ..tokens.doc cimport Doc +from ..strings cimport StringStore + +from .transition_system cimport TransitionSystem, Transition + +from ..gold cimport GoldParse + +from . import _parse_features +from ._parse_features cimport CONTEXT_SIZE +from ._parse_features cimport fill_context +from .stateclass cimport StateClass +from .parser cimport Parser + + +DEBUG = False +def set_debug(val): + global DEBUG + DEBUG = val + + +def get_templates(name): + pf = _parse_features + if name == 'ner': + return pf.ner + elif name == 'debug': + return pf.unigrams + else: + return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \ + pf.tree_shape + pf.trigrams) + + +cdef int BEAM_WIDTH = 16 +cdef weight_t BEAM_DENSITY = 0.01 + +cdef class BeamParser(Parser): + def __init__(self, *args, **kwargs): + self.beam_width = kwargs.get('beam_width', BEAM_WIDTH) + self.beam_density = kwargs.get('beam_density', BEAM_DENSITY) + Parser.__init__(self, *args, **kwargs) + + cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil: + self._parseC(tokens, length, nr_feat, nr_class) + + cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1: + cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) + beam.initialize(_init_state, length, tokens) + beam.check_done(_check_final_state, NULL) + if beam.is_done: + _cleanup(beam) + return 0 + while not beam.is_done: + self._advance_beam(beam, None, False) + state = beam.at(0) + self.moves.finalize_state(state.c) + for i in range(length): + tokens[i] = state.c._sent[i] + _cleanup(beam) + + def train(self, Doc tokens, GoldParse gold_parse, itn=0): + self.moves.preprocess_gold(gold_parse) + cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) + pred.initialize(_init_state, tokens.length, tokens.c) + pred.check_done(_check_final_state, NULL) + + cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) + gold.initialize(_init_state, tokens.length, tokens.c) + gold.check_done(_check_final_state, NULL) + violn = MaxViolation() + itn = 0 + while not pred.is_done and not gold.is_done: + # We search separately here, to allow for ambiguity in the gold parse. + self._advance_beam(pred, gold_parse, False) + self._advance_beam(gold, gold_parse, True) + violn.check_crf(pred, gold) + if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): + break + itn += 1 + else: + # The non-monotonic oracle makes it difficult to ensure final costs are + # correct. Therefore do final correction + for i in range(pred.size): + if is_gold(pred.at(i), gold_parse, self.moves.strings): + pred._states[i].loss = 0.0 + elif pred._states[i].loss == 0.0: + pred._states[i].loss = 1.0 + violn.check_crf(pred, gold) + assert pred.size >= 1 + assert gold.size >= 1 + #_check_train_integrity(pred, gold, gold_parse, self.moves) + histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist) + min_grad = 0.001 ** (itn+1) + histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad] + random.shuffle(histories) + for grad, hist in histories: + assert not math.isnan(grad) and not math.isinf(grad), hist + self.model._update_from_history(self.moves, tokens, hist, grad) + _cleanup(pred) + _cleanup(gold) + return pred.loss + + def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): + cdef atom_t[CONTEXT_SIZE] context + cdef Pool mem = Pool() + features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) + if False: + mb = Minibatch(self.model.widths, beam.size) + for i in range(beam.size): + stcls = beam.at(i) + if stcls.c.is_final(): + nr_feat = 0 + else: + nr_feat = self.model.set_featuresC(context, features, stcls.c) + self.moves.set_valid(beam.is_valid[i], stcls.c) + mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) + self.model(mb) + for i in range(beam.size): + memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) + else: + for i in range(beam.size): + stcls = beam.at(i) + if not stcls.is_final(): + nr_feat = self.model.set_featuresC(context, features, stcls.c) + self.moves.set_valid(beam.is_valid[i], stcls.c) + self.model.set_scoresC(beam.scores[i], features, nr_feat) + if gold is not None: + for i in range(beam.size): + stcls = beam.at(i) + if not stcls.c.is_final(): + self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) + if follow_gold: + for j in range(self.moves.n_moves): + beam.is_valid[i][j] *= beam.costs[i][j] < 1 + if follow_gold: + beam.advance(_transition_state, NULL, self.moves.c) + else: + beam.advance(_transition_state, NULL, self.moves.c) + beam.check_done(_check_final_state, NULL) + + +# These are passed as callbacks to thinc.search.Beam +cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: + dest = _dest + src = _src + moves = _moves + dest.clone(src) + moves[clas].do(dest.c, moves[clas].label) + + +cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: + cdef StateClass st = StateClass.init(tokens, length) + # Ensure sent_start is set to 0 throughout + for i in range(st.c.length): + st.c._sent[i].sent_start = False + st.c._sent[i].l_edge = i + st.c._sent[i].r_edge = i + st.fast_forward() + Py_INCREF(st) + return st + + +cdef int _check_final_state(void* _state, void* extra_args) except -1: + return (_state).is_final() + + +def _cleanup(Beam beam): + for i in range(beam.width): + Py_XDECREF(beam._states[i].content) + Py_XDECREF(beam._parents[i].content) + + +cdef hash_t _hash_state(void* _state, void* _) except 0: + state = _state + if state.c.is_final(): + return 1 + else: + return state.c.hash() + + +def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves): + for i in range(pred.size): + if not pred._states[i].is_done or pred._states[i].loss == 0: + continue + state = pred.at(i) + if is_gold(state, gold_parse, moves.strings) == True: + for dep in gold_parse.orig_annot: + print(dep[1], dep[3], dep[4]) + print("Cost", pred._states[i].loss) + for j in range(gold_parse.length): + print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) + acts = [moves.c[clas].move for clas in pred.histories[i]] + labels = [moves.c[clas].label for clas in pred.histories[i]] + print([moves.move_name(move, label) for move, label in zip(acts, labels)]) + raise Exception("Predicted state is gold-standard") + for i in range(gold.size): + if not gold._states[i].is_done: + continue + state = gold.at(i) + if is_gold(state, gold_parse, moves.strings) == False: + print("Truth") + for dep in gold_parse.orig_annot: + print(dep[1], dep[3], dep[4]) + print("Predicted good") + for j in range(gold_parse.length): + print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) + raise Exception("Gold parse is not gold-standard") + + +def is_gold(StateClass state, GoldParse gold, StringStore strings): + predicted = set() + truth = set() + for i in range(gold.length): + if state.safe_get(i).dep: + predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) + else: + predicted.add((i, state.H(i), 'ROOT')) + id_, word, tag, head, dep, ner = gold.orig_annot[i] + truth.add((id_, head, dep)) + return truth == predicted + diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index 4370f5c6f..aaed10303 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -1,6 +1,6 @@ from thinc.linear.avgtron cimport AveragedPerceptron -from thinc.extra.eg cimport Example -from thinc.structs cimport ExampleC +from thinc.typedefs cimport atom_t +from thinc.structs cimport FeatureC from .stateclass cimport StateClass from .arc_eager cimport TransitionSystem @@ -11,7 +11,8 @@ from ._state cimport StateC cdef class ParserModel(AveragedPerceptron): - cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil + cdef int set_featuresC(self, atom_t* context, FeatureC* features, + const StateC* state) nogil cdef class Parser: @@ -20,4 +21,4 @@ cdef class Parser: cdef readonly TransitionSystem moves cdef readonly object cfg - cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil + cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 093186518..410dbc029 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -27,7 +27,10 @@ from thinc.linalg cimport VecVec from thinc.structs cimport SparseArrayC from preshed.maps cimport MapStruct from preshed.maps cimport map_get + from thinc.structs cimport FeatureC +from thinc.structs cimport ExampleC +from thinc.extra.eg cimport Example from util import Config @@ -68,9 +71,43 @@ def get_templates(name): cdef class ParserModel(AveragedPerceptron): - cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil: - fill_context(eg.atoms, state) - eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms) + cdef int set_featuresC(self, atom_t* context, FeatureC* features, + const StateC* state) nogil: + fill_context(context, state) + nr_feat = self.extracter.set_features(features, context) + return nr_feat + + def update(self, Example eg): + '''Does regression on negative cost. Sort of cute?''' + self.time += 1 + cdef weight_t loss = 0.0 + best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) + for clas in range(eg.c.nr_class): + if not eg.c.is_valid[clas]: + continue + if eg.c.scores[clas] < eg.c.scores[best]: + continue + loss += (-eg.c.costs[clas] - eg.c.scores[clas]) ** 2 + d_loss = -2 * (-eg.c.costs[clas] - eg.c.scores[clas]) + for feat in eg.c.features[:eg.c.nr_feat]: + self.update_weight_ftrl(feat.key, clas, feat.value * d_loss) + return int(loss) + + def update_from_history(self, TransitionSystem moves, Doc doc, history, weight_t grad): + cdef Pool mem = Pool() + features = mem.alloc(self.nr_feat, sizeof(FeatureC)) + + cdef StateClass stcls = StateClass.init(doc.c, doc.length) + moves.initialize_state(stcls.c) + + cdef class_t clas + self.time += 1 + cdef atom_t[CONTEXT_SIZE] atoms + for clas in history: + nr_feat = self.set_featuresC(atoms, features, stcls.c) + for feat in features[:nr_feat]: + self.update_weight(feat.key, clas, feat.value * grad) + moves.c[clas].do(stcls.c, moves.c[clas].label) cdef class Parser: @@ -141,7 +178,7 @@ cdef class Parser: """ cdef int nr_feat = self.model.nr_feat with nogil: - status = self.parseC(tokens.c, tokens.length, nr_feat) + status = self.parseC(tokens.c, tokens.length, nr_feat, self.moves.n_moves) # Check for KeyboardInterrupt etc. Untested PyErr_CheckSignals() if status != 0: @@ -174,7 +211,7 @@ cdef class Parser: if len(queue) == batch_size: with nogil: for i in cython.parallel.prange(batch_size, num_threads=n_threads): - status = self.parseC(doc_ptr[i], lengths[i], nr_feat) + status = self.parseC(doc_ptr[i], lengths[i], nr_feat, self.moves.n_moves) if status != 0: with gil: raise ParserStateError(queue[i]) @@ -186,7 +223,7 @@ cdef class Parser: batch_size = len(queue) with nogil: for i in cython.parallel.prange(batch_size, num_threads=n_threads): - status = self.parseC(doc_ptr[i], lengths[i], nr_feat) + status = self.parseC(doc_ptr[i], lengths[i], nr_feat, self.moves.n_moves) if status != 0: with gil: raise ParserStateError(queue[i]) @@ -195,11 +232,10 @@ cdef class Parser: self.moves.finalize_doc(doc) yield doc - cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil: + cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil: state = new StateC(tokens, length) # NB: This can change self.moves.n_moves! self.moves.initialize_state(state) - nr_class = self.moves.n_moves cdef ExampleC eg eg.nr_feat = nr_feat @@ -211,7 +247,7 @@ cdef class Parser: eg.is_valid = calloc(sizeof(int), nr_class) cdef int i while not state.is_final(): - self.model.set_featuresC(&eg, state) + eg.nr_feat = self.model.set_featuresC(eg.atoms, eg.features, state) self.moves.set_valid(eg.is_valid, state) self.model.set_scoresC(eg.scores, eg.features, eg.nr_feat) @@ -257,16 +293,17 @@ cdef class Parser: cdef weight_t loss = 0 cdef Transition action while not stcls.is_final(): - self.model.set_featuresC(&eg.c, stcls.c) + eg.c.nr_feat = self.model.set_featuresC(eg.c.atoms, eg.c.features, + stcls.c) self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) - self.model.time += 1 guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) if eg.c.costs[guess] > 0: - best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) - for feat in eg.c.features[:eg.c.nr_feat]: - self.model.update_weight_ftrl(feat.key, best, -feat.value * eg.c.costs[guess]) - self.model.update_weight_ftrl(feat.key, guess, feat.value * eg.c.costs[guess]) + self.model.update(eg) + #best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) + #for feat in eg.c.features[:eg.c.nr_feat]: + # self.model.update_weight_ftrl(feat.key, best, -feat.value * eg.c.costs[guess]) + # self.model.update_weight_ftrl(feat.key, guess, feat.value * eg.c.costs[guess]) action = self.moves.c[guess] action.do(stcls.c, action.label) @@ -350,7 +387,8 @@ cdef class StepwiseState: def predict(self): self.eg.reset() - self.parser.model.set_featuresC(&self.eg.c, self.stcls.c) + self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features, + self.stcls.c) self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c) self.parser.model.set_scoresC(self.eg.c.scores, self.eg.c.features, self.eg.c.nr_feat)