mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Add update_tensors flag to Language.update. Experimental, re #1182
This commit is contained in:
parent
4cfb7a54e7
commit
cc19ea0e7c
|
@ -277,7 +277,8 @@ class Language(object):
|
|||
def make_doc(self, text):
|
||||
return self.tokenizer(text)
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None,
|
||||
update_tensors=False):
|
||||
"""Update the models in the pipeline.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
|
@ -310,7 +311,7 @@ class Language(object):
|
|||
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
||||
d_tokvecses = proc.update((docs, tokvecses), golds,
|
||||
drop=drop, sgd=get_grads, losses=losses)
|
||||
if d_tokvecses is not None:
|
||||
if update_tensors and d_tokvecses is not None:
|
||||
bp_tokvecses(d_tokvecses, sgd=sgd)
|
||||
for key, (W, dW) in grads.items():
|
||||
sgd(W, dW, key=key)
|
||||
|
|
Loading…
Reference in New Issue
Block a user