User fewer Vector internals (#9879)

* Use Vectors.shape rather than Vectors.data.shape

* Use Vectors.size rather than Vectors.data.size

* Add Vectors.to_ops to move data between different ops

* Add documentation for Vector.to_ops
This commit is contained in:
Daniël de Kok 2022-01-18 17:14:35 +01:00 committed by GitHub
parent 4dfd559e55
commit 50d2a2c930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 40 additions and 20 deletions

View File

@ -1285,9 +1285,9 @@ class Language:
) )
except IOError: except IOError:
raise IOError(Errors.E884.format(vectors=I["vectors"])) raise IOError(Errors.E884.format(vectors=I["vectors"]))
if self.vocab.vectors.data.shape[1] >= 1: if self.vocab.vectors.shape[1] >= 1:
ops = get_current_ops() ops = get_current_ops()
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) self.vocab.vectors.to_ops(ops)
if hasattr(self.tokenizer, "initialize"): if hasattr(self.tokenizer, "initialize"):
tok_settings = validate_init_settings( tok_settings = validate_init_settings(
self.tokenizer.initialize, # type: ignore[union-attr] self.tokenizer.initialize, # type: ignore[union-attr]
@ -1332,8 +1332,8 @@ class Language:
DOCS: https://spacy.io/api/language#resume_training DOCS: https://spacy.io/api/language#resume_training
""" """
ops = get_current_ops() ops = get_current_ops()
if self.vocab.vectors.data.shape[1] >= 1: if self.vocab.vectors.shape[1] >= 1:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) self.vocab.vectors.to_ops(ops)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"): if hasattr(proc, "_rehearsal_model"):
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined] proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]

View File

