* Change Language class to use a .pipeline attribute, instead of having the pipeline hard coded

This commit is contained in:
Matthew Honnibal 2016-05-17 16:55:42 +02:00
parent 17137f5c0c
commit 4d7f5468bb

View File

@ -192,6 +192,13 @@ class Language(object):
if matcher in (None, True): if matcher in (None, True):
matcher = Matcher.from_package(package, self.vocab) matcher = Matcher.from_package(package, self.vocab)
self.matcher = matcher self.matcher = matcher
self.pipeline = [
self.tokenizer,
self.tagger,
self.entity,
self.parser,
self.matcher
]
def __reduce__(self): def __reduce__(self):
args = ( args = (
@ -222,37 +229,29 @@ class Language(object):
>>> tokens[0].orth_, tokens[0].head.tag_ >>> tokens[0].orth_, tokens[0].head.tag_
('An', 'NN') ('An', 'NN')
""" """
tokens = self.tokenizer(text) doc = self.pipeline[0](text)
if self.tagger and tag:
self.tagger(tokens)
if self.matcher and entity:
self.matcher(tokens)
if self.parser and parse:
self.parser(tokens)
if self.entity and entity: if self.entity and entity:
# Add any of the entity labels already set, in case we don't have them. # Add any of the entity labels already set, in case we don't have them.
for tok in tokens: for token in doc:
if tok.ent_type != 0: if token.ent_type != 0:
self.entity.add_label(tok.ent_type) self.entity.add_label(token.ent_type)
self.entity(tokens) skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
return tokens for proc in self.pipeline[1:]:
if proc and not skip.get(proc):
proc(doc)
return doc
def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2, def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2,
batch_size=1000): batch_size=1000):
stream = self.tokenizer.pipe(texts, skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
stream = self.pipeline[0].pipe(texts,
n_threads=n_threads, batch_size=batch_size) n_threads=n_threads, batch_size=batch_size)
if self.tagger and tag: for proc in self.pipeline[1:]:
stream = self.tagger.pipe(stream, if proc and not skip.get(proc):
n_threads=n_threads, batch_size=batch_size) if hasattr(proc, 'pipe'):
if self.matcher and entity: stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size)
stream = self.matcher.pipe(stream, else:
n_threads=n_threads, batch_size=batch_size) stream = (proc(item) for item in stream)
if self.parser and parse:
stream = self.parser.pipe(stream,
n_threads=n_threads, batch_size=batch_size)
if self.entity and entity:
stream = self.entity.pipe(stream,
n_threads=1, batch_size=batch_size)
for doc in stream: for doc in stream:
yield doc yield doc