diff --git a/spacy/_ml.py b/spacy/_ml.py index 5cb974910..bb228da3f 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -1,4 +1,4 @@ -from thinc.api import layerize, chain, clone, concatenate +from thinc.api import layerize, chain, clone, concatenate, with_flatten from thinc.neural import Model, Maxout, Softmax from thinc.neural._classes.hash_embed import HashEmbed @@ -10,88 +10,137 @@ from .attrs import ID, PREFIX, SUFFIX, SHAPE, TAG, DEP def get_col(idx): def forward(X, drop=0.): - return Model.ops.xp.ascontiguousarray(X[:, idx]), None + output = Model.ops.xp.ascontiguousarray(X[:, idx]) + return output, None return layerize(forward) def build_model(state2vec, width, depth, nr_class): with Model.define_operators({'>>': chain, '**': clone}): - model = state2vec >> Maxout(width) ** depth >> Softmax(nr_class) + model = ( + state2vec + >> Maxout(width, 1344) + >> Maxout(width, width) + >> Softmax(nr_class, width) + ) return model def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): - embed_tags = _reshape(chain(get_col(0), HashEmbed(width, nr_vector))) - embed_deps = _reshape(chain(get_col(1), HashEmbed(width, nr_vector))) + embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector))) + embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector))) ops = embed_tags.ops - attr_names = ops.asarray([TAG, DEP], dtype='i') - extract = build_feature_extractor(attr_names, nF, nB, nS, nL, nR) - def forward(states, drop=0.): - tokens, attr_vals, tokvecs = extract(states) + def forward(tokens_attrs_vectors, drop=0.): + tokens, attr_vals, tokvecs = tokens_attrs_vectors tagvecs, bp_tagvecs = embed_deps.begin_update(attr_vals, drop=drop) depvecs, bp_depvecs = embed_tags.begin_update(attr_vals, drop=drop) - + orig_tokvecs_shape = tokvecs.shape tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] * tokvecs.shape[2])) - vector = ops.concatenate((tagvecs, depvecs, tokvecs)) - shapes = (tagvecs.shape, depvecs.shape, tokvecs.shape) + assert tagvecs.shape[0] == depvecs.shape[0] == tokvecs.shape[0], shapes + vector = ops.xp.hstack((tagvecs, depvecs, tokvecs)) + def backward(d_vector, sgd=None): - d_depvecs, d_tagvecs, d_tokvecs = ops.backprop_concatenate(d_vector, shapes) + d_tagvecs, d_depvecs, d_tokvecs = backprop_concatenate(d_vector, shapes) + assert d_tagvecs.shape == shapes[0], (d_tagvecs.shape, shapes) + assert d_depvecs.shape == shapes[1], (d_depvecs.shape, shapes) + assert d_tokvecs.shape == shapes[2], (d_tokvecs.shape, shapes) bp_tagvecs(d_tagvecs) bp_depvecs(d_depvecs) - d_tokvecs = d_tokvecs.reshape((len(states), tokens.shape[1], tokvecs.shape[2])) - return (d_tokvecs, tokens) + d_tokvecs = d_tokvecs.reshape(orig_tokvecs_shape) + + return (tokens, d_tokvecs) return vector, backward model = layerize(forward) model._layers = [embed_tags, embed_deps] return model -def build_feature_extractor(attr_names, nF, nB, nS, nL, nR): - def forward(states, drop=0.): - ops = model.ops - n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR) - vector_length = states[0].token_vector_length - tokens = ops.allocate((len(states), n_tokens), dtype='i') - features = ops.allocate((len(states), n_tokens, attr_names.shape[0]), dtype='i') - tokvecs = ops.allocate((len(states), n_tokens, vector_length), dtype='f') - for i, state in enumerate(states): - state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR) - state.set_attributes(features[i], tokens[i], attr_names) - state.set_token_vectors(tokvecs[i], tokens[i]) - def backward(d_features, sgd=None): - return d_features - return (tokens, features, tokvecs), backward - model = layerize(forward) - return model +def backprop_concatenate(gradient, shapes): + grads = [] + start = 0 + for shape in shapes: + end = start + shape[1] + grads.append(gradient[:, start : end]) + start = end + return grads def _reshape(layer): - def forward(X, drop=0.): - Xh = X.reshape((X.shape[0] * X.shape[1], X.shape[2])) - yh, bp_yh = layer.begin_update(Xh, drop=drop) - n = X.shape[0] - old_shape = X.shape - def backward(d_y, sgd=None): - d_yh = d_y.reshape((n, d_y.size / n)) - d_Xh = bp_yh(d_yh, sgd) - return d_Xh.reshape(old_shape) - return yh.reshape((n, yh.shape / n)), backward + '''Transforms input with shape + (states, tokens, features) + into input with shape: + (states * tokens, features) + So that it can be used with a token-wise feature extraction layer, e.g. + an embedding layer. The embedding layer outputs: + (states * tokens, ndim) + But we want to concatenate the vectors for the tokens, so we produce: + (states, tokens * ndim) + We then need to reverse the transforms to do the backward pass. Recall + the simple rule here: each layer is a map: + inputs -> (outputs, (d_outputs->d_inputs)) + So the shapes must match like this: + shape of forward input == shape of backward output + shape of backward input == shape of forward output + ''' + def forward(X__bfm, drop=0.): + b, f, m = X__bfm.shape + B = b*f + M = f*m + X__Bm = X__bfm.reshape((B, m)) + y__Bn, bp_yBn = layer.begin_update(X__Bm, drop=drop) + n = y__Bn.shape[1] + N = f * n + y__bN = y__Bn.reshape((b, N)) + def backward(dy__bN, sgd=None): + dy__Bn = dy__bN.reshape((B, n)) + dX__Bm = bp_yBn(dy__Bn, sgd) + if dX__Bm is None: + return None + else: + return dX__Bm.reshape((b, f, m)) + return y__bN, backward model = layerize(forward) model._layers.append(layer) return model -def build_tok2vec(lang, width, depth, embed_size, cols): + +@layerize +def flatten(seqs, drop=0.): + ops = Model.ops + def finish_update(d_X, sgd=None): + return d_X + X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs]) + return X, finish_update + + +def build_tok2vec(lang, width, depth=2, embed_size=1000): + cols = [ID, PREFIX, SUFFIX, SHAPE] with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}): - static = get_col(cols.index(ID)) >> StaticVectors(lang, width) + #static = get_col(cols.index(ID)) >> StaticVectors(lang, width) + lower = get_col(cols.index(ID)) >> HashEmbed(width, embed_size) prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size) suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size) shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size) tok2vec = ( - (static | prefix | suffix | shape) - >> Maxout(width, width*4) - >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) ** depth + doc2feats(cols) + >> with_flatten( + #(static | prefix | suffix | shape) + (lower | prefix | suffix | shape) + >> Maxout(width, width*4) + >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) + ) ) return tok2vec + + +def doc2feats(cols): + def forward(docs, drop=0.): + feats = [doc.to_array(cols) for doc in docs] + feats = [model.ops.asarray(f, dtype='uint64') for f in feats] + return feats, None + model = layerize(forward) + return model diff --git a/spacy/es/tag_map.py b/spacy/es/tag_map.py index dce29c921..3568e409e 100644 --- a/spacy/es/tag_map.py +++ b/spacy/es/tag_map.py @@ -304,5 +304,24 @@ TAG_MAP = { "VERB__VerbForm=Ger": {"morph": "VerbForm=Ger", "pos": "VERB"}, "VERB__VerbForm=Inf": {"morph": "VerbForm=Inf", "pos": "VERB"}, "X___": {"morph": "_", "pos": "X"}, - "SP": {"morph": "_", "pos": "SPACE"} + "SP": {"morph": "_", "pos": "SPACE"}, + "ADV": {POS: ADV}, + "NOUN": {POS: NOUN}, + "ADP": {POS: ADP}, + "PRON": {POS: PRON}, + "SCONJ": {POS: SCONJ}, + "PROPN": {POS: PROPN}, + "DET": {POS: DET}, + "SYM": {POS: SYM}, + "INTJ": {POS: INTJ}, + "PUNCT": {POS: PUNCT}, + "NUM": {POS: NUM}, + "AUX": {POS: AUX}, + "X": {POS: X}, + "CONJ": {POS: CONJ}, + "CCONJ": {POS: CCONJ}, # U20 + "ADJ": {POS: ADJ}, + "VERB": {POS: VERB}, + "PART": {POS: PART}, + "_": {POS: PUNCT} } diff --git a/spacy/pipeline.pxd b/spacy/pipeline.pxd index b08f18f33..e9b7f0f73 100644 --- a/spacy/pipeline.pxd +++ b/spacy/pipeline.pxd @@ -1,5 +1,5 @@ from .syntax.parser cimport Parser -from .syntax.beam_parser cimport BeamParser +#from .syntax.beam_parser cimport BeamParser from .syntax.ner cimport BiluoPushDown from .syntax.arc_eager cimport ArcEager from .tagger cimport Tagger @@ -13,9 +13,9 @@ cdef class DependencyParser(Parser): pass -cdef class BeamEntityRecognizer(BeamParser): - pass - - -cdef class BeamDependencyParser(BeamParser): - pass +#cdef class BeamEntityRecognizer(BeamParser): +# pass +# +# +#cdef class BeamDependencyParser(BeamParser): +# pass diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 147746a27..61c71c2bb 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -1,11 +1,15 @@ # coding: utf8 from __future__ import unicode_literals +from thinc.api import chain, layerize, with_getitem +from thinc.neural import Model, Softmax + from .syntax.parser cimport Parser -from .syntax.beam_parser cimport BeamParser +#from .syntax.beam_parser cimport BeamParser from .syntax.ner cimport BiluoPushDown from .syntax.arc_eager cimport ArcEager from .tagger import Tagger +from ._ml import build_tok2vec # TODO: The disorganization here is pretty embarrassing. At least it's only # internals. @@ -13,6 +17,39 @@ from .syntax.parser import get_templates as get_feature_templates from .attrs import DEP, ENT_TYPE +class TokenVectorEncoder(object): + '''Assign position-sensitive vectors to tokens, using a CNN or RNN.''' + def __init__(self, vocab, **cfg): + self.vocab = vocab + self.model = build_tok2vec(vocab.lang, 64, **cfg) + self.tagger = chain( + self.model, + Softmax(self.vocab.morphology.n_tags)) + + def __call__(self, doc): + doc.tensor = self.model([doc])[0] + + def begin_update(self, docs, drop=0.): + tensors, bp_tensors = self.model.begin_update(docs, drop=drop) + for i, doc in enumerate(docs): + doc.tensor = tensors[i] + return tensors, bp_tensors + + def update(self, docs, golds, drop=0., sgd=None): + scores, finish_update = self.tagger.begin_update(docs, drop=drop) + losses = scores.copy() + loss = 0.0 + idx = 0 + for i, gold in enumerate(golds): + for j, tag in enumerate(gold.tags): + tag_id = docs[0].vocab.morphology.tag_names.index(tag) + losses[idx, tag_id] -= 1.0 + loss += 1-scores[idx, tag_id] + idx += 1 + finish_update(losses, sgd) + return loss + + cdef class EntityRecognizer(Parser): """ Annotate named entities on Doc objects. @@ -31,25 +68,25 @@ cdef class EntityRecognizer(Parser): freqs.append([label, 1]) self.vocab._serializer = None - -cdef class BeamEntityRecognizer(BeamParser): - """ - Annotate named entities on Doc objects. - """ - TransitionSystem = BiluoPushDown - - feature_templates = get_feature_templates('ner') - - def add_label(self, label): - Parser.add_label(self, label) - if isinstance(label, basestring): - label = self.vocab.strings[label] - # Set label into serializer. Super hacky :( - for attr, freqs in self.vocab.serializer_freqs: - if attr == ENT_TYPE and label not in freqs: - freqs.append([label, 1]) - self.vocab._serializer = None - +# +#cdef class BeamEntityRecognizer(BeamParser): +# """ +# Annotate named entities on Doc objects. +# """ +# TransitionSystem = BiluoPushDown +# +# feature_templates = get_feature_templates('ner') +# +# def add_label(self, label): +# Parser.add_label(self, label) +# if isinstance(label, basestring): +# label = self.vocab.strings[label] +# # Set label into serializer. Super hacky :( +# for attr, freqs in self.vocab.serializer_freqs: +# if attr == ENT_TYPE and label not in freqs: +# freqs.append([label, 1]) +# self.vocab._serializer = None +# cdef class DependencyParser(Parser): TransitionSystem = ArcEager @@ -66,21 +103,22 @@ cdef class DependencyParser(Parser): # Super hacky :( self.vocab._serializer = None +# +#cdef class BeamDependencyParser(BeamParser): +# TransitionSystem = ArcEager +# +# feature_templates = get_feature_templates('basic') +# +# def add_label(self, label): +# Parser.add_label(self, label) +# if isinstance(label, basestring): +# label = self.vocab.strings[label] +# for attr, freqs in self.vocab.serializer_freqs: +# if attr == DEP and label not in freqs: +# freqs.append([label, 1]) +# # Super hacky :( +# self.vocab._serializer = None +# -cdef class BeamDependencyParser(BeamParser): - TransitionSystem = ArcEager - - feature_templates = get_feature_templates('basic') - - def add_label(self, label): - Parser.add_label(self, label) - if isinstance(label, basestring): - label = self.vocab.strings[label] - for attr, freqs in self.vocab.serializer_freqs: - if attr == DEP and label not in freqs: - freqs.append([label, 1]) - # Super hacky :( - self.vocab._serializer = None - - -__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer] +#__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer] +__all__ = [Tagger, DependencyParser, EntityRecognizer] diff --git a/spacy/syntax/beam_parser.pxd b/spacy/syntax/beam_parser.pxd index 35a60cbf3..81b33696f 100644 --- a/spacy/syntax/beam_parser.pxd +++ b/spacy/syntax/beam_parser.pxd @@ -3,8 +3,8 @@ from ..structs cimport TokenC from thinc.typedefs cimport weight_t -cdef class BeamParser(Parser): - cdef public int beam_width - cdef public weight_t beam_density - - cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1 +#cdef class BeamParser(Parser): +# cdef public int beam_width +# cdef public weight_t beam_density +# +# #cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1 diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index e96e28fcf..9f509efe8 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -56,130 +56,130 @@ def get_templates(name): cdef int BEAM_WIDTH = 16 cdef weight_t BEAM_DENSITY = 0.001 -cdef class BeamParser(Parser): - def __init__(self, *args, **kwargs): - self.beam_width = kwargs.get('beam_width', BEAM_WIDTH) - self.beam_density = kwargs.get('beam_density', BEAM_DENSITY) - Parser.__init__(self, *args, **kwargs) - - cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil: - with gil: - self._parseC(tokens, length, nr_feat, self.moves.n_moves) - - cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1: - cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) - # TODO: How do we handle new labels here? This increases nr_class - beam.initialize(self.moves.init_beam_state, length, tokens) - beam.check_done(_check_final_state, NULL) - if beam.is_done: - _cleanup(beam) - return 0 - while not beam.is_done: - self._advance_beam(beam, None, False) - state = beam.at(0) - self.moves.finalize_state(state.c) - for i in range(length): - tokens[i] = state.c._sent[i] - _cleanup(beam) - - def update(self, Doc tokens, GoldParse gold_parse, itn=0): - self.moves.preprocess_gold(gold_parse) - cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) - pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c) - pred.check_done(_check_final_state, NULL) - # Hack for NER - for i in range(pred.size): - stcls = pred.at(i) - self.moves.initialize_state(stcls.c) - - cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) - gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c) - gold.check_done(_check_final_state, NULL) - violn = MaxViolation() - while not pred.is_done and not gold.is_done: - # We search separately here, to allow for ambiguity in the gold parse. - self._advance_beam(pred, gold_parse, False) - self._advance_beam(gold, gold_parse, True) - violn.check_crf(pred, gold) - if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): - break - else: - # The non-monotonic oracle makes it difficult to ensure final costs are - # correct. Therefore do final correction - for i in range(pred.size): - if is_gold(pred.at(i), gold_parse, self.moves.strings): - pred._states[i].loss = 0.0 - elif pred._states[i].loss == 0.0: - pred._states[i].loss = 1.0 - violn.check_crf(pred, gold) - if pred.size < 1: - raise Exception("No candidates", tokens.length) - if gold.size < 1: - raise Exception("No gold", tokens.length) - if pred.loss == 0: - self.model.update_from_histories(self.moves, tokens, [(0.0, [])]) - elif True: - #_check_train_integrity(pred, gold, gold_parse, self.moves) - histories = list(zip(violn.p_probs, violn.p_hist)) + \ - list(zip(violn.g_probs, violn.g_hist)) - self.model.update_from_histories(self.moves, tokens, histories, min_grad=0.001**(itn+1)) - else: - self.model.update_from_histories(self.moves, tokens, - [(1.0, violn.p_hist[0]), (-1.0, violn.g_hist[0])]) - _cleanup(pred) - _cleanup(gold) - return pred.loss - - def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): - cdef atom_t[CONTEXT_SIZE] context - cdef Pool mem = Pool() - features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) - if False: - mb = Minibatch(self.model.widths, beam.size) - for i in range(beam.size): - stcls = beam.at(i) - if stcls.c.is_final(): - nr_feat = 0 - else: - nr_feat = self.model.set_featuresC(context, features, stcls.c) - self.moves.set_valid(beam.is_valid[i], stcls.c) - mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) - self.model(mb) - for i in range(beam.size): - memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) - else: - for i in range(beam.size): - stcls = beam.at(i) - if not stcls.is_final(): - nr_feat = self.model.set_featuresC(context, features, stcls.c) - self.moves.set_valid(beam.is_valid[i], stcls.c) - self.model.set_scoresC(beam.scores[i], features, nr_feat) - if gold is not None: - n_gold = 0 - lines = [] - for i in range(beam.size): - stcls = beam.at(i) - if not stcls.c.is_final(): - self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) - if follow_gold: - for j in range(self.moves.n_moves): - if beam.costs[i][j] >= 1: - beam.is_valid[i][j] = 0 - lines.append((stcls.B(0), stcls.B(1), - stcls.B_(0).ent_iob, stcls.B_(1).ent_iob, - stcls.B_(1).sent_start, - j, - beam.is_valid[i][j], 'set invalid', - beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label)) - n_gold += 1 if beam.is_valid[i][j] else 0 - if follow_gold and n_gold == 0: - raise Exception("No gold") - if follow_gold: - beam.advance(_transition_state, NULL, self.moves.c) - else: - beam.advance(_transition_state, _hash_state, self.moves.c) - beam.check_done(_check_final_state, NULL) - +#cdef class BeamParser(Parser): +# def __init__(self, *args, **kwargs): +# self.beam_width = kwargs.get('beam_width', BEAM_WIDTH) +# self.beam_density = kwargs.get('beam_density', BEAM_DENSITY) +# Parser.__init__(self, *args, **kwargs) +# +# #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil: +# # with gil: +# # self._parseC(tokens, length, nr_feat, self.moves.n_moves) +# +# #cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1: +# # cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) +# # # TODO: How do we handle new labels here? This increases nr_class +# # beam.initialize(self.moves.init_beam_state, length, tokens) +# # beam.check_done(_check_final_state, NULL) +# # if beam.is_done: +# # _cleanup(beam) +# # return 0 +# # while not beam.is_done: +# # self._advance_beam(beam, None, False) +# # state = beam.at(0) +# # self.moves.finalize_state(state.c) +# # for i in range(length): +# # tokens[i] = state.c._sent[i] +# # _cleanup(beam) +# +# def update(self, Doc tokens, GoldParse gold_parse, itn=0): +# self.moves.preprocess_gold(gold_parse) +# cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) +# pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c) +# pred.check_done(_check_final_state, NULL) +# # Hack for NER +# for i in range(pred.size): +# stcls = pred.at(i) +# self.moves.initialize_state(stcls.c) +# +# cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) +# gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c) +# gold.check_done(_check_final_state, NULL) +# violn = MaxViolation() +# while not pred.is_done and not gold.is_done: +# # We search separately here, to allow for ambiguity in the gold parse. +# self._advance_beam(pred, gold_parse, False) +# self._advance_beam(gold, gold_parse, True) +# violn.check_crf(pred, gold) +# if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): +# break +# else: +# # The non-monotonic oracle makes it difficult to ensure final costs are +# # correct. Therefore do final correction +# for i in range(pred.size): +# if is_gold(pred.at(i), gold_parse, self.moves.strings): +# pred._states[i].loss = 0.0 +# elif pred._states[i].loss == 0.0: +# pred._states[i].loss = 1.0 +# violn.check_crf(pred, gold) +# if pred.size < 1: +# raise Exception("No candidates", tokens.length) +# if gold.size < 1: +# raise Exception("No gold", tokens.length) +# if pred.loss == 0: +# self.model.update_from_histories(self.moves, tokens, [(0.0, [])]) +# elif True: +# #_check_train_integrity(pred, gold, gold_parse, self.moves) +# histories = list(zip(violn.p_probs, violn.p_hist)) + \ +# list(zip(violn.g_probs, violn.g_hist)) +# self.model.update_from_histories(self.moves, tokens, histories, min_grad=0.001**(itn+1)) +# else: +# self.model.update_from_histories(self.moves, tokens, +# [(1.0, violn.p_hist[0]), (-1.0, violn.g_hist[0])]) +# _cleanup(pred) +# _cleanup(gold) +# return pred.loss +# +# def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): +# cdef atom_t[CONTEXT_SIZE] context +# cdef Pool mem = Pool() +# features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) +# if False: +# mb = Minibatch(self.model.widths, beam.size) +# for i in range(beam.size): +# stcls = beam.at(i) +# if stcls.c.is_final(): +# nr_feat = 0 +# else: +# nr_feat = self.model.set_featuresC(context, features, stcls.c) +# self.moves.set_valid(beam.is_valid[i], stcls.c) +# mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) +# self.model(mb) +# for i in range(beam.size): +# memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) +# else: +# for i in range(beam.size): +# stcls = beam.at(i) +# if not stcls.is_final(): +# nr_feat = self.model.set_featuresC(context, features, stcls.c) +# self.moves.set_valid(beam.is_valid[i], stcls.c) +# self.model.set_scoresC(beam.scores[i], features, nr_feat) +# if gold is not None: +# n_gold = 0 +# lines = [] +# for i in range(beam.size): +# stcls = beam.at(i) +# if not stcls.c.is_final(): +# self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) +# if follow_gold: +# for j in range(self.moves.n_moves): +# if beam.costs[i][j] >= 1: +# beam.is_valid[i][j] = 0 +# lines.append((stcls.B(0), stcls.B(1), +# stcls.B_(0).ent_iob, stcls.B_(1).ent_iob, +# stcls.B_(1).sent_start, +# j, +# beam.is_valid[i][j], 'set invalid', +# beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label)) +# n_gold += 1 if beam.is_valid[i][j] else 0 +# if follow_gold and n_gold == 0: +# raise Exception("No gold") +# if follow_gold: +# beam.advance(_transition_state, NULL, self.moves.c) +# else: +# beam.advance(_transition_state, _hash_state, self.moves.c) +# beam.check_done(_check_final_state, NULL) +# # These are passed as callbacks to thinc.search.Beam cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index d97a2f519..67643617a 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -40,6 +40,9 @@ from ..structs cimport TokenC from ..tokens.doc cimport Doc from ..strings cimport StringStore from ..gold cimport GoldParse +from ..attrs cimport TAG, DEP + +from .._ml import build_parser_state2vec, build_model USE_FTRL = True @@ -107,6 +110,11 @@ cdef class Parser: def __reduce__(self): return (Parser, (self.vocab, self.moves, self.model), None, None) + def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): + state2vec = build_parser_state2vec(width, nr_vector, nF, nB, nL, nR) + model = build_model(state2vec, width, 2, self.moves.n_moves) + return model + def __call__(self, Doc tokens): """ Apply the parser or entity recognizer, setting the annotations onto the Doc object. @@ -118,25 +126,7 @@ cdef class Parser: """ self.parse_batch([tokens]) self.moves.finalize_doc(tokens) - - def parse_batch(self, docs): - states = self._init_states(docs) - nr_class = self.moves.n_moves - cdef StateClass state - cdef int guess - is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i') - todo = list(states) - while todo: - scores = self.model.predict(todo) - self._validate_batch(is_valid, states) - scores *= is_valid - for state, guess in zip(todo, scores.argmax(axis=1)): - action = self.moves.c[guess] - action.do(state.c, action.label) - todo = [state for state in todo if not state.is_final()] - for state, doc in zip(states, docs): - self.moves.finalize_state(state.c) - + def pipe(self, stream, int batch_size=1000, int n_threads=2): """ Process a stream of documents. @@ -170,53 +160,106 @@ cdef class Parser: self.moves.finalize_doc(doc) yield doc + def parse_batch(self, docs): + states = self._init_states(docs) + nr_class = self.moves.n_moves + cdef Doc doc + cdef StateClass state + cdef int guess + is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i') + tokvecs = [d.tensor for d in docs] + attr_names = self.model.ops.allocate((2,), dtype='i') + attr_names[0] = TAG + attr_names[1] = DEP + all_states = list(states) + todo = zip(states, tokvecs) + while todo: + states, tokvecs = zip(*todo) + features = self._get_features(states, tokvecs, attr_names) + scores = self.model.predict(features) + self._validate_batch(is_valid, states) + scores *= is_valid + for state, guess in zip(states, scores.argmax(axis=1)): + action = self.moves.c[guess] + action.do(state.c, action.label) + todo = filter(lambda sp: not sp[0].is_final(), todo) + for state, doc in zip(all_states, docs): + self.moves.finalize_state(state.c) + for i in range(doc.length): + doc.c[i] = state.c._sent[i] + + def update(self, docs, golds, drop=0., sgd=None): if isinstance(docs, Doc) and isinstance(golds, GoldParse): return self.update([docs], [golds], drop=drop) + for gold in golds: + self.moves.preprocess_gold(gold) states = self._init_states(docs) + tokvecs = [d.tensor for d in docs] d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] nr_class = self.moves.n_moves costs = self.model.ops.allocate((len(docs), nr_class), dtype='f') + gradients = self.model.ops.allocate((len(docs), nr_class), dtype='f') is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i') + attr_names = self.model.ops.allocate((2,), dtype='i') + attr_names[0] = TAG + attr_names[1] = DEP + output = list(d_tokens) + todo = zip(states, tokvecs, golds, d_tokens) + assert len(states) == len(todo) + loss = 0. + while todo: + states, tokvecs, golds, d_tokens = zip(*todo) + features = self._get_features(states, tokvecs, attr_names) - todo = zip(states, golds, d_tokens) - while states: - states, golds, d_tokens = zip(*todo) - scores, finish_update = self.model.begin_update(states, drop=drop) - - self._cost_batch(is_valid, costs, states, golds) + scores, finish_update = self.model.begin_update(features, drop=drop) + assert scores.shape == (len(states), self.moves.n_moves), (len(states), scores.shape) + + self._cost_batch(costs, is_valid, states, golds) scores *= is_valid self._set_gradient(gradients, scores, costs) + loss += numpy.abs(gradients).sum() / gradients.shape[0] token_ids, batch_token_grads = finish_update(gradients, sgd=sgd) for i, tok_i in enumerate(token_ids): - d_tokens[tok_i] += batch_token_grads[i] + d_tokens[i][tok_i] += batch_token_grads[i] self._transition_batch(states, scores) # Get unfinished states (and their matching gold and token gradients) - todo = zip(states, golds, d_tokens) - todo = filter(todo, lambda sp: sp[0].is_final) - - gradients = gradients[:len(todo)] + todo = filter(lambda sp: not sp[0].is_final(), todo) costs = costs[:len(todo)] is_valid = is_valid[:len(todo)] + gradients = gradients[:len(todo)] gradients.fill(0) costs.fill(0) is_valid.fill(1) - return 0 + return output, loss def _init_states(self, docs): states = [] cdef Doc doc cdef StateClass state for i, doc in enumerate(docs): - state = StateClass(doc) + state = StateClass.init(doc.c, doc.length) self.moves.initialize_state(state.c) states.append(state) return states + def _get_features(self, states, all_tokvecs, attr_names, + nF=1, nB=0, nS=2, nL=2, nR=2): + n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR) + vector_length = all_tokvecs[0].shape[1] + tokens = self.model.ops.allocate((len(states), n_tokens), dtype='int32') + features = self.model.ops.allocate((len(states), n_tokens, attr_names.shape[0]), dtype='uint64') + tokvecs = self.model.ops.allocate((len(states), n_tokens, vector_length), dtype='f') + for i, state in enumerate(states): + state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR) + state.set_attributes(features[i], tokens[i], attr_names) + state.set_token_vectors(tokvecs[i], all_tokvecs[i], tokens[i]) + return (tokens, features, tokvecs) + def _validate_batch(self, int[:, ::1] is_valid, states): cdef StateClass state cdef int i @@ -242,13 +285,13 @@ cdef class Parser: """Do multi-label log loss""" cdef double Z, gZ, max_, g_max g_scores = scores * (costs <= 0) - maxes = scores.max(axis=1) - g_maxes = g_scores.max(axis=1) - exps = (scores-maxes).exp() - g_exps = (g_scores-g_maxes).exp() + maxes = scores.max(axis=1).reshape((scores.shape[0], 1)) + g_maxes = g_scores.max(axis=1).reshape((g_scores.shape[0], 1)) + exps = numpy.exp((scores-maxes)) + g_exps = numpy.exp(g_scores-g_maxes) - Zs = exps.sum(axis=1) - gZs = g_exps.sum(axis=1) + Zs = exps.sum(axis=1).reshape((exps.shape[0], 1)) + gZs = g_exps.sum(axis=1).reshape((g_exps.shape[0], 1)) logprob = exps / Zs g_logprob = g_exps / gZs gradients[:] = logprob - g_logprob diff --git a/spacy/syntax/stateclass.pxd b/spacy/syntax/stateclass.pxd index f47ceff43..0eaf1fe55 100644 --- a/spacy/syntax/stateclass.pxd +++ b/spacy/syntax/stateclass.pxd @@ -1,6 +1,7 @@ from libc.string cimport memcpy, memset from cymem.cymem cimport Pool +cimport cython from ..structs cimport TokenC, Entity @@ -8,7 +9,7 @@ from ..vocab cimport EMPTY_LEXEME from ._state cimport StateC - +@cython.final cdef class StateClass: cdef Pool mem cdef StateC* c diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 541df2509..5dedd1501 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -1,14 +1,17 @@ # coding: utf-8 +# cython: infer_types=True from __future__ import unicode_literals from libc.string cimport memcpy, memset -from libc.stdint cimport uint32_t +from libc.stdint cimport uint32_t, uint64_t from ..vocab cimport EMPTY_LEXEME from ..structs cimport Entity from ..lexeme cimport Lexeme from ..symbols cimport punct from ..attrs cimport IS_SPACE +from ..attrs cimport attr_id_t +from ..tokens.token cimport Token cdef class StateClass: @@ -27,6 +30,13 @@ cdef class StateClass: def queue(self): return {self.B(i) for i in range(self.c.buffer_length())} + @property + def token_vector_lenth(self): + return self.doc.tensor.shape[1] + + def is_final(self): + return self.c.is_final() + def print_state(self, words): words = list(words) + ['_'] top = words[self.S(0)] + '_%d' % self.S_(0).head @@ -35,3 +45,33 @@ cdef class StateClass: n0 = words[self.B(0)] n1 = words[self.B(1)] return ' '.join((third, second, top, '|', n0, n1)) + + def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR): + return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR) + + def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, + nL=2, nR=2): + output[0] = self.B(0) + output[1] = self.S(0) + output[2] = self.S(1) + output[3] = self.L(self.S(0), 1) + output[4] = self.L(self.S(0), 2) + output[5] = self.R(self.S(0), 1) + output[6] = self.R(self.S(0), 2) + output[7] = self.L(self.S(1), 1) + output[8] = self.L(self.S(1), 2) + output[9] = self.R(self.S(1), 1) + output[10] = self.R(self.S(1), 2) + + def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): + cdef int i, j, tok_i + for i in range(tokens.shape[0]): + tok_i = tokens[i] + token = &self.c._sent[tok_i] + for j in range(names.shape[0]): + vals[i, j] = Token.get_struct_attr(token, names[j]) + + def set_token_vectors(self, float[:, :] tokvecs, + float[:, :] all_tokvecs, int[:] indices): + for i in range(indices.shape[0]): + tokvecs[i] = all_tokvecs[indices[i]]