From d3b8b9162caee3777eeeb385684abb71ab0d5dcd Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 31 Mar 2023 18:25:55 +0200 Subject: [PATCH] Support registered vectors --- spacy/default_config.cfg | 3 +++ spacy/errors.py | 1 + spacy/language.py | 10 ++++++- spacy/ml/staticvectors.py | 11 +++++--- spacy/schemas.py | 1 + spacy/vectors.pyx | 55 +++++++++++++++++++++++++++++++++++++-- spacy/vocab.pyx | 23 ++++++++++++---- 7 files changed, 92 insertions(+), 12 deletions(-) diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 694fb732f..812b89165 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -26,6 +26,9 @@ batch_size = 1000 [nlp.tokenizer] @tokenizers = "spacy.Tokenizer.v1" +[nlp.vectors] +@misc = "spacy.Vectors.v1" + # The pipeline components and their models [components] diff --git a/spacy/errors.py b/spacy/errors.py index 40cfa8d92..91d9925c7 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -549,6 +549,7 @@ class Errors(metaclass=ErrorsWithCodes): "during training, make sure to include it in 'annotating components'") # New errors added in v3.x + E849 = ("Unable to {action} vectors for vectors of type {vectors_type}.") E850 = ("The PretrainVectors objective currently only supports default or " "floret vectors, not {mode} vectors.") E851 = ("The 'textcat' component labels should only have values of 0 or 1, " diff --git a/spacy/language.py b/spacy/language.py index 9fdcf6328..936eb7367 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -20,6 +20,8 @@ import traceback from . import ty from .tokens.underscore import Underscore +from .strings import StringStore +from .vectors import BaseVectors from .vocab import Vocab, create_vocab from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .training import Example, validate_examples @@ -134,6 +136,7 @@ class Language: max_length: int = 10**6, meta: Dict[str, Any] = {}, create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None, + create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None, batch_size: int = 1000, **kwargs, ) -> None: @@ -174,6 +177,10 @@ class Language: if vocab is True: vectors_name = meta.get("vectors", {}).get("name") vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name) + if not create_vectors: + vectors_cfg = {"vectors": self._config["nlp"]["vectors"]} + create_vectors = registry.resolve(vectors_cfg)["vectors"] + vocab.vectors = create_vectors(vocab) else: if (self.lang and vocab.lang) and (self.lang != vocab.lang): raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) @@ -1750,6 +1757,7 @@ class Language: filled["nlp"], validate=validate, schema=ConfigSchemaNlp ) create_tokenizer = resolved_nlp["tokenizer"] + create_vectors = resolved_nlp["vectors"] before_creation = resolved_nlp["before_creation"] after_creation = resolved_nlp["after_creation"] after_pipeline_creation = resolved_nlp["after_pipeline_creation"] @@ -1770,7 +1778,7 @@ class Language: # inside stuff like the spacy train function. If we loaded them here, # then we would load them twice at runtime: once when we make from config, # and then again when we load from disk. - nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta) + nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, create_vectors=create_vectors, meta=meta) if after_creation is not None: nlp = after_creation(nlp) if not isinstance(nlp, cls): diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index 04cfe912d..004de2914 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -6,7 +6,7 @@ from thinc.api import Model, Ops, registry from ..tokens import Doc from ..errors import Errors -from ..vectors import Mode +from ..vectors import Vectors, Mode from ..vocab import Vocab @@ -43,11 +43,14 @@ def forward( keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) vocab: Vocab = docs[0].vocab W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) - if vocab.vectors.mode == Mode.default: + if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default: V = model.ops.asarray(vocab.vectors.data) rows = vocab.vectors.find(keys=keys) V = model.ops.as_contig(V[rows]) - elif vocab.vectors.mode == Mode.floret: + elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret: + V = vocab.vectors.get_batch(keys) + V = model.ops.as_contig(V) + elif hasattr(vocab.vectors, "get_batch"): V = vocab.vectors.get_batch(keys) V = model.ops.as_contig(V) else: @@ -56,7 +59,7 @@ def forward( vectors_data = model.ops.gemm(V, W, trans2=True) except ValueError: raise RuntimeError(Errors.E896) - if vocab.vectors.mode == Mode.default: + if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default: # Convert negative indices to 0-vectors # TODO: more options for UNK tokens vectors_data[rows < 0] = 0 diff --git a/spacy/schemas.py b/spacy/schemas.py index 140592dcd..d5353d1e0 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -375,6 +375,7 @@ class ConfigSchemaNlp(BaseModel): after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed") after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed") batch_size: Optional[int] = Field(..., title="Default batch size") + vectors: Callable = Field(..., title="Vectors implementation") # fmt: on class Config: diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index be0f6db09..2f978e733 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -1,3 +1,5 @@ +# cython: infer_types=True, profile=True, binding=True +from typing import Callable cimport numpy as np from libc.stdint cimport uint32_t, uint64_t from cython.operator cimport dereference as deref @@ -6,7 +8,8 @@ from murmurhash.mrmr cimport hash128_x64 import functools import numpy -from typing import cast +from pathlib import Path +from typing import cast, TYPE_CHECKING, Union import warnings from enum import Enum import srsly @@ -21,6 +24,10 @@ from .errors import Errors, Warnings from . import util +if TYPE_CHECKING: + from .vocab import Vocab # noqa: F401 + + def unpickle_vectors(bytes_data): return Vectors().from_bytes(bytes_data) @@ -34,7 +41,51 @@ class Mode(str, Enum): return list(cls.__members__.keys()) -cdef class Vectors: +cdef class BaseVectors: + def __init__(self, *, strings=None): + # Make sure abstract BaseVectors is not instantiated. + if self.__class__ == BaseVectors: + raise TypeError( + Errors.E1046.format(cls_name=self.__class__.__name__) + ) + + def __getitem__(self, key): + raise NotImplementedError + + def get_batch(self, keys): + raise NotImplementedError + + @property + def vectors_length(self): + raise NotImplementedError + + def add(self, key, *, vector=None): + raise NotImplementedError + + # add dummy methods for to_bytes, from_bytes, to_disk and from_disk to + # allow serialization + def to_bytes(self, **kwargs): + return b"" + + def from_bytes(self, data: bytes, **kwargs): + return self + + def to_disk(self, path: Union[str, Path], **kwargs): + return None + + def from_disk(self, path: Union[str, Path], **kwargs): + return self + + +@util.registry.misc("spacy.Vectors.v1") +def create_mode_vectors() -> Callable[["Vocab"], BaseVectors]: + def vectors_factory(vocab: "Vocab") -> BaseVectors: + return Vectors(strings=vocab.strings) + + return vectors_factory + + +cdef class Vectors(BaseVectors): """Store, save and load word vectors. Vectors data is kept in the vectors.data attribute, which should be an diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 27f8e5f98..5e9fe7946 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -88,8 +88,9 @@ cdef class Vocab: return self._vectors def __set__(self, vectors): - for s in vectors.strings: - self.strings.add(s) + if hasattr(vectors, "strings"): + for s in vectors.strings: + self.strings.add(s) self._vectors = vectors self._vectors.strings = self.strings @@ -188,7 +189,7 @@ cdef class Vocab: lex = mem.alloc(1, sizeof(LexemeC)) lex.orth = self.strings.add(string) lex.length = len(string) - if self.vectors is not None: + if self.vectors is not None and hasattr(self.vectors, "key2row"): lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK) else: lex.id = OOV_RANK @@ -284,12 +285,17 @@ cdef class Vocab: @property def vectors_length(self): - return self.vectors.shape[1] + if hasattr(self.vectors, "shape"): + return self.vectors.shape[1] + else: + return -1 def reset_vectors(self, *, width=None, shape=None): """Drop the current vector table. Because all vectors must be the same width, you have to call this to change the size of the vectors. """ + if not isinstance(self.vectors, Vectors): + raise ValueError(Errors.E849.format("reset", vectors_type=type(self.vectors))) if width is not None and shape is not None: raise ValueError(Errors.E065.format(width=width, shape=shape)) elif shape is not None: @@ -299,6 +305,8 @@ cdef class Vocab: self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width)) def deduplicate_vectors(self): + if not isinstance(self.vectors, Vectors): + raise ValueError(Errors.E849.format(action="deduplicate", vectors_type=type(self.vectors))) if self.vectors.mode != VectorsMode.default: raise ValueError(Errors.E858.format( mode=self.vectors.mode, @@ -352,6 +360,8 @@ cdef class Vocab: DOCS: https://spacy.io/api/vocab#prune_vectors """ + if not isinstance(self.vectors, Vectors): + raise ValueError(Errors.E849.format(action="prune", vectors_type=type(self.vectors))) if self.vectors.mode != VectorsMode.default: raise ValueError(Errors.E858.format( mode=self.vectors.mode, @@ -400,7 +410,10 @@ cdef class Vocab: orth = self.strings.add(orth) if self.has_vector(orth): return self.vectors[orth] - xp = get_array_module(self.vectors.data) + if isinstance(self.vectors, Vectors): + xp = get_array_module(self.vectors.data) + else: + xp = get_current_ops().xp vectors = xp.zeros((self.vectors_length,), dtype="f") return vectors