Fix __call__ method

This commit is contained in:
Matthew Honnibal 2017-05-28 08:11:58 -05:00
parent 5cf47b847b
commit bc97bc292c
3 changed files with 9 additions and 8 deletions

View File

@ -192,7 +192,7 @@ class Language(object):
name = getattr(proc, 'name', None) name = getattr(proc, 'name', None)
if name in disable: if name in disable:
continue continue
proc(doc) doc = proc(doc)
return doc return doc
def update(self, docs, golds, drop=0., sgd=None, losses=None): def update(self, docs, golds, drop=0., sgd=None, losses=None):

View File

@ -73,17 +73,16 @@ class TokenVectorEncoder(object):
self.doc2feats = doc2feats() self.doc2feats = doc2feats()
self.model = model self.model = model
def __call__(self, docs): def __call__(self, doc):
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM """Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
model. Vectors are set to the `Doc.tensor` attribute. model. Vectors are set to the `Doc.tensor` attribute.
docs (Doc or iterable): One or more documents to add vectors to. docs (Doc or iterable): One or more documents to add vectors to.
RETURNS (dict or None): Intermediate computations. RETURNS (dict or None): Intermediate computations.
""" """
if isinstance(docs, Doc): tokvecses = self.predict([doc])
docs = [docs] self.set_annotations([doc], tokvecses)
tokvecses = self.predict(docs) return doc
self.set_annotations(docs, tokvecses)
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
"""Process `Doc` objects as a stream. """Process `Doc` objects as a stream.
@ -169,6 +168,7 @@ class NeuralTagger(object):
def __call__(self, doc): def __call__(self, doc):
tags = self.predict([doc.tensor]) tags = self.predict([doc.tensor])
self.set_annotations([doc], tags) self.set_annotations([doc], tags)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, stream):

View File

@ -305,8 +305,9 @@ cdef class Parser:
Returns: Returns:
None None
""" """
states = self.parse_batch([doc], doc.tensor) states = self.parse_batch([doc], [doc.tensor])
self.set_annotations(doc, states[0]) self.set_annotations([doc], states)
return doc
def pipe(self, docs, int batch_size=1000, int n_threads=2): def pipe(self, docs, int batch_size=1000, int n_threads=2):
""" """