mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Span richcmp fix (#9956)
* Corrected Span's __richcmp__ implementation to take end, label and kb_id in consideration * Updated test * Updated test * Removed formatting from a test for readability sake * Use same tuples for all comparisons Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
e8a047a8d4
commit
47ea6704f1
|
@ -573,6 +573,55 @@ def test_span_with_vectors(doc):
|
|||
doc.vocab.vectors = prev_vectors
|
||||
|
||||
|
||||
# fmt: off
|
||||
def test_span_comparison(doc):
|
||||
|
||||
# Identical start, end, only differ in label and kb_id
|
||||
assert Span(doc, 0, 3) == Span(doc, 0, 3)
|
||||
assert Span(doc, 0, 3, "LABEL") == Span(doc, 0, 3, "LABEL")
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") == Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL")
|
||||
assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
assert Span(doc, 0, 3, "LABEL") != Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
assert Span(doc, 0, 3) <= Span(doc, 0, 3) and Span(doc, 0, 3) >= Span(doc, 0, 3)
|
||||
assert Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL") and Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "LABEL")
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
assert (Span(doc, 0, 3) < Span(doc, 0, 3, "", kb_id="KB_ID") < Span(doc, 0, 3, "LABEL") < Span(doc, 0, 3, "LABEL", kb_id="KB_ID"))
|
||||
assert (Span(doc, 0, 3) <= Span(doc, 0, 3, "", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID"))
|
||||
|
||||
assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") > Span(doc, 0, 3, "LABEL") > Span(doc, 0, 3, "", kb_id="KB_ID") > Span(doc, 0, 3))
|
||||
assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "", kb_id="KB_ID") >= Span(doc, 0, 3))
|
||||
|
||||
# Different end
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
|
||||
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4)
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 4)
|
||||
assert Span(doc, 0, 4) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
assert Span(doc, 0, 4) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
# Different start
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3)
|
||||
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3)
|
||||
assert Span(doc, 1, 3) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
assert Span(doc, 1, 3) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
# Different start & different end
|
||||
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID")
|
||||
|
||||
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3)
|
||||
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3)
|
||||
assert Span(doc, 1, 3) > Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
|
||||
assert Span(doc, 1, 3) >= Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start,end,expected_sentences,expected_sentences_with_hook",
|
||||
[
|
||||
|
|
|
@ -126,38 +126,26 @@ cdef class Span:
|
|||
return False
|
||||
else:
|
||||
return True
|
||||
self_tuple = (self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.doc)
|
||||
other_tuple = (other.c.start_char, other.c.end_char, other.c.label, other.c.kb_id, other.doc)
|
||||
# <
|
||||
if op == 0:
|
||||
return self.c.start_char < other.c.start_char
|
||||
return self_tuple < other_tuple
|
||||
# <=
|
||||
elif op == 1:
|
||||
return self.c.start_char <= other.c.start_char
|
||||
return self_tuple <= other_tuple
|
||||
# ==
|
||||
elif op == 2:
|
||||
# Do the cheap comparisons first
|
||||
return (
|
||||
(self.c.start_char == other.c.start_char) and \
|
||||
(self.c.end_char == other.c.end_char) and \
|
||||
(self.c.label == other.c.label) and \
|
||||
(self.c.kb_id == other.c.kb_id) and \
|
||||
(self.doc == other.doc)
|
||||
)
|
||||
return self_tuple == other_tuple
|
||||
# !=
|
||||
elif op == 3:
|
||||
# Do the cheap comparisons first
|
||||
return not (
|
||||
(self.c.start_char == other.c.start_char) and \
|
||||
(self.c.end_char == other.c.end_char) and \
|
||||
(self.c.label == other.c.label) and \
|
||||
(self.c.kb_id == other.c.kb_id) and \
|
||||
(self.doc == other.doc)
|
||||
)
|
||||
return self_tuple != other_tuple
|
||||
# >
|
||||
elif op == 4:
|
||||
return self.c.start_char > other.c.start_char
|
||||
return self_tuple > other_tuple
|
||||
# >=
|
||||
elif op == 5:
|
||||
return self.c.start_char >= other.c.start_char
|
||||
return self_tuple >= other_tuple
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))
|
||||
|
|
Loading…
Reference in New Issue
Block a user