from typing import List, Tuple, Callable, Optional, Sequence, cast from thinc.initializers import glorot_uniform_init from thinc.util import partial from thinc.types import Ragged, Floats2d, Floats1d, Ints1d from thinc.api import Model, Ops, registry from ..tokens import Doc from ..errors import Errors from ..vectors import Mode from ..vocab import Vocab @registry.layers("spacy.StaticVectors.v2") def StaticVectors( nO: Optional[int] = None, nM: Optional[int] = None, *, dropout: Optional[float] = None, init_W: Callable = glorot_uniform_init, key_attr: str = "ORTH" ) -> Model[List[Doc], Ragged]: """Embed Doc objects with their vocab's vectors table, applying a learned linear projection to control the dimensionality. If a dropout rate is specified, the dropout is applied per dimension over the whole batch. """ return Model( "static_vectors", forward, init=partial(init, init_W), params={"W": None}, attrs={"key_attr": key_attr, "dropout_rate": dropout}, dims={"nO": nO, "nM": nM}, ) def forward( model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool ) -> Tuple[Ragged, Callable]: token_count = sum(len(doc) for doc in docs) if not token_count: return _handle_empty(model.ops, model.get_dim("nO")) key_attr: int = model.attrs["key_attr"] keys: Ints1d = model.ops.flatten( cast(Sequence, [doc.to_array(key_attr) for doc in docs]) ) vocab: Vocab = docs[0].vocab W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) if vocab.vectors.mode == Mode.default: V = cast(Floats2d, model.ops.asarray(vocab.vectors.data)) rows = vocab.vectors.find(keys=keys) V = model.ops.as_contig(V[rows]) elif vocab.vectors.mode == Mode.floret: V = cast(Floats2d, vocab.vectors.get_batch(keys)) V = model.ops.as_contig(V) else: raise RuntimeError(Errors.E896) try: vectors_data = model.ops.gemm(V, W, trans2=True) except ValueError: raise RuntimeError(Errors.E896) if vocab.vectors.mode == Mode.default: # Convert negative indices to 0-vectors # TODO: more options for UNK tokens vectors_data[rows < 0] = 0 output = Ragged( vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") # type: ignore ) mask = None if is_train: mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) if mask is not None: output.data *= mask def backprop(d_output: Ragged) -> List[Doc]: if mask is not None: d_output.data *= mask model.inc_grad( "W", model.ops.gemm( cast(Floats2d, d_output.data), model.ops.as_contig(V), trans1=True ), ) return [] return output, backprop def init( init_W: Callable, model: Model[List[Doc], Ragged], X: Optional[List[Doc]] = None, Y: Optional[Ragged] = None, ) -> Model[List[Doc], Ragged]: nM = model.get_dim("nM") if model.has_dim("nM") else None nO = model.get_dim("nO") if model.has_dim("nO") else None if X is not None and len(X): nM = X[0].vocab.vectors.data.shape[1] if Y is not None: nO = Y.data.shape[1] if nM is None: raise ValueError(Errors.E905) if nO is None: raise ValueError(Errors.E904) model.set_dim("nM", nM) model.set_dim("nO", nO) model.set_param("W", init_W(model.ops, (nO, nM))) return model def _handle_empty(ops: Ops, nO: int): return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: [] def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]: if rate is not None: mask = ops.get_dropout_mask((nO,), rate) return mask # type: ignore return None