diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 1eb03fdb4..c7170c747 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -54,8 +54,69 @@ def set_debug(val): DEBUG = val -def get_templates(*args, **kwargs): - return [] +def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper_model): + is_valid = model.ops.allocate((len(docs), system.n_moves), dtype='i') + costs = model.ops.allocate((len(docs), system.n_moves), dtype='f') + token_ids = model.ops.allocate((len(docs), StateClass.nr_context_tokens()), + dtype='uint64') + cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps) + + def forward(states, drop=0.): + nonlocal is_valid, costs, token_ids, features + is_valid = is_valid[:len(states)] + costs = costs[:len(states)] + token_ids = token_ids[:len(states)] + is_valid = is_valid[:len(states)] + for state in states: + state.set_context_tokens(&token_ids[i]) + moves.set_valid(&is_valid[i], state.c) + + features = cached[token_ids].sum(axis=1) + + scores, bp_scores = upper_model.begin_update(features, drop=drop) + softmaxed = model.ops.softmax(scores) + # Renormalize for invalid actions + softmaxed *= is_valid + softmaxed /= softmaxed.sum(axis=1).reshape((softmaxed.shape[0], 1)) + + def backward(golds, sgd=None): + nonlocal costs_, is_valid_, moves_ + cdef TransitionSystem moves = moves_ + cdef int[:, :] is_valid + cdef float[:, :] costs + for i, (state, gold) in enumerate(zip(states, golds)): + moves.set_costs(&costs[i], &is_valid[i], + state, gold) + set_log_loss(model.ops, d_scores, + scores, is_valid, costs) + d_tokens = bp_scores(d_scores, sgd) + return d_tokens + + return softmaxed, backward + + return layerize(forward) + + +def set_log_loss(ops, gradients, scores, is_valid, costs): + """Do multi-label log loss""" + n = gradients.shape[0] + scores = scores * is_valid + g_scores = scores * is_valid * (costs <= 0.) + exps = ops.xp.exp(scores - scores.max(axis=1).reshape((n, 1))) + exps *= is_valid + g_exps = ops.xp.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) + g_exps *= costs <= 0. + g_exps *= is_valid + gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) + gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) + + +def transition_batch(TransitionSystem moves, states, scores): + cdef StateClass state + cdef int guess + for state, guess in zip(states, scores.argmax(axis=1)): + action = moves.c[guess] + action.do(state.c, action.label) cdef class Parser: @@ -114,10 +175,8 @@ cdef class Parser: def build_model(self, width=32, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) - - return build_model_precomputer( - build_model(state2vec, width*2, 2, self.moves.n_moves) - build_feature_maps(nr_context_tokens, width, nr_vector)) + self.model = build_model(width*2, 2, self.moves.n_moves) + self.feature_maps = build_feature_maps(nr_context_tokens, width, nr_vector)) def __call__(self, Doc tokens): """ @@ -129,7 +188,6 @@ cdef class Parser: None """ self.parse_batch([tokens]) - self.moves.finalize_doc(tokens) def pipe(self, stream, int batch_size=1000, int n_threads=2): """ @@ -167,14 +225,20 @@ cdef class Parser: def parse_batch(self, docs): cdef Doc doc cdef StateClass state - model, states = self.init_batch(docs) + model = get_greedy_model_for_batch([d.tensor for d in docs], + self.moves, self.model, self.feat_maps) + states = [StateClass.init(doc.c, doc.length) for doc in docs] todo = list(states) while todo: - todo = model(todo) + scores = model(todo) + transition_batch(self.moves, todo, scores) + todo = [st for st in states if not st.is_final()] for state, doc in zip(states, docs): self.moves.finalize_state(state.c) for i in range(doc.length): doc.c[i] = state.c._sent[i] + for doc in docs: + self.moves.finalize_parse(doc) def update(self, docs, golds, drop=0., sgd=None): if isinstance(docs, Doc) and isinstance(golds, GoldParse): @@ -182,20 +246,19 @@ cdef class Parser: for gold in golds: self.moves.preprocess_gold(gold) - model, states = self.init_batch(docs) + model = get_greedy_model_for_batch([d.tensor for d in docs], + self.moves, self.model, self.feat_maps) d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] output = list(d_tokens) todo = zip(states, golds, d_tokens) while todo: states, golds, d_tokens = zip(*todo) - states, finish_update = model.begin_update(states) + scores, finish_update = model.begin_update(token_ids) d_state_features = finish_update(golds, sgd=sgd) - for i, tok_ids in enumerate(token_ids): - for j, tok_i in enumerate(tok_ids): - if tok_i >= 0: - d_tokens[i][tok_i] += d_state_features[i, j] - + for i, token_ids in enumerate(token_ids): + d_tokens[i][token_ids] += d_state_features[i] + transition_batch(self.moves, states) # Get unfinished states (and their matching gold and token gradients) todo = filter(lambda sp: not sp[0].py_is_final(), todo) return output, sum(losses) @@ -245,28 +308,6 @@ cdef class Parser: self.cfg.setdefault('extra_labels', []).append(label) -def _transition_batch(self, states, scores): - cdef StateClass state - cdef int guess - for state, guess in zip(states, scores.argmax(axis=1)): - action = self.moves.c[guess] - action.do(state.c, action.label) - -def _set_gradient(self, gradients, scores, is_valid, costs): - """Do multi-label log loss""" - cdef double Z, gZ, max_, g_max - n = gradients.shape[0] - scores = scores * is_valid - g_scores = scores * is_valid * (costs <= 0.) - exps = numpy.exp(scores - scores.max(axis=1).reshape((n, 1))) - exps *= is_valid - g_exps = numpy.exp(g_scores - g_scores.max(axis=1).reshape((n, 1))) - g_exps *= costs <= 0. - g_exps *= is_valid - gradients[:] = exps / exps.sum(axis=1).reshape((n, 1)) - gradients -= g_exps / g_exps.sum(axis=1).reshape((n, 1)) - - def _begin_update(self, model, states, tokvecs, drop=0.): nr_class = self.moves.n_moves attr_names = self.model.ops.allocate((2,), dtype='i')