diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 68301238d..7b7a35700 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -1,6 +1,7 @@ # cython: infer_types=True # cython: cdivision=True # cython: boundscheck=False +# cython: profile=True # coding: utf-8 from __future__ import unicode_literals, print_function @@ -322,15 +323,17 @@ cdef class Parser: beam_density = self.cfg.get('beam_density', 0.0) cdef Beam beam if beam_width == 1: - states = self.parse_batch([doc]) - self.set_annotations([doc], states) + states, tokvecs = self.parse_batch([doc]) + self.set_annotations([doc], states, tensors=tokvecs) return doc else: - beam = self.beam_parse([doc], - beam_width=beam_width, beam_density=beam_density)[0] + beams, tokvecs = self.beam_parse([doc], + beam_width=beam_width, + beam_density=beam_density) + beam = beams[0] output = self.moves.get_beam_annot(beam) state = beam.at(0) - self.set_annotations([doc], [state]) + self.set_annotations([doc], [state], tensors=tokvecs) _cleanup(beam) return output @@ -356,15 +359,16 @@ cdef class Parser: for subbatch in cytoolz.partition_all(8, by_length): subbatch = list(subbatch) if beam_width == 1: - parse_states = self.parse_batch(subbatch) + parse_states, tokvecs = self.parse_batch(subbatch) beams = [] else: - beams = self.beam_parse(subbatch, beam_width=beam_width, - beam_density=beam_density) + beams, tokvecs = self.beam_parse(subbatch, + beam_width=beam_width, + beam_density=beam_density) parse_states = [] for beam in beams: parse_states.append(beam.at(0)) - self.set_annotations(subbatch, parse_states) + self.set_annotations(subbatch, parse_states, tensors=tokvecs) yield from batch def parse_batch(self, docs): @@ -411,7 +415,9 @@ cdef class Parser: feat_weights, bias, hW, hb, nr_class, nr_hidden, nr_feat, nr_piece) PyErr_CheckSignals() - return state_objs + tokvecs = self.model[0].ops.unflatten(tokvecs, + [len(doc) for doc in docs]) + return state_objs, tokvecs cdef void _parseC(self, StateC* state, const float* feat_weights, const float* bias, @@ -508,7 +514,9 @@ cdef class Parser: beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) beams.append(beam) - return beams + tokvecs = self.model[0].ops.unflatten(tokvecs, + [len(doc) for doc in docs]) + return beams, tokvecs def update(self, docs, golds, drop=0., sgd=None, losses=None): if not any(self.moves.has_gold(gold) for gold in golds): @@ -730,13 +738,17 @@ cdef class Parser: c_d_scores += d_scores.shape[1] return d_scores - def set_annotations(self, docs, states): + def set_annotations(self, docs, states, tensors=None): cdef StateClass state cdef Doc doc - for state, doc in zip(states, docs): + for i, (state, doc) in enumerate(zip(states, docs)): self.moves.finalize_state(state.c) - for i in range(doc.length): - doc.c[i] = state.c._sent[i] + for j in range(doc.length): + doc.c[j] = state.c._sent[j] + if tensors is not None: + print(doc.tensor.shape) + + doc.extend_tensor(tensors[i]) self.moves.finalize_doc(doc) for hook in self.postprocesses: for doc in docs: