spaCy/spacy/ml/staticvectors.py

100 lines
3.1 KiB
Python
Raw Normal View History

2020-07-28 13:17:09 +03:00
from typing import List, Tuple, Callable, Optional, cast
from thinc.initializers import glorot_uniform_init
from thinc.util import partial
from thinc.types import Ragged, Floats2d, Floats1d
from thinc.api import Model, Ops, registry
from ..tokens import Doc
2020-10-04 12:16:31 +03:00
from ..errors import Errors
2020-07-28 13:17:09 +03:00
@registry.layers("spacy.StaticVectors.v1")
def StaticVectors(
nO: Optional[int] = None,
nM: Optional[int] = None,
*,
dropout: Optional[float] = None,
init_W: Callable = glorot_uniform_init,
2020-07-28 23:02:34 +03:00
key_attr: str = "ORTH"
2020-07-28 13:17:09 +03:00
) -> Model[List[Doc], Ragged]:
"""Embed Doc objects with their vocab's vectors table, applying a learned
linear projection to control the dimensionality. If a dropout rate is
specified, the dropout is applied per dimension over the whole batch.
"""
return Model(
"static_vectors",
forward,
init=partial(init, init_W),
params={"W": None},
attrs={"key_attr": key_attr, "dropout_rate": dropout},
dims={"nO": nO, "nM": nM},
)
def forward(
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
if not sum(len(doc) for doc in docs):
2020-07-28 13:17:09 +03:00
return _handle_empty(model.ops, model.get_dim("nO"))
key_attr = model.attrs["key_attr"]
2020-07-28 16:52:55 +03:00
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
2020-07-28 13:17:09 +03:00
V = cast(Floats2d, docs[0].vocab.vectors.data)
rows = model.ops.flatten(
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
)
try:
vectors_data = model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True)
except ValueError:
raise RuntimeError(Errors.E896)
2020-07-28 13:17:09 +03:00
output = Ragged(
vectors_data,
model.ops.asarray([len(doc) for doc in docs], dtype="i")
2020-07-28 13:17:09 +03:00
)
2020-10-13 15:39:59 +03:00
mask = None
if is_train:
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
if mask is not None:
output.data *= mask
2020-07-28 23:02:34 +03:00
2020-07-28 13:17:09 +03:00
def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None:
d_output.data *= mask
2020-07-28 16:52:55 +03:00
model.inc_grad(
"W",
2020-07-28 23:02:34 +03:00
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
2020-07-28 16:52:55 +03:00
)
2020-07-28 13:17:09 +03:00
return []
return output, backprop
def init(
init_W: Callable,
model: Model[List[Doc], Ragged],
X: Optional[List[Doc]] = None,
Y: Optional[Ragged] = None,
) -> Model[List[Doc], Ragged]:
nM = model.get_dim("nM") if model.has_dim("nM") else None
nO = model.get_dim("nO") if model.has_dim("nO") else None
if X is not None and len(X):
nM = X[0].vocab.vectors.data.shape[1]
if Y is not None:
nO = Y.data.shape[1]
2020-07-28 23:02:34 +03:00
2020-07-28 13:17:09 +03:00
if nM is None:
2020-10-04 12:16:31 +03:00
raise ValueError(Errors.E905)
2020-07-28 13:17:09 +03:00
if nO is None:
2020-10-04 12:16:31 +03:00
raise ValueError(Errors.E904)
2020-07-28 13:17:09 +03:00
model.set_dim("nM", nM)
model.set_dim("nO", nO)
model.set_param("W", init_W(model.ops, (nO, nM)))
return model
def _handle_empty(ops: Ops, nO: int):
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
return ops.get_dropout_mask((nO,), rate) if rate is not None else None