Support registered vectors (#12492)

* Support registered vectors

* Format

* Auto-fill [nlp] on load from config and from bytes/disk

* Only auto-fill [nlp]

* Undo all changes to Language.from_disk

* Expand BaseVectors

These methods are needed in various places for training and vector
similarity.

* isort

* More linting

* Only fill [nlp.vectors]

* Update spacy/vocab.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Revert changes to test related to auto-filling [nlp]

* Add vectors registry

* Rephrase error about vocab methods for vectors

* Switch to dummy implementation for BaseVectors.to_ops

* Add initial draft of docs

* Remove example from BaseVectors docs

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/basevectors.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Fix type and lint bpemb example

* Update website/docs/api/basevectors.mdx

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Adriane Boyd 2023-08-01 15:46:08 +02:00 committed by GitHub
parent 9ffa5d8a15
commit 0fe43f40f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 425 additions and 16 deletions

View File

@ -26,6 +26,9 @@ batch_size = 1000
[nlp.tokenizer] [nlp.tokenizer]
@tokenizers = "spacy.Tokenizer.v1" @tokenizers = "spacy.Tokenizer.v1"
[nlp.vectors]
@vectors = "spacy.Vectors.v1"
# The pipeline components and their models # The pipeline components and their models
[components] [components]

View File

@ -553,6 +553,8 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'") "during training, make sure to include it in 'annotating components'")
# New errors added in v3.x # New errors added in v3.x
E849 = ("The vocab only supports {method} for vectors of type "
"spacy.vectors.Vectors, not {vectors_type}.")
E850 = ("The PretrainVectors objective currently only supports default or " E850 = ("The PretrainVectors objective currently only supports default or "
"floret vectors, not {mode} vectors.") "floret vectors, not {mode} vectors.")
E851 = ("The 'textcat' component labels should only have values of 0 or 1, " E851 = ("The 'textcat' component labels should only have values of 0 or 1, "

View File

@ -65,6 +65,7 @@ from .util import (
registry, registry,
warn_if_jupyter_cupy, warn_if_jupyter_cupy,
) )
from .vectors import BaseVectors
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
PipeCallable = Callable[[Doc], Doc] PipeCallable = Callable[[Doc], Doc]
@ -158,6 +159,7 @@ class Language:
max_length: int = 10**6, max_length: int = 10**6,
meta: Dict[str, Any] = {}, meta: Dict[str, Any] = {},
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None, create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None,
batch_size: int = 1000, batch_size: int = 1000,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -198,6 +200,10 @@ class Language:
if vocab is True: if vocab is True:
vectors_name = meta.get("vectors", {}).get("name") vectors_name = meta.get("vectors", {}).get("name")
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name) vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
if not create_vectors:
vectors_cfg = {"vectors": self._config["nlp"]["vectors"]}
create_vectors = registry.resolve(vectors_cfg)["vectors"]
vocab.vectors = create_vectors(vocab)
else: else:
if (self.lang and vocab.lang) and (self.lang != vocab.lang): if (self.lang and vocab.lang) and (self.lang != vocab.lang):
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
@ -1765,6 +1771,10 @@ class Language:
).merge(config) ).merge(config)
if "nlp" not in config: if "nlp" not in config:
raise ValueError(Errors.E985.format(config=config)) raise ValueError(Errors.E985.format(config=config))
# fill in [nlp.vectors] if not present (as a narrower alternative to
# auto-filling [nlp] from the default config)
if "vectors" not in config["nlp"]:
config["nlp"]["vectors"] = {"@vectors": "spacy.Vectors.v1"}
config_lang = config["nlp"].get("lang") config_lang = config["nlp"].get("lang")
if config_lang is not None and config_lang != cls.lang: if config_lang is not None and config_lang != cls.lang:
raise ValueError( raise ValueError(
@ -1796,6 +1806,7 @@ class Language:
filled["nlp"], validate=validate, schema=ConfigSchemaNlp filled["nlp"], validate=validate, schema=ConfigSchemaNlp
) )
create_tokenizer = resolved_nlp["tokenizer"] create_tokenizer = resolved_nlp["tokenizer"]
create_vectors = resolved_nlp["vectors"]
before_creation = resolved_nlp["before_creation"] before_creation = resolved_nlp["before_creation"]
after_creation = resolved_nlp["after_creation"] after_creation = resolved_nlp["after_creation"]
after_pipeline_creation = resolved_nlp["after_pipeline_creation"] after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
@ -1816,7 +1827,12 @@ class Language:
# inside stuff like the spacy train function. If we loaded them here, # inside stuff like the spacy train function. If we loaded them here,
# then we would load them twice at runtime: once when we make from config, # then we would load them twice at runtime: once when we make from config,
# and then again when we load from disk. # and then again when we load from disk.
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta) nlp = lang_cls(
vocab=vocab,
create_tokenizer=create_tokenizer,
create_vectors=create_vectors,
meta=meta,
)
if after_creation is not None: if after_creation is not None:
nlp = after_creation(nlp) nlp = after_creation(nlp)
if not isinstance(nlp, cls): if not isinstance(nlp, cls):

