mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Add vector deduplication (#10551)
* Add vector deduplication * Add `Vocab.deduplicate_vectors()` * Always run deduplication in `spacy init vectors` * Clean up a few vector-related error messages and docs examples * Always unique with numpy * Fix types
This commit is contained in:
parent
9966e08f32
commit
f98b41c390
|
@ -528,7 +528,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E858 = ("The {mode} vector table does not support this operation. "
|
E858 = ("The {mode} vector table does not support this operation. "
|
||||||
"{alternative}")
|
"{alternative}")
|
||||||
E859 = ("The floret vector table cannot be modified.")
|
E859 = ("The floret vector table cannot be modified.")
|
||||||
E860 = ("Can't truncate fasttext-bloom vectors.")
|
E860 = ("Can't truncate floret vectors.")
|
||||||
E861 = ("No 'keys' should be provided when initializing floret vectors "
|
E861 = ("No 'keys' should be provided when initializing floret vectors "
|
||||||
"with 'minn' and 'maxn'.")
|
"with 'minn' and 'maxn'.")
|
||||||
E862 = ("'hash_count' must be between 1-4 for floret vectors.")
|
E862 = ("'hash_count' must be between 1-4 for floret vectors.")
|
||||||
|
|
|
@ -455,6 +455,39 @@ def test_vectors_get_batch():
|
||||||
assert_equal(OPS.to_numpy(vecs), OPS.to_numpy(v.get_batch(words)))
|
assert_equal(OPS.to_numpy(vecs), OPS.to_numpy(v.get_batch(words)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectors_deduplicate():
|
||||||
|
data = OPS.asarray([[1, 1], [2, 2], [3, 4], [1, 1], [3, 4]], dtype="f")
|
||||||
|
v = Vectors(data=data, keys=["a1", "b1", "c1", "a2", "c2"])
|
||||||
|
vocab = Vocab()
|
||||||
|
vocab.vectors = v
|
||||||
|
# duplicate vectors do not use the same keys
|
||||||
|
assert (
|
||||||
|
vocab.vectors.key2row[v.strings["a1"]] != vocab.vectors.key2row[v.strings["a2"]]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
vocab.vectors.key2row[v.strings["c1"]] != vocab.vectors.key2row[v.strings["c2"]]
|
||||||
|
)
|
||||||
|
vocab.deduplicate_vectors()
|
||||||
|
# there are three unique vectors
|
||||||
|
assert vocab.vectors.shape[0] == 3
|
||||||
|
# the uniqued data is the same as the deduplicated data
|
||||||
|
assert_equal(
|
||||||
|
numpy.unique(OPS.to_numpy(vocab.vectors.data), axis=0),
|
||||||
|
OPS.to_numpy(vocab.vectors.data),
|
||||||
|
)
|
||||||
|
# duplicate vectors use the same keys now
|
||||||
|
assert (
|
||||||
|
vocab.vectors.key2row[v.strings["a1"]] == vocab.vectors.key2row[v.strings["a2"]]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
vocab.vectors.key2row[v.strings["c1"]] == vocab.vectors.key2row[v.strings["c2"]]
|
||||||
|
)
|
||||||
|
# deduplicating again makes no changes
|
||||||
|
vocab_b = vocab.to_bytes()
|
||||||
|
vocab.deduplicate_vectors()
|
||||||
|
assert vocab_b == vocab.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def floret_vectors_hashvec_str():
|
def floret_vectors_hashvec_str():
|
||||||
"""The full hashvec table from floret with the settings:
|
"""The full hashvec table from floret with the settings:
|
||||||
|
|
|
@ -213,6 +213,7 @@ def convert_vectors(
|
||||||
for lex in nlp.vocab:
|
for lex in nlp.vocab:
|
||||||
if lex.rank and lex.rank != OOV_RANK:
|
if lex.rank and lex.rank != OOV_RANK:
|
||||||
nlp.vocab.vectors.add(lex.orth, row=lex.rank) # type: ignore[attr-defined]
|
nlp.vocab.vectors.add(lex.orth, row=lex.rank) # type: ignore[attr-defined]
|
||||||
|
nlp.vocab.deduplicate_vectors()
|
||||||
else:
|
else:
|
||||||
if vectors_loc:
|
if vectors_loc:
|
||||||
logger.info(f"Reading vectors from {vectors_loc}")
|
logger.info(f"Reading vectors from {vectors_loc}")
|
||||||
|
@ -239,6 +240,7 @@ def convert_vectors(
|
||||||
nlp.vocab.vectors = Vectors(
|
nlp.vocab.vectors = Vectors(
|
||||||
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys
|
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys
|
||||||
)
|
)
|
||||||
|
nlp.vocab.deduplicate_vectors()
|
||||||
if name is None:
|
if name is None:
|
||||||
# TODO: Is this correct? Does this matter?
|
# TODO: Is this correct? Does this matter?
|
||||||
nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
|
nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
|
||||||
|
|
|
@ -46,6 +46,7 @@ class Vocab:
|
||||||
def reset_vectors(
|
def reset_vectors(
|
||||||
self, *, width: Optional[int] = ..., shape: Optional[int] = ...
|
self, *, width: Optional[int] = ..., shape: Optional[int] = ...
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
def deduplicate_vectors(self) -> None: ...
|
||||||
def prune_vectors(self, nr_row: int, batch_size: int = ...) -> Dict[str, float]: ...
|
def prune_vectors(self, nr_row: int, batch_size: int = ...) -> Dict[str, float]: ...
|
||||||
def get_vector(
|
def get_vector(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
|
|
||||||
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import get_array_module, get_current_ops
|
from thinc.api import get_array_module, get_current_ops
|
||||||
import functools
|
import functools
|
||||||
|
@ -297,6 +298,33 @@ cdef class Vocab:
|
||||||
width = width if width is not None else self.vectors.shape[1]
|
width = width if width is not None else self.vectors.shape[1]
|
||||||
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
||||||
|
|
||||||
|
def deduplicate_vectors(self):
|
||||||
|
if self.vectors.mode != VectorsMode.default:
|
||||||
|
raise ValueError(Errors.E858.format(
|
||||||
|
mode=self.vectors.mode,
|
||||||
|
alternative=""
|
||||||
|
))
|
||||||
|
ops = get_current_ops()
|
||||||
|
xp = get_array_module(self.vectors.data)
|
||||||
|
filled = xp.asarray(
|
||||||
|
sorted(list({row for row in self.vectors.key2row.values()}))
|
||||||
|
)
|
||||||
|
# deduplicate data and remap keys
|
||||||
|
data = numpy.unique(ops.to_numpy(self.vectors.data[filled]), axis=0)
|
||||||
|
data = ops.asarray(data)
|
||||||
|
if data.shape == self.vectors.data.shape:
|
||||||
|
# nothing to deduplicate
|
||||||
|
return
|
||||||
|
row_by_bytes = {row.tobytes(): i for i, row in enumerate(data)}
|
||||||
|
key2row = {
|
||||||
|
key: row_by_bytes[self.vectors.data[row].tobytes()]
|
||||||
|
for key, row in self.vectors.key2row.items()
|
||||||
|
}
|
||||||
|
# replace vectors with deduplicated version
|
||||||
|
self.vectors = Vectors(strings=self.strings, data=data, name=self.vectors.name)
|
||||||
|
for key, row in key2row.items():
|
||||||
|
self.vectors.add(key, row=row)
|
||||||
|
|
||||||
def prune_vectors(self, nr_row, batch_size=1024):
|
def prune_vectors(self, nr_row, batch_size=1024):
|
||||||
"""Reduce the current vector table to `nr_row` unique entries. Words
|
"""Reduce the current vector table to `nr_row` unique entries. Words
|
||||||
mapped to the discarded vectors will be remapped to the closest vector
|
mapped to the discarded vectors will be remapped to the closest vector
|
||||||
|
@ -325,7 +353,10 @@ cdef class Vocab:
|
||||||
DOCS: https://spacy.io/api/vocab#prune_vectors
|
DOCS: https://spacy.io/api/vocab#prune_vectors
|
||||||
"""
|
"""
|
||||||
if self.vectors.mode != VectorsMode.default:
|
if self.vectors.mode != VectorsMode.default:
|
||||||
raise ValueError(Errors.E866)
|
raise ValueError(Errors.E858.format(
|
||||||
|
mode=self.vectors.mode,
|
||||||
|
alternative=""
|
||||||
|
))
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
xp = get_array_module(self.vectors.data)
|
xp = get_array_module(self.vectors.data)
|
||||||
# Make sure all vectors are in the vocab
|
# Make sure all vectors are in the vocab
|
||||||
|
|
|
@ -156,7 +156,7 @@ cosines are calculated in minibatches to reduce memory usage.
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> nlp.vocab.prune_vectors(10000)
|
> nlp.vocab.prune_vectors(10000)
|
||||||
> assert len(nlp.vocab.vectors) <= 1000
|
> assert len(nlp.vocab.vectors) <= 10000
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|
@ -165,6 +165,17 @@ cosines are calculated in minibatches to reduce memory usage.
|
||||||
| `batch_size` | Batch of vectors for calculating the similarities. Larger batch sizes might be faster, while temporarily requiring more memory. ~~int~~ |
|
| `batch_size` | Batch of vectors for calculating the similarities. Larger batch sizes might be faster, while temporarily requiring more memory. ~~int~~ |
|
||||||
| **RETURNS** | A dictionary keyed by removed words mapped to `(string, score)` tuples, where `string` is the entry the removed word was mapped to, and `score` the similarity score between the two words. ~~Dict[str, Tuple[str, float]]~~ |
|
| **RETURNS** | A dictionary keyed by removed words mapped to `(string, score)` tuples, where `string` is the entry the removed word was mapped to, and `score` the similarity score between the two words. ~~Dict[str, Tuple[str, float]]~~ |
|
||||||
|
|
||||||
|
## Vocab.deduplicate_vectors {#deduplicate_vectors tag="method" new="3.3"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> nlp.vocab.deduplicate_vectors()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Remove any duplicate rows from the current vector table, maintaining the
|
||||||
|
mappings for all words in the vectors.
|
||||||
|
|
||||||
## Vocab.get_vector {#get_vector tag="method" new="2"}
|
## Vocab.get_vector {#get_vector tag="method" new="2"}
|
||||||
|
|
||||||
Retrieve a vector for a word in the vocabulary. Words can be looked up by string
|
Retrieve a vector for a word in the vocabulary. Words can be looked up by string
|
||||||
|
@ -179,7 +190,7 @@ or hash value. If the current vectors do not contain an entry for the word, a
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------------------------------- | ---------------------------------------------------------------------------------------------------------------------- |
|
| ----------- | ---------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `orth` | The hash value of a word, or its unicode string. ~~Union[int, str]~~ |
|
| `orth` | The hash value of a word, or its unicode string. ~~Union[int, str]~~ |
|
||||||
| **RETURNS** | A word vector. Size and shape are determined by the `Vocab.vectors` instance. ~~numpy.ndarray[ndim=1, dtype=float32]~~ |
|
| **RETURNS** | A word vector. Size and shape are determined by the `Vocab.vectors` instance. ~~numpy.ndarray[ndim=1, dtype=float32]~~ |
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user