mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add equality definition for vectors (#11806)
* Add equality definition for vectors This re-uses the check from sourcing components. * Use the equality check * Format Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
		
							parent
							
								
									caa9efad59
								
							
						
					
					
						commit
						c0c54e44bc
					
				| 
						 | 
					@ -706,13 +706,7 @@ class Language:
 | 
				
			||||||
        # Check source type
 | 
					        # Check source type
 | 
				
			||||||
        if not isinstance(source, Language):
 | 
					        if not isinstance(source, Language):
 | 
				
			||||||
            raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
 | 
					            raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
 | 
				
			||||||
        # Check vectors, with faster checks first
 | 
					        if self.vocab.vectors != source.vocab.vectors:
 | 
				
			||||||
        if (
 | 
					 | 
				
			||||||
            self.vocab.vectors.shape != source.vocab.vectors.shape
 | 
					 | 
				
			||||||
            or self.vocab.vectors.key2row != source.vocab.vectors.key2row
 | 
					 | 
				
			||||||
            or self.vocab.vectors.to_bytes(exclude=["strings"])
 | 
					 | 
				
			||||||
            != source.vocab.vectors.to_bytes(exclude=["strings"])
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            warnings.warn(Warnings.W113.format(name=source_name))
 | 
					            warnings.warn(Warnings.W113.format(name=source_name))
 | 
				
			||||||
        if source_name not in source.component_names:
 | 
					        if source_name not in source.component_names:
 | 
				
			||||||
            raise KeyError(
 | 
					            raise KeyError(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -626,3 +626,23 @@ def test_floret_vectors(floret_vectors_vec_str, floret_vectors_hashvec_str):
 | 
				
			||||||
            OPS.to_numpy(vocab_r[word].vector),
 | 
					            OPS.to_numpy(vocab_r[word].vector),
 | 
				
			||||||
            decimal=6,
 | 
					            decimal=6,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_equality():
 | 
				
			||||||
 | 
					    vectors1 = Vectors(shape=(10, 10))
 | 
				
			||||||
 | 
					    vectors2 = Vectors(shape=(10, 8))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert vectors1 != vectors2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vectors2 = Vectors(shape=(10, 10))
 | 
				
			||||||
 | 
					    assert vectors1 == vectors2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vectors1.add("hello", row=2)
 | 
				
			||||||
 | 
					    assert vectors1 != vectors2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vectors2.add("hello", row=2)
 | 
				
			||||||
 | 
					    assert vectors1 == vectors2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vectors1.resize((5, 9))
 | 
				
			||||||
 | 
					    vectors2.resize((5, 9))
 | 
				
			||||||
 | 
					    assert vectors1 == vectors2
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -243,6 +243,15 @@ cdef class Vectors:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return key in self.key2row
 | 
					            return key in self.key2row
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __eq__(self, other):
 | 
				
			||||||
 | 
					        # Check for equality, with faster checks first
 | 
				
			||||||
 | 
					        return (
 | 
				
			||||||
 | 
					                self.shape == other.shape
 | 
				
			||||||
 | 
					                and self.key2row == other.key2row
 | 
				
			||||||
 | 
					                and self.to_bytes(exclude=["strings"])
 | 
				
			||||||
 | 
					                  == other.to_bytes(exclude=["strings"])
 | 
				
			||||||
 | 
					               )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def resize(self, shape, inplace=False):
 | 
					    def resize(self, shape, inplace=False):
 | 
				
			||||||
        """Resize the underlying vectors array. If inplace=True, the memory
 | 
					        """Resize the underlying vectors array. If inplace=True, the memory
 | 
				
			||||||
        is reallocated. This may cause other references to the data to become
 | 
					        is reallocated. This may cause other references to the data to become
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user