Add deprojectivize to pipeline

This commit is contained in:
Matthew Honnibal 2017-05-21 18:43:31 -05:00
parent 1b5fa68996
commit 5738d373d5

View File

@ -93,10 +93,12 @@ class BaseDefaults(object):
factories = { factories = {
'make_doc': create_tokenizer, 'make_doc': create_tokenizer,
'token_vectors': lambda nlp, **cfg: TokenVectorEncoder(nlp.vocab, **cfg), 'token_vectors': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)],
'tags': lambda nlp, **cfg: NeuralTagger(nlp.vocab, **cfg), 'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
'dependencies': lambda nlp, **cfg: NeuralDependencyParser(nlp.vocab, **cfg), 'dependencies': lambda nlp, **cfg: [
'entities': lambda nlp, **cfg: NeuralEntityRecognizer(nlp.vocab, **cfg), NeuralDependencyParser(nlp.vocab, **cfg),
PseudoProjectivity.deprojectivize],
'entities': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)],
} }
token_match = TOKEN_MATCH token_match = TOKEN_MATCH
@ -162,6 +164,13 @@ class Language(object):
self.pipeline[i] = factory(self, **meta.get(entry, {})) self.pipeline[i] = factory(self, **meta.get(entry, {}))
else: else:
self.pipeline = [] self.pipeline = []
flat_list = []
for pipe in self.pipeline:
if isinstance(pipe, list):
flat_list.extend(pipe)
else:
flat_list.append(pipe)
self.pipeline = flat_list
def __call__(self, text, **disabled): def __call__(self, text, **disabled):
"""'Apply the pipeline to some text. The text can span multiple sentences, """'Apply the pipeline to some text. The text can span multiple sentences,
@ -207,6 +216,8 @@ class Language(object):
tok2vec = self.pipeline[0] tok2vec = self.pipeline[0]
feats = tok2vec.doc2feats(docs) feats = tok2vec.doc2feats(docs)
for proc in self.pipeline[1:]: for proc in self.pipeline[1:]:
if not hasattr(proc, 'update'):
continue
grads = {} grads = {}
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
d_tokvecses = proc.update((docs, tokvecses), golds, sgd=get_grads, drop=drop) d_tokvecses = proc.update((docs, tokvecses), golds, sgd=get_grads, drop=drop)
@ -326,7 +337,8 @@ class Language(object):
if hasattr(proc, 'pipe'): if hasattr(proc, 'pipe'):
docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size) docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
else: else:
docs = (proc(doc) for doc in docs) # Apply the function, but yield the doc
docs = _pipe(proc, docs)
for doc in docs: for doc in docs:
yield doc yield doc
@ -402,3 +414,8 @@ class Language(object):
if key not in exclude: if key not in exclude:
setattr(self, key, value) setattr(self, key, value)
return self return self
def _pipe(func, docs):
for doc in docs:
func(doc)
yield doc