From 52bc927db9320d96dbe1de15a41501e267960b64 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 15 Nov 2022 13:47:08 +0900 Subject: [PATCH] Add equality definition for vectors This re-uses the check from sourcing components. --- spacy/tests/vocab_vectors/test_vectors.py | 20 ++++++++++++++++++++ spacy/vectors.pyx | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py index dd2cfc596..70835816d 100644 --- a/spacy/tests/vocab_vectors/test_vectors.py +++ b/spacy/tests/vocab_vectors/test_vectors.py @@ -626,3 +626,23 @@ def test_floret_vectors(floret_vectors_vec_str, floret_vectors_hashvec_str): OPS.to_numpy(vocab_r[word].vector), decimal=6, ) + + +def test_equality(): + vectors1 = Vectors(shape=(10, 10)) + vectors2 = Vectors(shape=(10, 8)) + + assert vectors1 != vectors2 + + vectors2 = Vectors(shape=(10, 10)) + assert vectors1 == vectors2 + + vectors1.add("hello", row=2) + assert vectors1 != vectors2 + + vectors2.add("hello", row=2) + assert vectors1 == vectors2 + + vectors1.resize((5, 9)) + vectors2.resize((5, 9)) + assert vectors1 == vectors2 diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 8300220c1..584b84715 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -243,6 +243,16 @@ cdef class Vectors: else: return key in self.key2row + + def __eq__(self, other): + return ( + self.shape == other.shape + and self.key2row == other.key2row + and self.to_bytes(exclude=["strings"]) + == other.to_bytes(exclude=["strings"]) + ) + + def resize(self, shape, inplace=False): """Resize the underlying vectors array. If inplace=True, the memory is reallocated. This may cause other references to the data to become