View File

@ -9,7 +9,7 @@ from thinc.util import partial
from ..attrs import ORTH from ..attrs import ORTH
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..tokens import Doc from ..tokens import Doc
from ..vectors import Mode from ..vectors import Mode, Vectors
from ..vocab import Vocab from ..vocab import Vocab
@ -48,11 +48,14 @@ def forward(
key_attr: int = getattr(vocab.vectors, "attr", ORTH) key_attr: int = getattr(vocab.vectors, "attr", ORTH)
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if vocab.vectors.mode == Mode.default: if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
V = model.ops.asarray(vocab.vectors.data) V = model.ops.asarray(vocab.vectors.data)
rows = vocab.vectors.find(keys=keys) rows = vocab.vectors.find(keys=keys)
V = model.ops.as_contig(V[rows]) V = model.ops.as_contig(V[rows])
elif vocab.vectors.mode == Mode.floret: elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V)
elif hasattr(vocab.vectors, "get_batch"):
V = vocab.vectors.get_batch(keys) V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V) V = model.ops.as_contig(V)
else: else:
@ -61,7 +64,7 @@ def forward(
vectors_data = model.ops.gemm(V, W, trans2=True) vectors_data = model.ops.gemm(V, W, trans2=True)
except ValueError: except ValueError:
raise RuntimeError(Errors.E896) raise RuntimeError(Errors.E896)
if vocab.vectors.mode == Mode.default: if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
# Convert negative indices to 0-vectors # Convert negative indices to 0-vectors
# TODO: more options for UNK tokens # TODO: more options for UNK tokens
vectors_data[rows < 0] = 0 vectors_data[rows < 0] = 0

View File

@ -397,6 +397,7 @@ class ConfigSchemaNlp(BaseModel):
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed") after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed") after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
batch_size: Optional[int] = Field(..., title="Default batch size") batch_size: Optional[int] = Field(..., title="Default batch size")
vectors: Callable = Field(..., title="Vectors implementation")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -118,6 +118,7 @@ class registry(thinc.registry):
augmenters = catalogue.create("spacy", "augmenters", entry_points=True) augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
loggers = catalogue.create("spacy", "loggers", entry_points=True) loggers = catalogue.create("spacy", "loggers", entry_points=True)
scorers = catalogue.create("spacy", "scorers", entry_points=True) scorers = catalogue.create("spacy", "scorers", entry_points=True)
vectors = catalogue.create("spacy", "vectors", entry_points=True)
# These are factories registered via third-party packages and the # These are factories registered via third-party packages and the
# spacy_factories entry point. This registry only exists so we can easily # spacy_factories entry point. This registry only exists so we can easily
# load them via the entry points. The "true" factories are added via the # load them via the entry points. The "true" factories are added via the

View File

