* Add a .pipe method, that takes a stream of input, operates on it, and streams the output. Internally, the stream may be buffered, to allow multi-threading.

This commit is contained in:
Matthew Honnibal 2016-02-03 02:04:55 +01:00
parent fcfc17a164
commit 84b247ef83
4 changed files with 53 additions and 24 deletions

View File

@ -269,16 +269,24 @@ class Language(object):
self.entity(tokens)
return tokens
def batch(self, texts, tag=True, parse=True, entity=True):
if tag is False:
return [self(text, tag=tag, parse=parse, entity=entity)
for text in texts]
docs = []
for text in texts:
docs.append(self(text, tag=True, parse=False, entity=entity))
def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2,
batch_size=1000):
stream = self.tokenizer.stream(texts,
n_threads=n_threads, batch_size=batch_size)
if self.tagger and tag:
stream = self.tagger.stream(stream,
n_threads=n_threads, batch_size=batch_size)
if self.matcher and entity:
stream = self.matcher.stream(stream,
n_threads=n_threads, batch_size=batch_size)
if self.parser and parse:
self.parser.parse_batch(docs)
return docs
stream = self.parser.stream(stream,
n_threads=n_threads, batch_size=batch_size)
if self.entity and entity:
stream = self.entity.stream(stream,
n_threads=n_threads, batch_size=batch_size)
for doc in stream:
yield doc
def end_training(self, data_dir=None):
if data_dir is None:

View File

@ -250,6 +250,10 @@ cdef class Matcher:
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + filtered
return matches
def pipe(self, texts, batch_size=1000, n_threads=2):
for text in texts:
yield self(text)
cdef class PhraseMatcher:
cdef Pool mem
@ -303,6 +307,11 @@ cdef class PhraseMatcher:
doc.merge(*match)
return matches
def pipe(self, stream, batch_size=1000, n_threads=2):
for doc in stream:
self(doc)
yield doc
def accept_match(self, Doc doc, int label, int start, int end):
assert (end - start) < self.max_length
cdef int i, j

View File

@ -114,26 +114,33 @@ cdef class Parser:
# Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals()
def parse_batch(self, batch):
cdef TokenC** doc_ptr = <TokenC**>calloc(len(batch), sizeof(TokenC*))
cdef int* lengths = <int*>calloc(len(batch), sizeof(int))
def pipe(self, stream, int batch_size=1000, int n_threads=2):
cdef Pool mem = Pool()
cdef TokenC** doc_ptr = <TokenC**>mem.alloc(batch_size, sizeof(TokenC*))
cdef int* lengths = <int*>mem.alloc(batch_size, sizeof(int))
cdef Doc doc
cdef int i
for i, doc in enumerate(batch):
doc_ptr[i] = doc.c
lengths[i] = doc.length
cdef int nr_class = self.moves.n_moves
cdef int nr_feat = self.model.nr_feat
cdef int nr_doc = len(batch)
with nogil:
for i in range(nr_doc):
queue = []
for doc in stream:
queue.append(doc)
doc_ptr[len(queue)] = doc.c
lengths[len(queue)] = doc.length
if len(queue) == batch_size:
for i in cython.parallel.prange(batch_size, nogil=True,
num_threads=n_threads):
self.parseC(doc_ptr[i], lengths[i], nr_feat, nr_class)
for doc in batch:
doc.is_parsed = True
# Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals()
free(doc_ptr)
free(lengths)
for doc in queue:
yield doc
queue = []
batch_size = len(queue)
for i in range(batch_size):
self.parseC(doc_ptr[i], lengths[i], nr_feat, nr_class)
for doc in queue:
yield doc
PyErr_CheckSignals()
cdef void parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) nogil:
cdef ExampleC eg

View File

@ -213,6 +213,11 @@ cdef class Tagger:
tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length
def pipe(self, stream, batch_size=1000, n_threads=2):
for doc in stream:
self(doc)
yield doc
def train(self, Doc tokens, object gold_tag_strs):
assert len(tokens) == len(gold_tag_strs)
for tag in gold_tag_strs: