Make .similarity() return 1.0 if all orth attrs match

This commit is contained in:
Matthew Honnibal 2018-01-15 16:29:48 +01:00
parent 82135d85b7
commit ccb51a9f36
6 changed files with 56 additions and 0 deletions

View File

@ -112,6 +112,14 @@ cdef class Lexeme:
`Span`, `Token` and `Lexeme` objects.
RETURNS (float): A scalar similarity score. Higher is more similar.
"""
# Return 1.0 similarity for matches
if hasattr(other, 'orth'):
if self.c.orth == other.orth:
return 1.0
elif hasattr(other, '__len__') and len(other) == 1 \
and hasattr(other[0], 'orth'):
if self.c.orth == other[0].orth:
return 1.0
if self.vector_norm == 0 or other.vector_norm == 0:
return 0.0
return (numpy.dot(self.vector, other.vector) /

View File

@ -217,6 +217,16 @@ def test_doc_api_has_vector():
doc = Doc(vocab, words=['kitten'])
assert doc.has_vector
def test_doc_api_similarity_match():
doc = Doc(Vocab(), words=['a'])
assert doc.similarity(doc[0]) == 1.0
assert doc.similarity(doc.vocab['a']) == 1.0
doc2 = Doc(doc.vocab, words=['a', 'b', 'c'])
assert doc.similarity(doc2[:1]) == 1.0
assert doc.similarity(doc2) == 0.0
def test_lowest_common_ancestor(en_tokenizer):
tokens = en_tokenizer('the lazy dog slept')
doc = get_doc(tokens.vocab, [t.text for t in tokens], heads=[2, 1, 1, 0])
@ -225,6 +235,7 @@ def test_lowest_common_ancestor(en_tokenizer):
assert(lca[0, 1] == 2)
assert(lca[1, 2] == 2)
def test_parse_tree(en_tokenizer):
"""Tests doc.print_tree() method."""
text = 'I like New York in Autumn.'

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
from ..util import get_doc
from ...attrs import ORTH, LENGTH
from ...tokens import Doc
from ...vocab import Vocab
import pytest
@ -66,6 +68,15 @@ def test_spans_lca_matrix(en_tokenizer):
assert(lca[1, 1] == 1)
def test_span_similarity_match():
doc = Doc(Vocab(), words=['a', 'b', 'a', 'b'])
span1 = doc[:2]
span2 = doc[2:]
assert span1.similarity(span2) == 1.0
assert span1.similarity(doc) == 0.0
assert span1[:1].similarity(doc.vocab['a']) == 1.0
def test_spans_default_sentiment(en_tokenizer):
"""Test span.sentiment property's default averaging behaviour"""
text = "good stuff bad stuff"

View File

@ -295,6 +295,17 @@ cdef class Doc:
"""
if 'similarity' in self.user_hooks:
return self.user_hooks['similarity'](self, other)
if isinstance(other, (Lexeme, Token)) and self.length == 1:
if self.c[0].lex.orth == other.orth:
return 1.0
elif isinstance(other, (Span, Doc)):
if len(self) == len(other):
for i in range(self.length):
if self[i].orth != other[i].orth:
break
else:
return 1.0
if self.vector_norm == 0 or other.vector_norm == 0:
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)

View File

@ -184,6 +184,15 @@ cdef class Span:
"""
if 'similarity' in self.doc.user_span_hooks:
self.doc.user_span_hooks['similarity'](self, other)
if len(self) == 1 and hasattr(other, 'orth'):
if self[0].orth == other.orth:
return 1.0
elif hasattr(other, '__len__') and len(self) == len(other):
for i in range(len(self)):
if self[i].orth != getattr(other[i], 'orth', None):
break
else:
return 1.0
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)

View File

@ -149,6 +149,12 @@ cdef class Token:
"""
if 'similarity' in self.doc.user_token_hooks:
return self.doc.user_token_hooks['similarity'](self)
if hasattr(other, '__len__') and len(other) == 1:
if self.c.lex.orth == getattr(other[0], 'orth', None):
return 1.0
elif hasattr(other, 'orth'):
if self.c.lex.orth == other.orth:
return 1.0
if self.vector_norm == 0 or other.vector_norm == 0:
return 0.0
return (numpy.dot(self.vector, other.vector) /