@ -1,3 +1,6 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Callable
from cython.operator cimport dereference as deref from cython.operator cimport dereference as deref
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from libcpp.set cimport set as cppset from libcpp.set cimport set as cppset
@ -5,7 +8,8 @@ from murmurhash.mrmr cimport hash128_x64
import warnings import warnings
from enum import Enum from enum import Enum
from typing import cast from pathlib import Path
from typing import TYPE_CHECKING, Union, cast
import numpy import numpy
import srsly import srsly
@ -21,6 +25,9 @@ from .attrs import IDS
from .errors import Errors, Warnings from .errors import Errors, Warnings
from .strings import get_string_id from .strings import get_string_id
if TYPE_CHECKING:
from .vocab import Vocab # noqa: F401 # no-cython-lint
def unpickle_vectors(bytes_data): def unpickle_vectors(bytes_data):
return Vectors().from_bytes(bytes_data) return Vectors().from_bytes(bytes_data)
@ -35,7 +42,71 @@ class Mode(str, Enum):
return list(cls.__members__.keys()) return list(cls.__members__.keys())
cdef class Vectors: cdef class BaseVectors:
def __init__(self, *, strings=None):
# Make sure abstract BaseVectors is not instantiated.
if self.__class__ == BaseVectors:
raise TypeError(
Errors.E1046.format(cls_name=self.__class__.__name__)
)
def __getitem__(self, key):
raise NotImplementedError
def __contains__(self, key):
raise NotImplementedError
def is_full(self):
raise NotImplementedError
def get_batch(self, keys):
raise NotImplementedError
@property
def shape(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
@property
def vectors_length(self):
raise NotImplementedError
@property
def size(self):
raise NotImplementedError
def add(self, key, *, vector=None):
raise NotImplementedError
def to_ops(self, ops: Ops):
pass
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
# allow serialization
def to_bytes(self, **kwargs):
return b""
def from_bytes(self, data: bytes, **kwargs):
return self
def to_disk(self, path: Union[str, Path], **kwargs):
return None
def from_disk(self, path: Union[str, Path], **kwargs):
return self
@util.registry.vectors("spacy.Vectors.v1")
def create_mode_vectors() -> Callable[["Vocab"], BaseVectors]:
def vectors_factory(vocab: "Vocab") -> BaseVectors:
return Vectors(strings=vocab.strings)
return vectors_factory
cdef class Vectors(BaseVectors):
"""Store, save and load word vectors. """Store, save and load word vectors.
Vectors data is kept in the vectors.data attribute, which should be an Vectors data is kept in the vectors.data attribute, which should be an

View File

@ -94,8 +94,9 @@ cdef class Vocab:
return self._vectors return self._vectors
def __set__(self, vectors): def __set__(self, vectors):
for s in vectors.strings: if hasattr(vectors, "strings"):
self.strings.add(s) for s in vectors.strings:
self.strings.add(s)
self._vectors = vectors self._vectors = vectors
self._vectors.strings = self.strings self._vectors.strings = self.strings
@ -193,7 +194,7 @@ cdef class Vocab:
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC)) lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
lex.orth = self.strings.add(string) lex.orth = self.strings.add(string)
lex.length = len(string) lex.length = len(string)
if self.vectors is not None: if self.vectors is not None and hasattr(self.vectors, "key2row"):
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK) lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
else: else:
lex.id = OOV_RANK lex.id = OOV_RANK
@ -289,12 +290,17 @@ cdef class Vocab:
@property @property
def vectors_length(self): def vectors_length(self):
return self.vectors.shape[1] if hasattr(self.vectors, "shape"):
return self.vectors.shape[1]
else:
return -1
def reset_vectors(self, *, width=None, shape=None): def reset_vectors(self, *, width=None, shape=None):
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
width, you have to call this to change the size of the vectors. width, you have to call this to change the size of the vectors.
""" """
if not isinstance(self.vectors, Vectors):
raise ValueError(Errors.E849.format(method="reset_vectors", vectors_type=type(self.vectors)))
if width is not None and shape is not None: if width is not None and shape is not None:
raise ValueError(Errors.E065.format(width=width, shape=shape)) raise ValueError(Errors.E065.format(width=width, shape=shape))
elif shape is not None: elif shape is not None:
@ -304,6 +310,8 @@ cdef class Vocab:
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width)) self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
def deduplicate_vectors(self): def deduplicate_vectors(self):
if not isinstance(self.vectors, Vectors):
raise ValueError(Errors.E849.format(method="deduplicate_vectors", vectors_type=type(self.vectors)))
if self.vectors.mode != VectorsMode.default: if self.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E858.format( raise ValueError(Errors.E858.format(
mode=self.vectors.mode, mode=self.vectors.mode,
@ -357,6 +365,8 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab#prune_vectors DOCS: https://spacy.io/api/vocab#prune_vectors
""" """
if not isinstance(self.vectors, Vectors):
raise ValueError(Errors.E849.format(method="prune_vectors", vectors_type=type(self.vectors)))
if self.vectors.mode != VectorsMode.default: if self.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E858.format( raise ValueError(Errors.E858.format(
mode=self.vectors.mode, mode=self.vectors.mode,

View File

@ -0,0 +1,143 @@
---
title: BaseVectors
teaser: Abstract class for word vectors
tag: class
source: spacy/vectors.pyx
version: 3.7
---
`BaseVectors` is an abstract class to support the development of custom vectors
implementations.
For use in training with [`StaticVectors`](/api/architectures#staticvectors),
`get_batch` must be implemented. For improved performance, use efficient
batching in `get_batch` and implement `to_ops` to copy the vector data to the
current device. See an example custom implementation for
[BPEmb subword embeddings](/usage/embeddings-transformers#custom-vectors).
## BaseVectors.\_\_init\_\_ {id="init",tag="method"}
Create a new vector store.
| Name | Description |
| -------------- | --------------------------------------------------------------------------------------------------------------------- |
| _keyword-only_ | |
| `strings` | The string store. A new string store is created if one is not provided. Defaults to `None`. ~~Optional[StringStore]~~ |
## BaseVectors.\_\_getitem\_\_ {id="getitem",tag="method"}
Get a vector by key. If the key is not found in the table, a `KeyError` should
be raised.
| Name | Description |
| ----------- | ---------------------------------------------------------------- |
| `key` | The key to get the vector for. ~~Union[int, str]~~ |
| **RETURNS** | The vector for the key. ~~numpy.ndarray[ndim=1, dtype=float32]~~ |
## BaseVectors.\_\_len\_\_ {id="len",tag="method"}
Return the number of vectors in the table.
| Name | Description |
| ----------- | ------------------------------------------- |
| **RETURNS** | The number of vectors in the table. ~~int~~ |
## BaseVectors.\_\_contains\_\_ {id="contains",tag="method"}
Check whether there is a vector entry for the given key.
| Name | Description |
| ----------- | -------------------------------------------- |
| `key` | The key to check. ~~int~~ |
| **RETURNS** | Whether the key has a vector entry. ~~bool~~ |
## BaseVectors.add {id="add",tag="method"}
Add a key to the table, if possible. If no keys can be added, return `-1`.
| Name | Description |
| ----------- | ----------------------------------------------------------------------------------- |
| `key` | The key to add. ~~Union[str, int]~~ |
| **RETURNS** | The row the vector was added to, or `-1` if the operation is not supported. ~~int~~ |
## BaseVectors.shape {id="shape",tag="property"}
Get `(rows, dims)` tuples of number of rows and number of dimensions in the
vector table.
| Name | Description |
| ----------- | ------------------------------------------ |
| **RETURNS** | A `(rows, dims)` pair. ~~Tuple[int, int]~~ |
## BaseVectors.size {id="size",tag="property"}
The vector size, i.e. `rows * dims`.
| Name | Description |
| ----------- | ------------------------ |
| **RETURNS** | The vector size. ~~int~~ |
## BaseVectors.is_full {id="is_full",tag="property"}
Whether the vectors table is full and no slots are available for new keys.
| Name | Description |
| ----------- | ------------------------------------------- |
| **RETURNS** | Whether the vectors table is full. ~~bool~~ |
## BaseVectors.get_batch {id="get_batch",tag="method",version="3.2"}
Get the vectors for the provided keys efficiently as a batch. Required to use
the vectors with [`StaticVectors`](/api/architectures#StaticVectors) for
training.
| Name | Description |
| ------ | --------------------------------------- |
| `keys` | The keys. ~~Iterable[Union[int, str]]~~ |
## BaseVectors.to_ops {id="to_ops",tag="method"}
Dummy method. Implement this to change the embedding matrix to use different
Thinc ops.
| Name | Description |
| ----- | -------------------------------------------------------- |
| `ops` | The Thinc ops to switch the embedding matrix to. ~~Ops~~ |
## BaseVectors.to_disk {id="to_disk",tag="method"}
Dummy method to allow serialization. Implement to save vector data with the
pipeline.
| Name | Description |
| ------ | ------------------------------------------------------------------------------------------------------------------------------------------ |
| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
## BaseVectors.from_disk {id="from_disk",tag="method"}
Dummy method to allow serialization. Implement to load vector data from a saved
pipeline.
| Name | Description |
| ----------- | ----------------------------------------------------------------------------------------------- |
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
| **RETURNS** | The modified vectors object. ~~BaseVectors~~ |
## BaseVectors.to_bytes {id="to_bytes",tag="method"}
Dummy method to allow serialization. Implement to serialize vector data to a
binary string.
| Name | Description |
| ----------- | ---------------------------------------------------- |
| **RETURNS** | The serialized form of the vectors object. ~~bytes~~ |
## BaseVectors.from_bytes {id="from_bytes",tag="method"}
Dummy method to allow serialization. Implement to load vector data from a binary
string.
| Name | Description |
| ----------- | ----------------------------------- |
| `data` | The data to load from. ~~bytes~~ |
| **RETURNS** | The vectors object. ~~BaseVectors~~ |

View File

@ -297,10 +297,9 @@ The vector size, i.e. `rows * dims`.
## Vectors.is_full {id="is_full",tag="property"} ## Vectors.is_full {id="is_full",tag="property"}
Whether the vectors table is full and has no slots are available for new keys. Whether the vectors table is full and no slots are available for new keys. If a
If a table is full, it can be resized using table is full, it can be resized using [`Vectors.resize`](/api/vectors#resize).
[`Vectors.resize`](/api/vectors#resize). In `floret` mode, the table is always In `floret` mode, the table is always full and cannot be resized.
full and cannot be resized.
> #### Example > #### Example
> >
@ -441,7 +440,7 @@ Load state from a binary string.
> #### Example > #### Example
> >
> ```python > ```python
> fron spacy.vectors import Vectors > from spacy.vectors import Vectors
> vectors_bytes = vectors.to_bytes() > vectors_bytes = vectors.to_bytes()
> new_vectors = Vectors(StringStore()) > new_vectors = Vectors(StringStore())
> new_vectors.from_bytes(vectors_bytes) > new_vectors.from_bytes(vectors_bytes)

View File

@ -632,6 +632,165 @@ def MyCustomVectors(
) )
``` ```
#### Creating a custom vectors implementation {id="custom-vectors",version="3.7"}
You can specify a custom registered vectors class under `[nlp.vectors]` in order
to use static vectors in formats other than the ones supported by
[`Vectors`](/api/vectors). Extend the abstract [`BaseVectors`](/api/basevectors)
class to implement your custom vectors.
As an example, the following `BPEmbVectors` class implements support for
[BPEmb subword embeddings](https://bpemb.h-its.org/):
```python
# requires: pip install bpemb
import warnings
from pathlib import Path
from typing import Callable, Optional, cast
from bpemb import BPEmb
from thinc.api import Ops, get_current_ops
from thinc.backends import get_array_ops
from thinc.types import Floats2d
from spacy.strings import StringStore
from spacy.util import registry
from spacy.vectors import BaseVectors
from spacy.vocab import Vocab
class BPEmbVectors(BaseVectors):
def __init__(
self,
*,
strings: Optional[StringStore] = None,
lang: Optional[str] = None,
vs: Optional[int] = None,
dim: Optional[int] = None,
cache_dir: Optional[Path] = None,
encode_extra_options: Optional[str] = None,
model_file: Optional[Path] = None,
emb_file: Optional[Path] = None,
):
kwargs = {}
if lang is not None:
kwargs["lang"] = lang
if vs is not None:
kwargs["vs"] = vs
if dim is not None:
kwargs["dim"] = dim
if cache_dir is not None:
kwargs["cache_dir"] = cache_dir
if encode_extra_options is not None:
kwargs["encode_extra_options"] = encode_extra_options
if model_file is not None:
kwargs["model_file"] = model_file
if emb_file is not None:
kwargs["emb_file"] = emb_file
self.bpemb = BPEmb(**kwargs)
self.strings = strings
self.name = repr(self.bpemb)
self.n_keys = -1
self.mode = "BPEmb"
self.to_ops(get_current_ops())
def __contains__(self, key):
return True
def is_full(self):
return True
def add(self, key, *, vector=None, row=None):
warnings.warn(
(
"Skipping BPEmbVectors.add: the bpemb vector table cannot be "
"modified. Vectors are calculated from bytepieces."
)
)
return -1
def __getitem__(self, key):
return self.get_batch([key])[0]
def get_batch(self, keys):
keys = [self.strings.as_string(key) for key in keys]
bp_ids = self.bpemb.encode_ids(keys)
ops = get_array_ops(self.bpemb.emb.vectors)
indices = ops.asarray(ops.xp.hstack(bp_ids), dtype="int32")
lengths = ops.asarray([len(x) for x in bp_ids], dtype="int32")
vecs = ops.reduce_mean(cast(Floats2d, self.bpemb.emb.vectors[indices]), lengths)
return vecs
@property
def shape(self):
return self.bpemb.vectors.shape
def __len__(self):
return self.shape[0]
@property
def vectors_length(self):
return self.shape[1]
@property
def size(self):
return self.bpemb.vectors.size
def to_ops(self, ops: Ops):
self.bpemb.emb.vectors = ops.asarray(self.bpemb.emb.vectors)
@registry.vectors("BPEmbVectors.v1")
def create_bpemb_vectors(
lang: Optional[str] = "multi",
vs: Optional[int] = None,
dim: Optional[int] = None,
cache_dir: Optional[Path] = None,
encode_extra_options: Optional[str] = None,
model_file: Optional[Path] = None,
emb_file: Optional[Path] = None,
) -> Callable[[Vocab], BPEmbVectors]:
def bpemb_vectors_factory(vocab: Vocab) -> BPEmbVectors:
return BPEmbVectors(
strings=vocab.strings,
lang=lang,
vs=vs,
dim=dim,
cache_dir=cache_dir,
encode_extra_options=encode_extra_options,
model_file=model_file,
emb_file=emb_file,
)
return bpemb_vectors_factory
```
<Infobox variant="warning">
Note that the serialization methods are not implemented, so the embeddings are
loaded from your local cache or downloaded by `BPEmb` each time the pipeline is
loaded.
</Infobox>
To use this in your pipeline, specify this registered function under
`[nlp.vectors]` in your config:
```ini
[nlp.vectors]
@vectors = "BPEmbVectors.v1"
lang = "en"
```
Or specify it when creating a blank pipeline:
```python
nlp = spacy.blank("en", config={"nlp.vectors": {"@vectors": "BPEmbVectors.v1", "lang": "en"}})
```
Remember to include this code with `--code` when using
[`spacy train`](/api/cli#train) and [`spacy package`](/api/cli#package).
## Pretraining {id="pretraining"} ## Pretraining {id="pretraining"}
The [`spacy pretrain`](/api/cli#pretrain) command lets you initialize your The [`spacy pretrain`](/api/cli#pretrain) command lets you initialize your

View File

@ -131,6 +131,7 @@
"label": "Other", "label": "Other",
"items": [ "items": [
{ "text": "Attributes", "url": "/api/attributes" }, { "text": "Attributes", "url": "/api/attributes" },
{ "text": "BaseVectors", "url": "/api/basevectors" },
{ "text": "Corpus", "url": "/api/corpus" }, { "text": "Corpus", "url": "/api/corpus" },
{ "text": "InMemoryLookupKB", "url": "/api/inmemorylookupkb" }, { "text": "InMemoryLookupKB", "url": "/api/inmemorylookupkb" },
{ "text": "KnowledgeBase", "url": "/api/kb" }, { "text": "KnowledgeBase", "url": "/api/kb" },