mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
avoid dropout at runtime (#6247)
This commit is contained in:
parent
86d648740f
commit
f8a1c1afd6
|
@ -1,6 +1,6 @@
|
||||||
# fmt: off
|
# fmt: off
|
||||||
__title__ = "spacy-nightly"
|
__title__ = "spacy-nightly"
|
||||||
__version__ = "3.0.0a40"
|
__version__ = "3.0.0a41"
|
||||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||||
__projects__ = "https://github.com/explosion/projects"
|
__projects__ = "https://github.com/explosion/projects"
|
||||||
|
|
|
@ -39,7 +39,6 @@ def forward(
|
||||||
key_attr = model.attrs["key_attr"]
|
key_attr = model.attrs["key_attr"]
|
||||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
V = cast(Floats2d, docs[0].vocab.vectors.data)
|
V = cast(Floats2d, docs[0].vocab.vectors.data)
|
||||||
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
|
||||||
rows = model.ops.flatten(
|
rows = model.ops.flatten(
|
||||||
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
||||||
)
|
)
|
||||||
|
@ -47,8 +46,11 @@ def forward(
|
||||||
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
||||||
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
|
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
|
||||||
)
|
)
|
||||||
if mask is not None:
|
mask = None
|
||||||
output.data *= mask
|
if is_train:
|
||||||
|
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
||||||
|
if mask is not None:
|
||||||
|
output.data *= mask
|
||||||
|
|
||||||
def backprop(d_output: Ragged) -> List[Doc]:
|
def backprop(d_output: Ragged) -> List[Doc]:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user