mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 15:37:29 +03:00 
			
		
		
		
	TEST: Doc.ents as SpanGroup
Overview of required changes to support `SpanGroup` rather than `Tuple[Span]`: * implement slice for `SpanGroup` * return `SpanGroup` for `SpanGroup + x` or `x + SpanGroup` rather than refusing to concatenate (currently without good error handling) Side effects: * if appending to `Doc.ents`, only `Iterable[Span]` is supported rather than other formats like raw entity tuples from matcher results, but you can still assign mixed data in any of the currently supported formats to `Doc.ents` * for the `Matcher` case, `as_spans` provides a good alternative
This commit is contained in:
		
							parent
							
								
									41b3a0d932
								
							
						
					
					
						commit
						31cd5141bd
					
				|  | @ -815,7 +815,7 @@ def test_doc_set_ents(en_tokenizer): | ||||||
|     doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified") |     doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified") | ||||||
|     assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3] |     assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3] | ||||||
|     assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0] |     assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0] | ||||||
|     assert doc.ents == tuple() |     assert len(doc.ents) == 0 | ||||||
| 
 | 
 | ||||||
|     # invalid IOB repaired after blocked |     # invalid IOB repaired after blocked | ||||||
|     doc.ents = [Span(doc, 3, 5, "ENT")] |     doc.ents = [Span(doc, 3, 5, "ENT")] | ||||||
|  | @ -892,7 +892,7 @@ def test_doc_init_iob(): | ||||||
|     words = ["a", "b", "c", "d", "e"] |     words = ["a", "b", "c", "d", "e"] | ||||||
|     ents = ["O"] * len(words) |     ents = ["O"] * len(words) | ||||||
|     doc = Doc(Vocab(), words=words, ents=ents) |     doc = Doc(Vocab(), words=words, ents=ents) | ||||||
|     assert doc.ents == () |     assert len(doc.ents) == 0 | ||||||
| 
 | 
 | ||||||
|     ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"] |     ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"] | ||||||
|     doc = Doc(Vocab(), words=words, ents=ents) |     doc = Doc(Vocab(), words=words, ents=ents) | ||||||
|  |  | ||||||
|  | @ -87,7 +87,10 @@ def test_issue118_prefix_reorder(en_tokenizer, patterns): | ||||||
|     matcher.add("BostonCeltics", patterns) |     matcher.add("BostonCeltics", patterns) | ||||||
|     assert len(list(doc.ents)) == 0 |     assert len(list(doc.ents)) == 0 | ||||||
|     matches = [(ORG, start, end) for _, start, end in matcher(doc)] |     matches = [(ORG, start, end) for _, start, end in matcher(doc)] | ||||||
|     doc.ents += tuple(matches)[1:] |     matches_spans = matcher(doc, as_spans=True) | ||||||
|  |     for span in matches_spans: | ||||||
|  |         span.label = ORG | ||||||
|  |     doc.ents += matches_spans[1:] | ||||||
|     assert matches == [(ORG, 9, 10), (ORG, 9, 11)] |     assert matches == [(ORG, 9, 10), (ORG, 9, 11)] | ||||||
|     ents = doc.ents |     ents = doc.ents | ||||||
|     assert len(ents) == 1 |     assert len(ents) == 1 | ||||||
|  | @ -116,7 +119,8 @@ def test_issue242(en_tokenizer): | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|         # One token can only be part of one entity, so test that the matches |         # One token can only be part of one entity, so test that the matches | ||||||
|         # can't be added as entities |         # can't be added as entities | ||||||
|         doc.ents += tuple(matches) |         matches_spans = matcher(doc, as_spans=True) | ||||||
|  |         doc.ents += tuple(matches_spans) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.issue(587) | @pytest.mark.issue(587) | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ from cymem.cymem import Pool | ||||||
| from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged | from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged | ||||||
| from .span import Span | from .span import Span | ||||||
| from .token import Token | from .token import Token | ||||||
|  | from .span_group import SpanGroup | ||||||
| from .span_groups import SpanGroups | from .span_groups import SpanGroups | ||||||
| from .retokenizer import Retokenizer | from .retokenizer import Retokenizer | ||||||
| from ..lexeme import Lexeme | from ..lexeme import Lexeme | ||||||
|  | @ -120,7 +121,7 @@ class Doc: | ||||||
|     def text(self) -> str: ... |     def text(self) -> str: ... | ||||||
|     @property |     @property | ||||||
|     def text_with_ws(self) -> str: ... |     def text_with_ws(self) -> str: ... | ||||||
|     ents: Tuple[Span] |     ents: SpanGroup | ||||||
|     def set_ents( |     def set_ents( | ||||||
|         self, |         self, | ||||||
|         entities: List[Span], |         entities: List[Span], | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ import warnings | ||||||
| 
 | 
 | ||||||
| from .span cimport Span | from .span cimport Span | ||||||
| from .token cimport MISSING_DEP | from .token cimport MISSING_DEP | ||||||
|  | from .span_group import SpanGroup | ||||||
| from .span_groups import SpanGroups | from .span_groups import SpanGroups | ||||||
| from .token cimport Token | from .token cimport Token | ||||||
| from ..lexeme cimport Lexeme, EMPTY_LEXEME | from ..lexeme cimport Lexeme, EMPTY_LEXEME | ||||||
|  | @ -743,7 +744,7 @@ cdef class Doc: | ||||||
|                 output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id)) |                 output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id)) | ||||||
|             # remove empty-label spans |             # remove empty-label spans | ||||||
|             output = [o for o in output if o.label_ != ""] |             output = [o for o in output if o.label_ != ""] | ||||||
|             return tuple(output) |             return SpanGroup(self, spans=output) | ||||||
| 
 | 
 | ||||||
