Fix tagging model

This commit is contained in:
Matthew Honnibal 2017-08-06 01:50:08 +02:00
parent 468c138ab3
commit e9ab800e15
2 changed files with 16 additions and 23 deletions

View File

@ -346,16 +346,16 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
def fine_tune(model1, combine=None): def fine_tune(model1, combine=None):
def fine_tune_fwd(docs, drop=0.): def fine_tune_fwd(docs_tokvecs, drop=0.):
docs, tokvecs = docs_tokvecs
lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i')
X1, bp_X1 = model1.begin_update(docs) X1, bp_X1 = model1.begin_update(docs)
lengths = [len(doc) for doc in docs]
X2 = model1.ops.flatten(X1)
def fine_tune_bwd(d_output, sgd=None): def fine_tune_bwd(d_output, sgd=None):
bp_X1(d_output, sgd=sgd) bp_X1(model1.ops.flatten(d_output), sgd=sgd)
return d_output return d_output
return (X1+X2, lengths), fine_tune_bwd return model1.ops.unflatten(X1+X2, lengths), fine_tune_bwd
model = wrap(fine_tune_fwd) model = wrap(fine_tune_fwd)
return model return model
@ -410,32 +410,23 @@ def preprocess_doc(docs, drop=0.):
def build_tagger_model(nr_class, token_vector_width, **cfg): def build_tagger_model(nr_class, token_vector_width, **cfg):
with Model.define_operators({'>>': chain, '+': add}): with Model.define_operators({'>>': chain, '+': add}):
# Input: (doc, tensor) tuples # Input: (doc, tensor) tuples
embed_docs = with_getitem(0, embed_docs = (
FeatureExtracter([NORM]) FeatureExtracter([NORM])
>> flatten
>> HashEmbed(token_vector_width, 1000) >> HashEmbed(token_vector_width, 1000)
>> flatten_add_lengths
) )
model = ( model = (
fine_tune(embed_docs) fine_tune(embed_docs)
>>
with_getitem(0,
FeatureExtracter([NORM])
>> HashEmbed(token_vector_width, 1000)
>> flatten_add_lengths
)
>> with_getitem(1,
flatten_add_lengths)
>> add_tuples
>> with_flatten( >> with_flatten(
Maxout(token_vector_width, token_vector_width) Maxout(token_vector_width, token_vector_width)
>> Softmax(nr_class, token_vector_width) >> Softmax(nr_class, token_vector_width)
) )
) )
model.nI = None
return model return model
def build_text_classifier(nr_class, width=64, **cfg): def build_text_classifier(nr_class, width=64, **cfg):
nr_vector = cfg.get('nr_vector', 200) nr_vector = cfg.get('nr_vector', 200)
with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}): with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):

View File

@ -253,23 +253,25 @@ class NeuralTagger(BaseThincComponent):
self.cfg = dict(cfg) self.cfg = dict(cfg)
def __call__(self, doc): def __call__(self, doc):
tags = self.predict([doc.tensor]) tags = self.predict(([doc], [doc.tensor]))
self.set_annotations([doc], tags) self.set_annotations([doc], tags)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, stream):
docs = list(docs)
tokvecs = [d.tensor for d in docs] tokvecs = [d.tensor for d in docs]
tag_ids = self.predict(tokvecs) tag_ids = self.predict((docs, tokvecs))
self.set_annotations(docs, tag_ids) self.set_annotations(docs, tag_ids)
yield from docs yield from docs
def predict(self, tokvecs): def predict(self, docs_tokvecs):
scores = self.model(tokvecs) scores = self.model(docs_tokvecs)
scores = self.model.ops.flatten(scores) scores = self.model.ops.flatten(scores)
guesses = scores.argmax(axis=1) guesses = scores.argmax(axis=1)
if not isinstance(guesses, numpy.ndarray): if not isinstance(guesses, numpy.ndarray):
guesses = guesses.get() guesses = guesses.get()
tokvecs = docs_tokvecs[1]
guesses = self.model.ops.unflatten(guesses, guesses = self.model.ops.unflatten(guesses,
[tv.shape[0] for tv in tokvecs]) [tv.shape[0] for tv in tokvecs])
return guesses return guesses
@ -295,7 +297,7 @@ class NeuralTagger(BaseThincComponent):
if self.model.nI is None: if self.model.nI is None:
self.model.nI = tokvecs[0].shape[1] self.model.nI = tokvecs[0].shape[1]
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop) tag_scores, bp_tag_scores = self.model.begin_update(docs_tokvecs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores) loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)