From 47ea6704f1045ee3a04ac7ffbfedba01d944e233 Mon Sep 17 00:00:00 2001 From: Natalia Rodnova <4512370+nrodnova@users.noreply.github.com> Date: Mon, 17 Jan 2022 03:17:49 -0700 Subject: [PATCH] Span richcmp fix (#9956) * Corrected Span's __richcmp__ implementation to take end, label and kb_id in consideration * Updated test * Updated test * Removed formatting from a test for readability sake * Use same tuples for all comparisons Co-authored-by: Adriane Boyd --- spacy/tests/doc/test_span.py | 49 ++++++++++++++++++++++++++++++++++++ spacy/tokens/span.pyx | 28 ++++++--------------- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 10aba5b94..bdf34c1c1 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -573,6 +573,55 @@ def test_span_with_vectors(doc): doc.vocab.vectors = prev_vectors +# fmt: off +def test_span_comparison(doc): + + # Identical start, end, only differ in label and kb_id + assert Span(doc, 0, 3) == Span(doc, 0, 3) + assert Span(doc, 0, 3, "LABEL") == Span(doc, 0, 3, "LABEL") + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") == Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + + assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL") + assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + assert Span(doc, 0, 3, "LABEL") != Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + + assert Span(doc, 0, 3) <= Span(doc, 0, 3) and Span(doc, 0, 3) >= Span(doc, 0, 3) + assert Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL") and Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "LABEL") + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + + assert (Span(doc, 0, 3) < Span(doc, 0, 3, "", kb_id="KB_ID") < Span(doc, 0, 3, "LABEL") < Span(doc, 0, 3, "LABEL", kb_id="KB_ID")) + assert (Span(doc, 0, 3) <= Span(doc, 0, 3, "", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")) + + assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") > Span(doc, 0, 3, "LABEL") > Span(doc, 0, 3, "", kb_id="KB_ID") > Span(doc, 0, 3)) + assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "", kb_id="KB_ID") >= Span(doc, 0, 3)) + + # Different end + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4, "LABEL", kb_id="KB_ID") + + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4) + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 4) + assert Span(doc, 0, 4) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + assert Span(doc, 0, 4) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + + # Different start + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID") + + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3) + assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3) + assert Span(doc, 1, 3) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + assert Span(doc, 1, 3) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") + + # Different start & different end + assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID") + + assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3) + assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3) + assert Span(doc, 1, 3) > Span(doc, 0, 4, "LABEL", kb_id="KB_ID") + assert Span(doc, 1, 3) >= Span(doc, 0, 4, "LABEL", kb_id="KB_ID") +# fmt: on + + @pytest.mark.parametrize( "start,end,expected_sentences,expected_sentences_with_hook", [ diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index cd02cab36..5484b25fd 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -126,38 +126,26 @@ cdef class Span: return False else: return True + self_tuple = (self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.doc) + other_tuple = (other.c.start_char, other.c.end_char, other.c.label, other.c.kb_id, other.doc) # < if op == 0: - return self.c.start_char < other.c.start_char + return self_tuple < other_tuple # <= elif op == 1: - return self.c.start_char <= other.c.start_char + return self_tuple <= other_tuple # == elif op == 2: - # Do the cheap comparisons first - return ( - (self.c.start_char == other.c.start_char) and \ - (self.c.end_char == other.c.end_char) and \ - (self.c.label == other.c.label) and \ - (self.c.kb_id == other.c.kb_id) and \ - (self.doc == other.doc) - ) + return self_tuple == other_tuple # != elif op == 3: - # Do the cheap comparisons first - return not ( - (self.c.start_char == other.c.start_char) and \ - (self.c.end_char == other.c.end_char) and \ - (self.c.label == other.c.label) and \ - (self.c.kb_id == other.c.kb_id) and \ - (self.doc == other.doc) - ) + return self_tuple != other_tuple # > elif op == 4: - return self.c.start_char > other.c.start_char + return self_tuple > other_tuple # >= elif op == 5: - return self.c.start_char >= other.c.start_char + return self_tuple >= other_tuple def __hash__(self): return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))