|         def __set__(self, ents): |         def __set__(self, ents): | ||||||
|             # TODO: |             # TODO: | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from typing import Any, Dict, Iterable, Optional | from typing import Any, Dict, Iterable, Optional, Union, overload | ||||||
| from .doc import Doc | from .doc import Doc | ||||||
| from .span import Span | from .span import Span | ||||||
| 
 | 
 | ||||||
|  | @ -22,7 +22,10 @@ class SpanGroup: | ||||||
|     def __len__(self) -> int: ... |     def __len__(self) -> int: ... | ||||||
|     def append(self, span: Span) -> None: ... |     def append(self, span: Span) -> None: ... | ||||||
|     def extend(self, spans: Iterable[Span]) -> None: ... |     def extend(self, spans: Iterable[Span]) -> None: ... | ||||||
|  |     @overload | ||||||
|     def __getitem__(self, i: int) -> Span: ... |     def __getitem__(self, i: int) -> Span: ... | ||||||
|  |     @overload | ||||||
|  |     def __getitem__(self, i: slice) -> SpanGroup: ... | ||||||
|     def to_bytes(self) -> bytes: ... |     def to_bytes(self) -> bytes: ... | ||||||
|     def from_bytes(self, bytes_data: bytes) -> SpanGroup: ... |     def from_bytes(self, bytes_data: bytes) -> SpanGroup: ... | ||||||
|     def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ... |     def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ... | ||||||
|  |  | ||||||
|  | @ -4,10 +4,13 @@ import struct | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| import srsly | import srsly | ||||||
| 
 | 
 | ||||||
| from spacy.errors import Errors |  | ||||||
| from .span cimport Span |  | ||||||
| from libcpp.memory cimport make_shared | from libcpp.memory cimport make_shared | ||||||
| 
 | 
 | ||||||
|  | from .span cimport Span | ||||||
|  | 
 | ||||||
|  | from ..errors import Errors | ||||||
|  | from .. import util | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| cdef class SpanGroup: | cdef class SpanGroup: | ||||||
|     """A group of spans that all belong to the same Doc object. The group |     """A group of spans that all belong to the same Doc object. The group | ||||||
|  | @ -93,7 +96,7 @@ cdef class SpanGroup: | ||||||
|         """ |         """ | ||||||
|         return self.c.size() |         return self.c.size() | ||||||
| 
 | 
 | ||||||
|     def __getitem__(self, int i) -> Span: |     def __getitem__(self, object i) -> Span: | ||||||
|         """Get a span from the group. Note that a copy of the span is returned, |         """Get a span from the group. Note that a copy of the span is returned, | ||||||
|         so if any changes are made to this span, they are not reflected in the |         so if any changes are made to this span, they are not reflected in the | ||||||
|         corresponding member of the span group. |         corresponding member of the span group. | ||||||
|  | @ -103,6 +106,10 @@ cdef class SpanGroup: | ||||||
| 
 | 
 | ||||||
|         DOCS: https://spacy.io/api/spangroup#getitem |         DOCS: https://spacy.io/api/spangroup#getitem | ||||||
|         """ |         """ | ||||||
|  |         if isinstance(i, slice): | ||||||
|  |             start, stop = util.normalize_slice(len(self), i.start, i.stop, i.step) | ||||||
|  |             spans = [self[i] for i in range(start, stop)] | ||||||
|  |             return SpanGroup(self.doc, spans=spans) | ||||||
|         i = self._normalize_index(i) |         i = self._normalize_index(i) | ||||||
|         return Span.cinit(self.doc, self.c[i]) |         return Span.cinit(self.doc, self.c[i]) | ||||||
| 
 | 
 | ||||||
|  | @ -155,8 +162,10 @@ cdef class SpanGroup: | ||||||
|         """ |         """ | ||||||
|         # For Cython 0.x and __add__, you cannot rely on `self` as being `self` |         # For Cython 0.x and __add__, you cannot rely on `self` as being `self` | ||||||
|         # or being the right type, so both types need to be checked explicitly. |         # or being the right type, so both types need to be checked explicitly. | ||||||
|         if isinstance(self, SpanGroup) and isinstance(other, SpanGroup): |         if isinstance(self, SpanGroup): | ||||||
|             return self._concat(other) |             return self._concat(other) | ||||||
|  |         if isinstance(other, SpanGroup): | ||||||
|  |             return other._concat(self) | ||||||
|         return NotImplemented |         return NotImplemented | ||||||
| 
 | 
 | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user