Break the tokenization stage out of the pipeline into a function 'make_doc'. This allows all pipeline methods to have the same signature.

This commit is contained in:
Matthew Honnibal 2016-10-14 17:38:29 +02:00
parent 2cc515b2ed
commit 6d8cb515ac

View File

@ -35,7 +35,6 @@ from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.nonproj import PseudoProjectivity
class BaseDefaults(object): class BaseDefaults(object):
def __init__(self, lang, path): def __init__(self, lang, path):
self.path = path self.path = path
@ -125,8 +124,11 @@ class BaseDefaults(object):
else: else:
return Matcher(vocab) return Matcher(vocab)
def MakeDoc(self, nlp, **cfg):
return nlp.tokenizer.__call__
def Pipeline(self, nlp, **cfg): def Pipeline(self, nlp, **cfg):
pipeline = [nlp.tokenizer] pipeline = []
if nlp.tagger: if nlp.tagger:
pipeline.append(nlp.tagger) pipeline.append(nlp.tagger)
if nlp.parser: if nlp.parser:
@ -265,6 +267,7 @@ class Language(object):
matcher=True, matcher=True,
serializer=True, serializer=True,
vectors=True, vectors=True,
make_doc=True,
pipeline=True, pipeline=True,
defaults=True, defaults=True,
data_dir=None): data_dir=None):
@ -303,6 +306,11 @@ class Language(object):
self.entity = entity if entity is not True else defaults.Entity(self.vocab) self.entity = entity if entity is not True else defaults.Entity(self.vocab)
self.parser = parser if parser is not True else defaults.Parser(self.vocab) self.parser = parser if parser is not True else defaults.Parser(self.vocab)
self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab) self.matcher = matcher if matcher is not True else defaults.Matcher(self.vocab)
if make_doc in (None, True, False):
self.make_doc = defaults.MakeDoc(self)
else:
self.make_doc = make_doc
if pipeline in (None, False): if pipeline in (None, False):
self.pipeline = [] self.pipeline = []
elif pipeline is True: elif pipeline is True:
@ -339,24 +347,22 @@ class Language(object):
>>> tokens[0].orth_, tokens[0].head.tag_ >>> tokens[0].orth_, tokens[0].head.tag_
('An', 'NN') ('An', 'NN')
""" """
doc = self.pipeline[0](text) doc = self.make_doc(text)
if self.entity and entity: if self.entity and entity:
# Add any of the entity labels already set, in case we don't have them. # Add any of the entity labels already set, in case we don't have them.
for token in doc: for token in doc:
if token.ent_type != 0: if token.ent_type != 0:
self.entity.add_label(token.ent_type) self.entity.add_label(token.ent_type)
skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity} skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
for proc in self.pipeline[1:]: for proc in self.pipeline:
if proc and not skip.get(proc): if proc and not skip.get(proc):
proc(doc) proc(doc)
return doc return doc
def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2, def pipe(self, texts, tag=True, parse=True, entity=True, n_threads=2, batch_size=1000):
batch_size=1000):
skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity} skip = {self.tagger: not tag, self.parser: not parse, self.entity: not entity}
stream = self.pipeline[0].pipe(texts, stream = (self.make_doc(text) for text in texts)
n_threads=n_threads, batch_size=batch_size) for proc in self.pipeline:
for proc in self.pipeline[1:]:
if proc and not skip.get(proc): if proc and not skip.get(proc):
if hasattr(proc, 'pipe'): if hasattr(proc, 'pipe'):
stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size) stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size)