Remove the state argument from Language

This commit is contained in:
Matthew Honnibal 2017-05-19 13:25:42 -05:00
parent 09a877886b
commit 66ea9aebe7

View File

@ -145,7 +145,7 @@ class Language(object):
else:
self.pipeline = []
def __call__(self, text, state=None, **disabled):
def __call__(self, text, **disabled):
"""
Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string
@ -153,7 +153,6 @@ class Language(object):
Args:
text (unicode): The text to be processed.
state: Arbitrary
Returns:
doc (Doc): A container for accessing the annotations.
@ -170,22 +169,20 @@ class Language(object):
name = getattr(proc, 'name', None)
if name in disabled and not disabled[name]:
continue
state = proc(doc, state=state)
proc(doc)
return doc
def update(self, docs, golds, state=None, drop=0., sgd=None):
def update(self, docs, golds, drop=0., sgd=None):
grads = {}
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
state = {} if state is None else state
for process in self.pipeline:
if hasattr(process, 'update'):
state = process.update(docs, golds,
state=state,
drop=drop,
sgd=get_grads)
else:
process(docs, state=state)
tok2vec = self.pipeline[0]
feats = tok2vec.doc2feats(docs)
for proc in self.pipeline[1:]:
tokvecs, bp_tokvecs = tok2vec.model.begin_update(feats, drop=drop)
grads = {}
d_tokvecs = proc.update((docs, tokvecs), golds, sgd=get_grads, drop=drop)
bp_tokvecs(d_tokvecs, sgd=get_grads)
if sgd is not None:
for key, (W, dW) in grads.items():
# TODO: Unhack this when thinc improves
@ -194,7 +191,6 @@ class Language(object):
else:
sgd.ops = CupyOps()
sgd(W, dW, key=key)
return state
@contextmanager
def begin_training(self, gold_tuples, **cfg):
@ -248,18 +244,18 @@ class Language(object):
parse (bool)
entity (bool)
"""
#stream = ((self.make_doc(text), None) for text in texts)
stream = ((doc, {}) for doc in texts)
#docs = (self.make_doc(text) for text in texts)
docs = texts
for proc in self.pipeline:
name = getattr(proc, 'name', None)
if name in disabled and not disabled[name]:
continue
if hasattr(proc, 'pipe'):
stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size)
docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
else:
stream = (proc(doc, state) for doc, state in stream)
for doc, state in stream:
docs = (proc(doc) for doc in docs)
for doc in docs:
yield doc
def to_disk(self, path, **exclude):