mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Span richcmp fix (#9956)
* Corrected Span's __richcmp__ implementation to take end, label and kb_id in consideration * Updated test * Updated test * Removed formatting from a test for readability sake * Use same tuples for all comparisons Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
		
							parent
							
								
									e8a047a8d4
								
							
						
					
					
						commit
						47ea6704f1
					
				|  | @ -573,6 +573,55 @@ def test_span_with_vectors(doc): | |||
|     doc.vocab.vectors = prev_vectors | ||||
| 
 | ||||
| 
 | ||||
| # fmt: off | ||||
| def test_span_comparison(doc): | ||||
| 
 | ||||
|     # Identical start, end, only differ in label and kb_id | ||||
|     assert Span(doc, 0, 3) == Span(doc, 0, 3) | ||||
|     assert Span(doc, 0, 3, "LABEL") == Span(doc, 0, 3, "LABEL") | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") == Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL") | ||||
|     assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
|     assert Span(doc, 0, 3, "LABEL") != Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     assert Span(doc, 0, 3) <= Span(doc, 0, 3) and Span(doc, 0, 3) >= Span(doc, 0, 3) | ||||
|     assert Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL") and Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "LABEL") | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     assert (Span(doc, 0, 3) < Span(doc, 0, 3, "", kb_id="KB_ID") < Span(doc, 0, 3, "LABEL") < Span(doc, 0, 3, "LABEL", kb_id="KB_ID")) | ||||
|     assert (Span(doc, 0, 3) <= Span(doc, 0, 3, "", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")) | ||||
| 
 | ||||
|     assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") > Span(doc, 0, 3, "LABEL") > Span(doc, 0, 3, "", kb_id="KB_ID") > Span(doc, 0, 3)) | ||||
|     assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "", kb_id="KB_ID") >= Span(doc, 0, 3)) | ||||
| 
 | ||||
|     # Different end | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4) | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 4) | ||||
|     assert Span(doc, 0, 4) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
|     assert Span(doc, 0, 4) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     # Different start | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3) | ||||
|     assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3) | ||||
|     assert Span(doc, 1, 3) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
|     assert Span(doc, 1, 3) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     # Different start & different end | ||||
|     assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID") | ||||
| 
 | ||||
|     assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3) | ||||
|     assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3) | ||||
|     assert Span(doc, 1, 3) > Span(doc, 0, 4, "LABEL", kb_id="KB_ID") | ||||
|     assert Span(doc, 1, 3) >= Span(doc, 0, 4, "LABEL", kb_id="KB_ID") | ||||
| # fmt: on | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "start,end,expected_sentences,expected_sentences_with_hook", | ||||
|     [ | ||||
|  |  | |||
|  | @ -126,38 +126,26 @@ cdef class Span: | |||
|                 return False | ||||
|             else: | ||||
|                 return True | ||||
|         self_tuple = (self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.doc) | ||||
|         other_tuple = (other.c.start_char, other.c.end_char, other.c.label, other.c.kb_id, other.doc) | ||||
|         # < | ||||
|         if op == 0: | ||||
|             return self.c.start_char < other.c.start_char | ||||
|             return self_tuple < other_tuple | ||||
|         # <= | ||||
|         elif op == 1: | ||||
|             return self.c.start_char <= other.c.start_char | ||||
|             return self_tuple <= other_tuple | ||||
|         # == | ||||
|         elif op == 2: | ||||
|             # Do the cheap comparisons first | ||||
|             return ( | ||||
|                 (self.c.start_char == other.c.start_char) and \ | ||||
|                 (self.c.end_char == other.c.end_char) and \ | ||||
|                 (self.c.label == other.c.label) and \ | ||||
|                 (self.c.kb_id == other.c.kb_id) and \ | ||||
|                 (self.doc == other.doc) | ||||
|             ) | ||||
|             return self_tuple == other_tuple | ||||
|         # != | ||||
|         elif op == 3: | ||||
|             # Do the cheap comparisons first | ||||
|             return not ( | ||||
|                 (self.c.start_char == other.c.start_char) and \ | ||||
|                 (self.c.end_char == other.c.end_char) and \ | ||||
|                 (self.c.label == other.c.label) and \ | ||||
|                 (self.c.kb_id == other.c.kb_id) and \ | ||||
|                 (self.doc == other.doc) | ||||
|             ) | ||||
|             return self_tuple != other_tuple | ||||
|         # > | ||||
|         elif op == 4: | ||||
|             return self.c.start_char > other.c.start_char | ||||
|             return self_tuple > other_tuple | ||||
|         # >= | ||||
|         elif op == 5: | ||||
|             return self.c.start_char >= other.c.start_char | ||||
|             return self_tuple >= other_tuple | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|         return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id)) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user