@ -23,7 +23,7 @@ def create_pretrain_vectors(
maxout_pieces: int, hidden_size: int, loss: str maxout_pieces: int, hidden_size: int, loss: str
) -> Callable[["Vocab", Model], Model]: ) -> Callable[["Vocab", Model], Model]:
def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model:
if vocab.vectors.data.shape[1] == 0: if vocab.vectors.shape[1] == 0:
raise ValueError(Errors.E875) raise ValueError(Errors.E875)
model = build_cloze_multi_task_model( model = build_cloze_multi_task_model(
vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces
@ -116,7 +116,7 @@ def build_multi_task_model(
def build_cloze_multi_task_model( def build_cloze_multi_task_model(
vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int
) -> Model: ) -> Model:
nO = vocab.vectors.data.shape[1] nO = vocab.vectors.shape[1]
output_layer = chain( output_layer = chain(
cast(Model[List["Floats2d"], Floats2d], list2array()), cast(Model[List["Floats2d"], Floats2d], list2array()),
Maxout( Maxout(

View File

@ -94,7 +94,7 @@ def init(
nM = model.get_dim("nM") if model.has_dim("nM") else None nM = model.get_dim("nM") if model.has_dim("nM") else None
nO = model.get_dim("nO") if model.has_dim("nO") else None nO = model.get_dim("nO") if model.has_dim("nO") else None
if X is not None and len(X): if X is not None and len(X):
nM = X[0].vocab.vectors.data.shape[1] nM = X[0].vocab.vectors.shape[1]
if Y is not None: if Y is not None:
nO = Y.data.shape[1] nO = Y.data.shape[1]

View File

@ -421,7 +421,7 @@ def test_vector_is_oov():
def test_init_vectors_unset(): def test_init_vectors_unset():
v = Vectors(shape=(10, 10)) v = Vectors(shape=(10, 10))
assert v.is_full is False assert v.is_full is False
assert v.data.shape == (10, 10) assert v.shape == (10, 10)
with pytest.raises(ValueError): with pytest.raises(ValueError):
v = Vectors(shape=(10, 10), mode="floret") v = Vectors(shape=(10, 10), mode="floret")
@ -514,7 +514,7 @@ def test_floret_vectors(floret_vectors_vec_str, floret_vectors_hashvec_str):
# rows: 2 rows per ngram # rows: 2 rows per ngram
rows = OPS.xp.asarray( rows = OPS.xp.asarray(
[ [
h % nlp.vocab.vectors.data.shape[0] h % nlp.vocab.vectors.shape[0]
for ngram in ngrams for ngram in ngrams
for h in nlp.vocab.vectors._get_ngram_hashes(ngram) for h in nlp.vocab.vectors._get_ngram_hashes(ngram)
], ],
@ -544,17 +544,17 @@ def test_floret_vectors(floret_vectors_vec_str, floret_vectors_hashvec_str):
# an empty key returns 0s # an empty key returns 0s
assert_equal( assert_equal(
OPS.to_numpy(nlp.vocab[""].vector), OPS.to_numpy(nlp.vocab[""].vector),
numpy.zeros((nlp.vocab.vectors.data.shape[0],)), numpy.zeros((nlp.vocab.vectors.shape[0],)),
) )
# an empty batch returns 0s # an empty batch returns 0s
assert_equal( assert_equal(
OPS.to_numpy(nlp.vocab.vectors.get_batch([""])), OPS.to_numpy(nlp.vocab.vectors.get_batch([""])),
numpy.zeros((1, nlp.vocab.vectors.data.shape[0])), numpy.zeros((1, nlp.vocab.vectors.shape[0])),
) )
# an empty key within a batch returns 0s # an empty key within a batch returns 0s
assert_equal( assert_equal(
OPS.to_numpy(nlp.vocab.vectors.get_batch(["a", "", "b"])[1]), OPS.to_numpy(nlp.vocab.vectors.get_batch(["a", "", "b"])[1]),
numpy.zeros((nlp.vocab.vectors.data.shape[0],)), numpy.zeros((nlp.vocab.vectors.shape[0],)),
) )
# the loaded ngram vector table cannot be modified # the loaded ngram vector table cannot be modified

View File

@ -616,7 +616,7 @@ cdef class Doc:
""" """
if "has_vector" in self.user_hooks: if "has_vector" in self.user_hooks:
return self.user_hooks["has_vector"](self) return self.user_hooks["has_vector"](self)
elif self.vocab.vectors.data.size: elif self.vocab.vectors.size:
return True return True
elif self.tensor.size: elif self.tensor.size:
return True return True
@ -641,7 +641,7 @@ cdef class Doc:
if not len(self): if not len(self):
self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f") self._vector = xp.zeros((self.vocab.vectors_length,), dtype="f")
return self._vector return self._vector
elif self.vocab.vectors.data.size > 0: elif self.vocab.vectors.size > 0:
self._vector = sum(t.vector for t in self) / len(self) self._vector = sum(t.vector for t in self) / len(self)
return self._vector return self._vector
elif self.tensor.size > 0: elif self.tensor.size > 0:

View File

@ -497,7 +497,7 @@ cdef class Span:
""" """
if "has_vector" in self.doc.user_span_hooks: if "has_vector" in self.doc.user_span_hooks:
return self.doc.user_span_hooks["has_vector"](self) return self.doc.user_span_hooks["has_vector"](self)
elif self.vocab.vectors.data.size > 0: elif self.vocab.vectors.size > 0:
return any(token.has_vector for token in self) return any(token.has_vector for token in self)
elif self.doc.tensor.size > 0: elif self.doc.tensor.size > 0:
return True return True

View File

@ -164,7 +164,7 @@ def load_vectors_into_model(
len(vectors_nlp.vocab.vectors.keys()) == 0 len(vectors_nlp.vocab.vectors.keys()) == 0
and vectors_nlp.vocab.vectors.mode != VectorsMode.floret and vectors_nlp.vocab.vectors.mode != VectorsMode.floret
) or ( ) or (
vectors_nlp.vocab.vectors.data.shape[0] == 0 vectors_nlp.vocab.vectors.shape[0] == 0
and vectors_nlp.vocab.vectors.mode == VectorsMode.floret and vectors_nlp.vocab.vectors.mode == VectorsMode.floret
): ):
logger.warning(Warnings.W112.format(name=name)) logger.warning(Warnings.W112.format(name=name))

View File

@ -10,7 +10,7 @@ from typing import cast
import warnings import warnings
from enum import Enum from enum import Enum
import srsly import srsly
from thinc.api import get_array_module, get_current_ops from thinc.api import Ops, get_array_module, get_current_ops
from thinc.backends import get_array_ops from thinc.backends import get_array_ops
from thinc.types import Floats2d from thinc.types import Floats2d
@ -146,7 +146,7 @@ cdef class Vectors:
DOCS: https://spacy.io/api/vectors#size DOCS: https://spacy.io/api/vectors#size
""" """
return self.data.shape[0] * self.data.shape[1] return self.data.size
@property @property
def is_full(self): def is_full(self):
@ -517,6 +517,9 @@ cdef class Vectors:
for i in range(len(queries)) ], dtype="uint64") for i in range(len(queries)) ], dtype="uint64")
return (keys, best_rows, scores) return (keys, best_rows, scores)
def to_ops(self, ops: Ops):
self.data = ops.asarray(self.data)
def _get_cfg(self): def _get_cfg(self):
if self.mode == Mode.default: if self.mode == Mode.default:
return { return {

View File

@ -283,7 +283,7 @@ cdef class Vocab:
@property @property
def vectors_length(self): def vectors_length(self):
return self.vectors.data.shape[1] return self.vectors.shape[1]
def reset_vectors(self, *, width=None, shape=None): def reset_vectors(self, *, width=None, shape=None):
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
@ -294,7 +294,7 @@ cdef class Vocab:
elif shape is not None: elif shape is not None:
self.vectors = Vectors(strings=self.strings, shape=shape) self.vectors = Vectors(strings=self.strings, shape=shape)
else: else:
width = width if width is not None else self.vectors.data.shape[1] width = width if width is not None else self.vectors.shape[1]
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width)) self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
def prune_vectors(self, nr_row, batch_size=1024): def prune_vectors(self, nr_row, batch_size=1024):

View File

@ -371,6 +371,23 @@ Get the vectors for the provided keys efficiently as a batch.
| ------ | --------------------------------------- | | ------ | --------------------------------------- |
| `keys` | The keys. ~~Iterable[Union[int, str]]~~ | | `keys` | The keys. ~~Iterable[Union[int, str]]~~ |
## Vectors.to_ops {#to_ops tag="method"}
Change the embedding matrix to use different Thinc ops.
> #### Example
>
> ```python
> from thinc.api import NumpyOps
>
> vectors.to_ops(NumpyOps())
>
> ```
| Name | Description |
|-------|----------------------------------------------------------|
| `ops` | The Thinc ops to switch the embedding matrix to. ~~Ops~~ |
## Vectors.to_disk {#to_disk tag="method"} ## Vectors.to_disk {#to_disk tag="method"}
Save the current state to a directory. Save the current state to a directory.