spaCy/spacy/ml/staticvectors.py
Adriane Boyd 0fe43f40f1
Support registered vectors (#12492)
* Support registered vectors

* Format

* Auto-fill [nlp] on load from config and from bytes/disk

* Only auto-fill [nlp]

* Undo all changes to Language.from_disk

* Expand BaseVectors

These methods are needed in various places for training and vector
similarity.

* isort

* More linting

* Only fill [nlp.vectors]

* Update spacy/vocab.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Revert changes to test related to auto-filling [nlp]

* Add vectors registry

* Rephrase error about vocab methods for vectors

* Switch to dummy implementation for BaseVectors.to_ops

* Add initial draft of docs

* Remove example from BaseVectors docs

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/basevectors.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Fix type and lint bpemb example

* Update website/docs/api/basevectors.mdx

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2023-08-01 15:46:08 +02:00

126 lines
4.1 KiB
Python

import warnings
from typing import Callable, List, Optional, Sequence, Tuple, cast
from thinc.api import Model, Ops, registry
from thinc.initializers import glorot_uniform_init
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from thinc.util import partial
from ..attrs import ORTH
from ..errors import Errors, Warnings
from ..tokens import Doc
from ..vectors import Mode, Vectors
from ..vocab import Vocab
@registry.layers("spacy.StaticVectors.v2")
def StaticVectors(
nO: Optional[int] = None,
nM: Optional[int] = None,
*,
dropout: Optional[float] = None,
init_W: Callable = glorot_uniform_init,
key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]:
"""Embed Doc objects with their vocab's vectors table, applying a learned
linear projection to control the dimensionality. If a dropout rate is
specified, the dropout is applied per dimension over the whole batch.
"""
if key_attr != "ORTH":
warnings.warn(Warnings.W125, DeprecationWarning)
return Model(
"static_vectors",
forward,
init=partial(init, init_W),
params={"W": None},
attrs={"key_attr": key_attr, "dropout_rate": dropout},
dims={"nO": nO, "nM": nM},
)
def forward(
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
token_count = sum(len(doc) for doc in docs)
if not token_count:
return _handle_empty(model.ops, model.get_dim("nO"))
vocab: Vocab = docs[0].vocab
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
V = model.ops.asarray(vocab.vectors.data)
rows = vocab.vectors.find(keys=keys)
V = model.ops.as_contig(V[rows])
elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V)
elif hasattr(vocab.vectors, "get_batch"):
V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V)
else:
raise RuntimeError(Errors.E896)
try:
vectors_data = model.ops.gemm(V, W, trans2=True)
except ValueError:
raise RuntimeError(Errors.E896)
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
# Convert negative indices to 0-vectors
# TODO: more options for UNK tokens
vectors_data[rows < 0] = 0
output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
mask = None
if is_train:
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
if mask is not None:
output.data *= mask
def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None:
d_output.data *= mask
model.inc_grad(
"W",
model.ops.gemm(
cast(Floats2d, d_output.data),
cast(Floats2d, model.ops.as_contig(V)),
trans1=True,
),
)
return []
return output, backprop
def init(
init_W: Callable,
model: Model[List[Doc], Ragged],
X: Optional[List[Doc]] = None,
Y: Optional[Ragged] = None,
) -> Model[List[Doc], Ragged]:
nM = model.get_dim("nM") if model.has_dim("nM") else None
nO = model.get_dim("nO") if model.has_dim("nO") else None
if X is not None and len(X):
nM = X[0].vocab.vectors.shape[1]
if Y is not None:
nO = Y.data.shape[1]
if nM is None:
raise ValueError(Errors.E905)
if nO is None:
raise ValueError(Errors.E904)
model.set_dim("nM", nM)
model.set_dim("nO", nO)
model.set_param("W", init_W(model.ops, (nO, nM)))
return model
def _handle_empty(ops: Ops, nO: int):
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
if rate is not None:
mask = ops.get_dropout_mask((nO,), rate)
return mask # type: ignore
return None