Remove the state argument from Language

This commit is contained in:
Matthew Honnibal 2017-05-19 13:25:42 -05:00
parent 09a877886b
commit 66ea9aebe7

View File

@ -145,7 +145,7 @@ class Language(object):
else: else:
self.pipeline = [] self.pipeline = []
def __call__(self, text, state=None, **disabled): def __call__(self, text, **disabled):
""" """
Apply the pipeline to some text. The text can span multiple sentences, Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string
@ -153,7 +153,6 @@ class Language(object):
Args: Args:
text (unicode): The text to be processed. text (unicode): The text to be processed.
state: Arbitrary
Returns: Returns:
doc (Doc): A container for accessing the annotations. doc (Doc): A container for accessing the annotations.
@ -170,31 +169,28 @@ class Language(object):
name = getattr(proc, 'name', None) name = getattr(proc, 'name', None)
if name in disabled and not disabled[name]: if name in disabled and not disabled[name]:
continue continue
state = proc(doc, state=state) proc(doc)
return doc return doc
def update(self, docs, golds, state=None, drop=0., sgd=None): def update(self, docs, golds, drop=0., sgd=None):
grads = {} grads = {}
def get_grads(W, dW, key=None): def get_grads(W, dW, key=None):
grads[key] = (W, dW) grads[key] = (W, dW)
state = {} if state is None else state tok2vec = self.pipeline[0]
for process in self.pipeline: feats = tok2vec.doc2feats(docs)
if hasattr(process, 'update'): for proc in self.pipeline[1:]:
state = process.update(docs, golds, tokvecs, bp_tokvecs = tok2vec.model.begin_update(feats, drop=drop)
state=state, grads = {}
drop=drop, d_tokvecs = proc.update((docs, tokvecs), golds, sgd=get_grads, drop=drop)
sgd=get_grads) bp_tokvecs(d_tokvecs, sgd=get_grads)
else: if sgd is not None:
process(docs, state=state) for key, (W, dW) in grads.items():
if sgd is not None: # TODO: Unhack this when thinc improves
for key, (W, dW) in grads.items(): if isinstance(W, numpy.ndarray):
# TODO: Unhack this when thinc improves sgd.ops = NumpyOps()
if isinstance(W, numpy.ndarray): else:
sgd.ops = NumpyOps() sgd.ops = CupyOps()
else: sgd(W, dW, key=key)
sgd.ops = CupyOps()
sgd(W, dW, key=key)
return state
@contextmanager @contextmanager
def begin_training(self, gold_tuples, **cfg): def begin_training(self, gold_tuples, **cfg):
@ -248,18 +244,18 @@ class Language(object):
parse (bool) parse (bool)
entity (bool) entity (bool)
""" """
#stream = ((self.make_doc(text), None) for text in texts) #docs = (self.make_doc(text) for text in texts)
stream = ((doc, {}) for doc in texts) docs = texts
for proc in self.pipeline: for proc in self.pipeline:
name = getattr(proc, 'name', None) name = getattr(proc, 'name', None)
if name in disabled and not disabled[name]: if name in disabled and not disabled[name]:
continue continue
if hasattr(proc, 'pipe'): if hasattr(proc, 'pipe'):
stream = proc.pipe(stream, n_threads=n_threads, batch_size=batch_size) docs = proc.pipe(docs, n_threads=n_threads, batch_size=batch_size)
else: else:
stream = (proc(doc, state) for doc, state in stream) docs = (proc(doc) for doc in docs)
for doc, state in stream: for doc in docs:
yield doc yield doc
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):