Support registered vectors

This commit is contained in:
Adriane Boyd 2023-03-31 18:25:55 +02:00
parent 476a2e7a0a
commit d3b8b9162c
7 changed files with 92 additions and 12 deletions

View File

@ -26,6 +26,9 @@ batch_size = 1000
[nlp.tokenizer] [nlp.tokenizer]
@tokenizers = "spacy.Tokenizer.v1" @tokenizers = "spacy.Tokenizer.v1"
[nlp.vectors]
@misc = "spacy.Vectors.v1"
# The pipeline components and their models # The pipeline components and their models
[components] [components]

View File

@ -549,6 +549,7 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'") "during training, make sure to include it in 'annotating components'")
# New errors added in v3.x # New errors added in v3.x
E849 = ("Unable to {action} vectors for vectors of type {vectors_type}.")
E850 = ("The PretrainVectors objective currently only supports default or " E850 = ("The PretrainVectors objective currently only supports default or "
"floret vectors, not {mode} vectors.") "floret vectors, not {mode} vectors.")
E851 = ("The 'textcat' component labels should only have values of 0 or 1, " E851 = ("The 'textcat' component labels should only have values of 0 or 1, "

View File

@ -20,6 +20,8 @@ import traceback
from . import ty from . import ty
from .tokens.underscore import Underscore from .tokens.underscore import Underscore
from .strings import StringStore
from .vectors import BaseVectors
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples from .training import Example, validate_examples
@ -134,6 +136,7 @@ class Language:
max_length: int = 10**6, max_length: int = 10**6,
meta: Dict[str, Any] = {}, meta: Dict[str, Any] = {},
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None, create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None,
batch_size: int = 1000, batch_size: int = 1000,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -174,6 +177,10 @@ class Language:
if vocab is True: if vocab is True:
vectors_name = meta.get("vectors", {}).get("name") vectors_name = meta.get("vectors", {}).get("name")
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name) vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
if not create_vectors:
vectors_cfg = {"vectors": self._config["nlp"]["vectors"]}
create_vectors = registry.resolve(vectors_cfg)["vectors"]
vocab.vectors = create_vectors(vocab)
else: else:
if (self.lang and vocab.lang) and (self.lang != vocab.lang): if (self.lang and vocab.lang) and (self.lang != vocab.lang):
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
@ -1750,6 +1757,7 @@ class Language:
filled["nlp"], validate=validate, schema=ConfigSchemaNlp filled["nlp"], validate=validate, schema=ConfigSchemaNlp
) )
create_tokenizer = resolved_nlp["tokenizer"] create_tokenizer = resolved_nlp["tokenizer"]
create_vectors = resolved_nlp["vectors"]
before_creation = resolved_nlp["before_creation"] before_creation = resolved_nlp["before_creation"]
after_creation = resolved_nlp["after_creation"] after_creation = resolved_nlp["after_creation"]
after_pipeline_creation = resolved_nlp["after_pipeline_creation"] after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
@ -1770,7 +1778,7 @@ class Language:
# inside stuff like the spacy train function. If we loaded them here, # inside stuff like the spacy train function. If we loaded them here,
# then we would load them twice at runtime: once when we make from config, # then we would load them twice at runtime: once when we make from config,
# and then again when we load from disk. # and then again when we load from disk.
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta) nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, create_vectors=create_vectors, meta=meta)
if after_creation is not None: if after_creation is not None:
nlp = after_creation(nlp) nlp = after_creation(nlp)
if not isinstance(nlp, cls): if not isinstance(nlp, cls):

View File

