mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 01:02:23 +03:00
Fix update when update_shared=False
This commit is contained in:
parent
7a6edeea68
commit
78a5f842e9
|
@ -245,7 +245,7 @@ class Language(object):
|
||||||
def matcher(self):
|
def matcher(self):
|
||||||
return self.get_component('matcher')
|
return self.get_component('matcher')
|
||||||
|
|
||||||
def get_component(self, name):
|
def get_component(self, name):
|
||||||
if self.pipeline in (True, None):
|
if self.pipeline in (True, None):
|
||||||
return None
|
return None
|
||||||
for proc in self.pipeline:
|
for proc in self.pipeline:
|
||||||
|
@ -322,8 +322,8 @@ class Language(object):
|
||||||
all_d_tokvecses[i] += d_tv
|
all_d_tokvecses[i] += d_tv
|
||||||
if update_shared and bp_tokvecses is not None:
|
if update_shared and bp_tokvecses is not None:
|
||||||
bp_tokvecses(all_d_tokvecses, sgd=sgd)
|
bp_tokvecses(all_d_tokvecses, sgd=sgd)
|
||||||
for key, (W, dW) in grads.items():
|
for key, (W, dW) in grads.items():
|
||||||
sgd(W, dW, key=key)
|
sgd(W, dW, key=key)
|
||||||
# Clear the tensor variable, to free GPU memory.
|
# Clear the tensor variable, to free GPU memory.
|
||||||
# If we don't do this, the memory leak gets pretty
|
# If we don't do this, the memory leak gets pretty
|
||||||
# bad, because we may be holding part of a batch.
|
# bad, because we may be holding part of a batch.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user