mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
fb0da3e097
* Support custom token/lexeme attribute for vectors * Fix imports * Back off to ORTH without Vectors.attr * Fallback if vectors.attr doesn't exist * Update docs
123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
import warnings
|
|
from typing import Callable, List, Optional, Sequence, Tuple, cast
|
|
|
|
from thinc.api import Model, Ops, registry
|
|
from thinc.initializers import glorot_uniform_init
|
|
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
|
from thinc.util import partial
|
|
|
|
from ..attrs import ORTH
|
|
from ..errors import Errors, Warnings
|
|
from ..tokens import Doc
|
|
from ..vectors import Mode
|
|
from ..vocab import Vocab
|
|
|
|
|
|
@registry.layers("spacy.StaticVectors.v2")
|
|
def StaticVectors(
|
|
nO: Optional[int] = None,
|
|
nM: Optional[int] = None,
|
|
*,
|
|
dropout: Optional[float] = None,
|
|
init_W: Callable = glorot_uniform_init,
|
|
key_attr: str = "ORTH"
|
|
) -> Model[List[Doc], Ragged]:
|
|
"""Embed Doc objects with their vocab's vectors table, applying a learned
|
|
linear projection to control the dimensionality. If a dropout rate is
|
|
specified, the dropout is applied per dimension over the whole batch.
|
|
"""
|
|
if key_attr != "ORTH":
|
|
warnings.warn(Warnings.W125, DeprecationWarning)
|
|
return Model(
|
|
"static_vectors",
|
|
forward,
|
|
init=partial(init, init_W),
|
|
params={"W": None},
|
|
attrs={"key_attr": key_attr, "dropout_rate": dropout},
|
|
dims={"nO": nO, "nM": nM},
|
|
)
|
|
|
|
|
|
def forward(
|
|
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
|
) -> Tuple[Ragged, Callable]:
|
|
token_count = sum(len(doc) for doc in docs)
|
|
if not token_count:
|
|
return _handle_empty(model.ops, model.get_dim("nO"))
|
|
vocab: Vocab = docs[0].vocab
|
|
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
|
|
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
|
if vocab.vectors.mode == Mode.default:
|
|
V = model.ops.asarray(vocab.vectors.data)
|
|
rows = vocab.vectors.find(keys=keys)
|
|
V = model.ops.as_contig(V[rows])
|
|
elif vocab.vectors.mode == Mode.floret:
|
|
V = vocab.vectors.get_batch(keys)
|
|
V = model.ops.as_contig(V)
|
|
else:
|
|
raise RuntimeError(Errors.E896)
|
|
try:
|
|
vectors_data = model.ops.gemm(V, W, trans2=True)
|
|
except ValueError:
|
|
raise RuntimeError(Errors.E896)
|
|
if vocab.vectors.mode == Mode.default:
|
|
# Convert negative indices to 0-vectors
|
|
# TODO: more options for UNK tokens
|
|
vectors_data[rows < 0] = 0
|
|
output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
|
|
mask = None
|
|
if is_train:
|
|
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
|
if mask is not None:
|
|
output.data *= mask
|
|
|
|
def backprop(d_output: Ragged) -> List[Doc]:
|
|
if mask is not None:
|
|
d_output.data *= mask
|
|
model.inc_grad(
|
|
"W",
|
|
model.ops.gemm(
|
|
cast(Floats2d, d_output.data),
|
|
cast(Floats2d, model.ops.as_contig(V)),
|
|
trans1=True,
|
|
),
|
|
)
|
|
return []
|
|
|
|
return output, backprop
|
|
|
|
|
|
def init(
|
|
init_W: Callable,
|
|
model: Model[List[Doc], Ragged],
|
|
X: Optional[List[Doc]] = None,
|
|
Y: Optional[Ragged] = None,
|
|
) -> Model[List[Doc], Ragged]:
|
|
nM = model.get_dim("nM") if model.has_dim("nM") else None
|
|
nO = model.get_dim("nO") if model.has_dim("nO") else None
|
|
if X is not None and len(X):
|
|
nM = X[0].vocab.vectors.shape[1]
|
|
if Y is not None:
|
|
nO = Y.data.shape[1]
|
|
|
|
if nM is None:
|
|
raise ValueError(Errors.E905)
|
|
if nO is None:
|
|
raise ValueError(Errors.E904)
|
|
model.set_dim("nM", nM)
|
|
model.set_dim("nO", nO)
|
|
model.set_param("W", init_W(model.ops, (nO, nM)))
|
|
return model
|
|
|
|
|
|
def _handle_empty(ops: Ops, nO: int):
|
|
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []
|
|
|
|
|
|
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
|
|
if rate is not None:
|
|
mask = ops.get_dropout_mask((nO,), rate)
|
|
return mask # type: ignore
|
|
return None
|