@ -6,7 +6,7 @@ from thinc.api import Model, Ops, registry
from ..tokens import Doc from ..tokens import Doc
from ..errors import Errors from ..errors import Errors
from ..vectors import Mode from ..vectors import Vectors, Mode
from ..vocab import Vocab from ..vocab import Vocab
@ -43,11 +43,14 @@ def forward(
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
vocab: Vocab = docs[0].vocab vocab: Vocab = docs[0].vocab
W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if vocab.vectors.mode == Mode.default: if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
V = model.ops.asarray(vocab.vectors.data) V = model.ops.asarray(vocab.vectors.data)
rows = vocab.vectors.find(keys=keys) rows = vocab.vectors.find(keys=keys)
V = model.ops.as_contig(V[rows]) V = model.ops.as_contig(V[rows])
elif vocab.vectors.mode == Mode.floret: elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V)
elif hasattr(vocab.vectors, "get_batch"):
V = vocab.vectors.get_batch(keys) V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V) V = model.ops.as_contig(V)
else: else:
@ -56,7 +59,7 @@ def forward(
vectors_data = model.ops.gemm(V, W, trans2=True) vectors_data = model.ops.gemm(V, W, trans2=True)
except ValueError: except ValueError:
raise RuntimeError(Errors.E896) raise RuntimeError(Errors.E896)
if vocab.vectors.mode == Mode.default: if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
# Convert negative indices to 0-vectors # Convert negative indices to 0-vectors
# TODO: more options for UNK tokens # TODO: more options for UNK tokens
vectors_data[rows < 0] = 0 vectors_data[rows < 0] = 0

View File

@ -375,6 +375,7 @@ class ConfigSchemaNlp(BaseModel):
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed") after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed") after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
batch_size: Optional[int] = Field(..., title="Default batch size") batch_size: Optional[int] = Field(..., title="Default batch size")
vectors: Callable = Field(..., title="Vectors implementation")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -1,3 +1,5 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Callable
cimport numpy as np cimport numpy as np
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
@ -6,7 +8,8 @@ from murmurhash.mrmr cimport hash128_x64
import functools import functools
import numpy import numpy
from typing import cast from pathlib import Path
from typing import cast, TYPE_CHECKING, Union
import warnings import warnings
from enum import Enum from enum import Enum
import srsly import srsly
@ -21,6 +24,10 @@ from .errors import Errors, Warnings
from . import util from . import util
if TYPE_CHECKING:
from .vocab import Vocab # noqa: F401
def unpickle_vectors(bytes_data): def unpickle_vectors(bytes_data):
return Vectors().from_bytes(bytes_data) return Vectors().from_bytes(bytes_data)
@ -34,7 +41,51 @@ class Mode(str, Enum):
return list(cls.__members__.keys()) return list(cls.__members__.keys())
cdef class Vectors: cdef class BaseVectors:
def __init__(self, *, strings=None):
# Make sure abstract BaseVectors is not instantiated.
if self.__class__ == BaseVectors:
raise TypeError(
Errors.E1046.format(cls_name=self.__class__.__name__)
)
def __getitem__(self, key):
raise NotImplementedError
def get_batch(self, keys):
raise NotImplementedError
@property
def vectors_length(self):
raise NotImplementedError
def add(self, key, *, vector=None):
raise NotImplementedError
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
# allow serialization
def to_bytes(self, **kwargs):
return b""
def from_bytes(self, data: bytes, **kwargs):
return self
def to_disk(self, path: Union[str, Path], **kwargs):
return None
def from_disk(self, path: Union[str, Path], **kwargs):
return self
@util.registry.misc("spacy.Vectors.v1")
def create_mode_vectors() -> Callable[["Vocab"], BaseVectors]:
def vectors_factory(vocab: "Vocab") -> BaseVectors:
return Vectors(strings=vocab.strings)
return vectors_factory
cdef class Vectors(BaseVectors):
"""Store, save and load word vectors. """Store, save and load word vectors.
Vectors data is kept in the vectors.data attribute, which should be an Vectors data is kept in the vectors.data attribute, which should be an

View File

@ -88,8 +88,9 @@ cdef class Vocab:
return self._vectors return self._vectors
def __set__(self, vectors): def __set__(self, vectors):
for s in vectors.strings: if hasattr(vectors, "strings"):
self.strings.add(s) for s in vectors.strings:
self.strings.add(s)
self._vectors = vectors self._vectors = vectors
self._vectors.strings = self.strings self._vectors.strings = self.strings
@ -188,7 +189,7 @@ cdef class Vocab:
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC)) lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
lex.orth = self.strings.add(string) lex.orth = self.strings.add(string)
lex.length = len(string) lex.length = len(string)
if self.vectors is not None: if self.vectors is not None and hasattr(self.vectors, "key2row"):
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK) lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
else: else:
lex.id = OOV_RANK lex.id = OOV_RANK
@ -284,12 +285,17 @@ cdef class Vocab:
@property @property
def vectors_length(self): def vectors_length(self):
return self.vectors.shape[1] if hasattr(self.vectors, "shape"):
return self.vectors.shape[1]
else:
return -1
def reset_vectors(self, *, width=None, shape=None): def reset_vectors(self, *, width=None, shape=None):
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
width, you have to call this to change the size of the vectors. width, you have to call this to change the size of the vectors.
""" """
if not isinstance(self.vectors, Vectors):
raise ValueError(Errors.E849.format("reset", vectors_type=type(self.vectors)))
if width is not None and shape is not None: if width is not None and shape is not None:
raise ValueError(Errors.E065.format(width=width, shape=shape)) raise ValueError(Errors.E065.format(width=width, shape=shape))
elif shape is not None: elif shape is not None:
@ -299,6 +305,8 @@ cdef class Vocab:
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width)) self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
def deduplicate_vectors(self): def deduplicate_vectors(self):
if not isinstance(self.vectors, Vectors):
raise ValueError(Errors.E849.format(action="deduplicate", vectors_type=type(self.vectors)))
if self.vectors.mode != VectorsMode.default: if self.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E858.format( raise ValueError(Errors.E858.format(
mode=self.vectors.mode, mode=self.vectors.mode,
@ -352,6 +360,8 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab#prune_vectors DOCS: https://spacy.io/api/vocab#prune_vectors
""" """
if not isinstance(self.vectors, Vectors):
raise ValueError(Errors.E849.format(action="prune", vectors_type=type(self.vectors)))
if self.vectors.mode != VectorsMode.default: if self.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E858.format( raise ValueError(Errors.E858.format(
mode=self.vectors.mode, mode=self.vectors.mode,
@ -400,7 +410,10 @@ cdef class Vocab:
orth = self.strings.add(orth) orth = self.strings.add(orth)
if self.has_vector(orth): if self.has_vector(orth):
return self.vectors[orth] return self.vectors[orth]
xp = get_array_module(self.vectors.data) if isinstance(self.vectors, Vectors):
xp = get_array_module(self.vectors.data)
else:
xp = get_current_ops().xp
vectors = xp.zeros((self.vectors_length,), dtype="f") vectors = xp.zeros((self.vectors_length,), dtype="f")
return vectors return vectors