avoid dropout at runtime (#6247)

This commit is contained in:
Sofie Van Landeghem 2020-10-13 14:39:59 +02:00 committed by GitHub
parent 86d648740f
commit f8a1c1afd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 4 deletions

View File

@ -1,6 +1,6 @@
# fmt: off # fmt: off
__title__ = "spacy-nightly" __title__ = "spacy-nightly"
__version__ = "3.0.0a40" __version__ = "3.0.0a41"
__download_url__ = "https://github.com/explosion/spacy-models/releases/download" __download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json" __compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
__projects__ = "https://github.com/explosion/projects" __projects__ = "https://github.com/explosion/projects"

View File

@ -39,7 +39,6 @@ def forward(
key_attr = model.attrs["key_attr"] key_attr = model.attrs["key_attr"]
W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
V = cast(Floats2d, docs[0].vocab.vectors.data) V = cast(Floats2d, docs[0].vocab.vectors.data)
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
rows = model.ops.flatten( rows = model.ops.flatten(
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
) )
@ -47,8 +46,11 @@ def forward(
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True), model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
model.ops.asarray([len(doc) for doc in docs], dtype="i"), model.ops.asarray([len(doc) for doc in docs], dtype="i"),
) )
if mask is not None: mask = None
output.data *= mask if is_train:
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
if mask is not None:
output.data *= mask
def backprop(d_output: Ragged) -> List[Doc]: def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None: if mask is not None: