mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 13:43:24 +03:00
Support registered vectors
This commit is contained in:
parent
476a2e7a0a
commit
d3b8b9162c
|
@ -26,6 +26,9 @@ batch_size = 1000
|
||||||
[nlp.tokenizer]
|
[nlp.tokenizer]
|
||||||
@tokenizers = "spacy.Tokenizer.v1"
|
@tokenizers = "spacy.Tokenizer.v1"
|
||||||
|
|
||||||
|
[nlp.vectors]
|
||||||
|
@misc = "spacy.Vectors.v1"
|
||||||
|
|
||||||
# The pipeline components and their models
|
# The pipeline components and their models
|
||||||
[components]
|
[components]
|
||||||
|
|
||||||
|
|
|
@ -549,6 +549,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"during training, make sure to include it in 'annotating components'")
|
"during training, make sure to include it in 'annotating components'")
|
||||||
|
|
||||||
# New errors added in v3.x
|
# New errors added in v3.x
|
||||||
|
E849 = ("Unable to {action} vectors for vectors of type {vectors_type}.")
|
||||||
E850 = ("The PretrainVectors objective currently only supports default or "
|
E850 = ("The PretrainVectors objective currently only supports default or "
|
||||||
"floret vectors, not {mode} vectors.")
|
"floret vectors, not {mode} vectors.")
|
||||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||||
|
|
|
@ -20,6 +20,8 @@ import traceback
|
||||||
|
|
||||||
from . import ty
|
from . import ty
|
||||||
from .tokens.underscore import Underscore
|
from .tokens.underscore import Underscore
|
||||||
|
from .strings import StringStore
|
||||||
|
from .vectors import BaseVectors
|
||||||
from .vocab import Vocab, create_vocab
|
from .vocab import Vocab, create_vocab
|
||||||
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
||||||
from .training import Example, validate_examples
|
from .training import Example, validate_examples
|
||||||
|
@ -134,6 +136,7 @@ class Language:
|
||||||
max_length: int = 10**6,
|
max_length: int = 10**6,
|
||||||
meta: Dict[str, Any] = {},
|
meta: Dict[str, Any] = {},
|
||||||
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
|
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
|
||||||
|
create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None,
|
||||||
batch_size: int = 1000,
|
batch_size: int = 1000,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -174,6 +177,10 @@ class Language:
|
||||||
if vocab is True:
|
if vocab is True:
|
||||||
vectors_name = meta.get("vectors", {}).get("name")
|
vectors_name = meta.get("vectors", {}).get("name")
|
||||||
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
|
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
|
||||||
|
if not create_vectors:
|
||||||
|
vectors_cfg = {"vectors": self._config["nlp"]["vectors"]}
|
||||||
|
create_vectors = registry.resolve(vectors_cfg)["vectors"]
|
||||||
|
vocab.vectors = create_vectors(vocab)
|
||||||
else:
|
else:
|
||||||
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||||
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||||
|
@ -1750,6 +1757,7 @@ class Language:
|
||||||
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
||||||
)
|
)
|
||||||
create_tokenizer = resolved_nlp["tokenizer"]
|
create_tokenizer = resolved_nlp["tokenizer"]
|
||||||
|
create_vectors = resolved_nlp["vectors"]
|
||||||
before_creation = resolved_nlp["before_creation"]
|
before_creation = resolved_nlp["before_creation"]
|
||||||
after_creation = resolved_nlp["after_creation"]
|
after_creation = resolved_nlp["after_creation"]
|
||||||
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
||||||
|
@ -1770,7 +1778,7 @@ class Language:
|
||||||
# inside stuff like the spacy train function. If we loaded them here,
|
# inside stuff like the spacy train function. If we loaded them here,
|
||||||
# then we would load them twice at runtime: once when we make from config,
|
# then we would load them twice at runtime: once when we make from config,
|
||||||
# and then again when we load from disk.
|
# and then again when we load from disk.
|
||||||
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta)
|
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, create_vectors=create_vectors, meta=meta)
|
||||||
if after_creation is not None:
|
if after_creation is not None:
|
||||||
nlp = after_creation(nlp)
|
nlp = after_creation(nlp)
|
||||||
if not isinstance(nlp, cls):
|
if not isinstance(nlp, cls):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from thinc.api import Model, Ops, registry
|
||||||
|
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..vectors import Mode
|
from ..vectors import Vectors, Mode
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,11 +43,14 @@ def forward(
|
||||||
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
||||||
vocab: Vocab = docs[0].vocab
|
vocab: Vocab = docs[0].vocab
|
||||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
if vocab.vectors.mode == Mode.default:
|
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
|
||||||
V = model.ops.asarray(vocab.vectors.data)
|
V = model.ops.asarray(vocab.vectors.data)
|
||||||
rows = vocab.vectors.find(keys=keys)
|
rows = vocab.vectors.find(keys=keys)
|
||||||
V = model.ops.as_contig(V[rows])
|
V = model.ops.as_contig(V[rows])
|
||||||
elif vocab.vectors.mode == Mode.floret:
|
elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
|
||||||
|
V = vocab.vectors.get_batch(keys)
|
||||||
|
V = model.ops.as_contig(V)
|
||||||
|
elif hasattr(vocab.vectors, "get_batch"):
|
||||||
V = vocab.vectors.get_batch(keys)
|
V = vocab.vectors.get_batch(keys)
|
||||||
V = model.ops.as_contig(V)
|
V = model.ops.as_contig(V)
|
||||||
else:
|
else:
|
||||||
|
@ -56,7 +59,7 @@ def forward(
|
||||||
vectors_data = model.ops.gemm(V, W, trans2=True)
|
vectors_data = model.ops.gemm(V, W, trans2=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise RuntimeError(Errors.E896)
|
raise RuntimeError(Errors.E896)
|
||||||
if vocab.vectors.mode == Mode.default:
|
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
|
||||||
# Convert negative indices to 0-vectors
|
# Convert negative indices to 0-vectors
|
||||||
# TODO: more options for UNK tokens
|
# TODO: more options for UNK tokens
|
||||||
vectors_data[rows < 0] = 0
|
vectors_data[rows < 0] = 0
|
||||||
|
|
|
@ -375,6 +375,7 @@ class ConfigSchemaNlp(BaseModel):
|
||||||
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
||||||
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
||||||
batch_size: Optional[int] = Field(..., title="Default batch size")
|
batch_size: Optional[int] = Field(..., title="Default batch size")
|
||||||
|
vectors: Callable = Field(..., title="Vectors implementation")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
|
from typing import Callable
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
from libc.stdint cimport uint32_t, uint64_t
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
from cython.operator cimport dereference as deref
|
from cython.operator cimport dereference as deref
|
||||||
|
@ -6,7 +8,8 @@ from murmurhash.mrmr cimport hash128_x64
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import numpy
|
import numpy
|
||||||
from typing import cast
|
from pathlib import Path
|
||||||
|
from typing import cast, TYPE_CHECKING, Union
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -21,6 +24,10 @@ from .errors import Errors, Warnings
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .vocab import Vocab # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def unpickle_vectors(bytes_data):
|
def unpickle_vectors(bytes_data):
|
||||||
return Vectors().from_bytes(bytes_data)
|
return Vectors().from_bytes(bytes_data)
|
||||||
|
|
||||||
|
@ -34,7 +41,51 @@ class Mode(str, Enum):
|
||||||
return list(cls.__members__.keys())
|
return list(cls.__members__.keys())
|
||||||
|
|
||||||
|
|
||||||
cdef class Vectors:
|
cdef class BaseVectors:
|
||||||
|
def __init__(self, *, strings=None):
|
||||||
|
# Make sure abstract BaseVectors is not instantiated.
|
||||||
|
if self.__class__ == BaseVectors:
|
||||||
|
raise TypeError(
|
||||||
|
Errors.E1046.format(cls_name=self.__class__.__name__)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_batch(self, keys):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vectors_length(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add(self, key, *, vector=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
|
||||||
|
# allow serialization
|
||||||
|
def to_bytes(self, **kwargs):
|
||||||
|
return b""
|
||||||
|
|
||||||
|
def from_bytes(self, data: bytes, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path: Union[str, Path], **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def from_disk(self, path: Union[str, Path], **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@util.registry.misc("spacy.Vectors.v1")
|
||||||
|
def create_mode_vectors() -> Callable[["Vocab"], BaseVectors]:
|
||||||
|
def vectors_factory(vocab: "Vocab") -> BaseVectors:
|
||||||
|
return Vectors(strings=vocab.strings)
|
||||||
|
|
||||||
|
return vectors_factory
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Vectors(BaseVectors):
|
||||||
"""Store, save and load word vectors.
|
"""Store, save and load word vectors.
|
||||||
|
|
||||||
Vectors data is kept in the vectors.data attribute, which should be an
|
Vectors data is kept in the vectors.data attribute, which should be an
|
||||||
|
|
|
@ -88,8 +88,9 @@ cdef class Vocab:
|
||||||
return self._vectors
|
return self._vectors
|
||||||
|
|
||||||
def __set__(self, vectors):
|
def __set__(self, vectors):
|
||||||
for s in vectors.strings:
|
if hasattr(vectors, "strings"):
|
||||||
self.strings.add(s)
|
for s in vectors.strings:
|
||||||
|
self.strings.add(s)
|
||||||
self._vectors = vectors
|
self._vectors = vectors
|
||||||
self._vectors.strings = self.strings
|
self._vectors.strings = self.strings
|
||||||
|
|
||||||
|
@ -188,7 +189,7 @@ cdef class Vocab:
|
||||||
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
||||||
lex.orth = self.strings.add(string)
|
lex.orth = self.strings.add(string)
|
||||||
lex.length = len(string)
|
lex.length = len(string)
|
||||||
if self.vectors is not None:
|
if self.vectors is not None and hasattr(self.vectors, "key2row"):
|
||||||
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
||||||
else:
|
else:
|
||||||
lex.id = OOV_RANK
|
lex.id = OOV_RANK
|
||||||
|
@ -284,12 +285,17 @@ cdef class Vocab:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vectors_length(self):
|
def vectors_length(self):
|
||||||
return self.vectors.shape[1]
|
if hasattr(self.vectors, "shape"):
|
||||||
|
return self.vectors.shape[1]
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
def reset_vectors(self, *, width=None, shape=None):
|
def reset_vectors(self, *, width=None, shape=None):
|
||||||
"""Drop the current vector table. Because all vectors must be the same
|
"""Drop the current vector table. Because all vectors must be the same
|
||||||
width, you have to call this to change the size of the vectors.
|
width, you have to call this to change the size of the vectors.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(self.vectors, Vectors):
|
||||||
|
raise ValueError(Errors.E849.format("reset", vectors_type=type(self.vectors)))
|
||||||
if width is not None and shape is not None:
|
if width is not None and shape is not None:
|
||||||
raise ValueError(Errors.E065.format(width=width, shape=shape))
|
raise ValueError(Errors.E065.format(width=width, shape=shape))
|
||||||
elif shape is not None:
|
elif shape is not None:
|
||||||
|
@ -299,6 +305,8 @@ cdef class Vocab:
|
||||||
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
||||||
|
|
||||||
def deduplicate_vectors(self):
|
def deduplicate_vectors(self):
|
||||||
|
if not isinstance(self.vectors, Vectors):
|
||||||
|
raise ValueError(Errors.E849.format(action="deduplicate", vectors_type=type(self.vectors)))
|
||||||
if self.vectors.mode != VectorsMode.default:
|
if self.vectors.mode != VectorsMode.default:
|
||||||
raise ValueError(Errors.E858.format(
|
raise ValueError(Errors.E858.format(
|
||||||
mode=self.vectors.mode,
|
mode=self.vectors.mode,
|
||||||
|
@ -352,6 +360,8 @@ cdef class Vocab:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vocab#prune_vectors
|
DOCS: https://spacy.io/api/vocab#prune_vectors
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(self.vectors, Vectors):
|
||||||
|
raise ValueError(Errors.E849.format(action="prune", vectors_type=type(self.vectors)))
|
||||||
if self.vectors.mode != VectorsMode.default:
|
if self.vectors.mode != VectorsMode.default:
|
||||||
raise ValueError(Errors.E858.format(
|
raise ValueError(Errors.E858.format(
|
||||||
mode=self.vectors.mode,
|
mode=self.vectors.mode,
|
||||||
|
@ -400,7 +410,10 @@ cdef class Vocab:
|
||||||
orth = self.strings.add(orth)
|
orth = self.strings.add(orth)
|
||||||
if self.has_vector(orth):
|
if self.has_vector(orth):
|
||||||
return self.vectors[orth]
|
return self.vectors[orth]
|
||||||
xp = get_array_module(self.vectors.data)
|
if isinstance(self.vectors, Vectors):
|
||||||
|
xp = get_array_module(self.vectors.data)
|
||||||
|
else:
|
||||||
|
xp = get_current_ops().xp
|
||||||
vectors = xp.zeros((self.vectors_length,), dtype="f")
|
vectors = xp.zeros((self.vectors_length,), dtype="f")
|
||||||
return vectors
|
return vectors
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user