mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 01:02:23 +03:00
TEST: Doc.ents as SpanGroup
Overview of required changes to support `SpanGroup` rather than `Tuple[Span]`: * implement slice for `SpanGroup` * return `SpanGroup` for `SpanGroup + x` or `x + SpanGroup` rather than refusing to concatenate (currently without good error handling) Side effects: * if appending to `Doc.ents`, only `Iterable[Span]` is supported rather than other formats like raw entity tuples from matcher results, but you can still assign mixed data in any of the currently supported formats to `Doc.ents` * for the `Matcher` case, `as_spans` provides a good alternative
This commit is contained in:
parent
41b3a0d932
commit
31cd5141bd
|
@ -815,7 +815,7 @@ def test_doc_set_ents(en_tokenizer):
|
||||||
doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified")
|
doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified")
|
||||||
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
|
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
|
||||||
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
|
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
|
||||||
assert doc.ents == tuple()
|
assert len(doc.ents) == 0
|
||||||
|
|
||||||
# invalid IOB repaired after blocked
|
# invalid IOB repaired after blocked
|
||||||
doc.ents = [Span(doc, 3, 5, "ENT")]
|
doc.ents = [Span(doc, 3, 5, "ENT")]
|
||||||
|
@ -892,7 +892,7 @@ def test_doc_init_iob():
|
||||||
words = ["a", "b", "c", "d", "e"]
|
words = ["a", "b", "c", "d", "e"]
|
||||||
ents = ["O"] * len(words)
|
ents = ["O"] * len(words)
|
||||||
doc = Doc(Vocab(), words=words, ents=ents)
|
doc = Doc(Vocab(), words=words, ents=ents)
|
||||||
assert doc.ents == ()
|
assert len(doc.ents) == 0
|
||||||
|
|
||||||
ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"]
|
ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"]
|
||||||
doc = Doc(Vocab(), words=words, ents=ents)
|
doc = Doc(Vocab(), words=words, ents=ents)
|
||||||
|
|
|
@ -87,7 +87,10 @@ def test_issue118_prefix_reorder(en_tokenizer, patterns):
|
||||||
matcher.add("BostonCeltics", patterns)
|
matcher.add("BostonCeltics", patterns)
|
||||||
assert len(list(doc.ents)) == 0
|
assert len(list(doc.ents)) == 0
|
||||||
matches = [(ORG, start, end) for _, start, end in matcher(doc)]
|
matches = [(ORG, start, end) for _, start, end in matcher(doc)]
|
||||||
doc.ents += tuple(matches)[1:]
|
matches_spans = matcher(doc, as_spans=True)
|
||||||
|
for span in matches_spans:
|
||||||
|
span.label = ORG
|
||||||
|
doc.ents += matches_spans[1:]
|
||||||
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
|
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
|
||||||
ents = doc.ents
|
ents = doc.ents
|
||||||
assert len(ents) == 1
|
assert len(ents) == 1
|
||||||
|
@ -116,7 +119,8 @@ def test_issue242(en_tokenizer):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# One token can only be part of one entity, so test that the matches
|
# One token can only be part of one entity, so test that the matches
|
||||||
# can't be added as entities
|
# can't be added as entities
|
||||||
doc.ents += tuple(matches)
|
matches_spans = matcher(doc, as_spans=True)
|
||||||
|
doc.ents += tuple(matches_spans)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(587)
|
@pytest.mark.issue(587)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from cymem.cymem import Pool
|
||||||
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
|
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
|
||||||
from .span import Span
|
from .span import Span
|
||||||
from .token import Token
|
from .token import Token
|
||||||
|
from .span_group import SpanGroup
|
||||||
from .span_groups import SpanGroups
|
from .span_groups import SpanGroups
|
||||||
from .retokenizer import Retokenizer
|
from .retokenizer import Retokenizer
|
||||||
from ..lexeme import Lexeme
|
from ..lexeme import Lexeme
|
||||||
|
@ -120,7 +121,7 @@ class Doc:
|
||||||
def text(self) -> str: ...
|
def text(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
def text_with_ws(self) -> str: ...
|
def text_with_ws(self) -> str: ...
|
||||||
ents: Tuple[Span]
|
ents: SpanGroup
|
||||||
def set_ents(
|
def set_ents(
|
||||||
self,
|
self,
|
||||||
entities: List[Span],
|
entities: List[Span],
|
||||||
|
|
|
@ -19,6 +19,7 @@ import warnings
|
||||||
|
|
||||||
from .span cimport Span
|
from .span cimport Span
|
||||||
from .token cimport MISSING_DEP
|
from .token cimport MISSING_DEP
|
||||||
|
from .span_group import SpanGroup
|
||||||
from .span_groups import SpanGroups
|
from .span_groups import SpanGroups
|
||||||
from .token cimport Token
|
from .token cimport Token
|
||||||
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||||
|
@ -743,7 +744,7 @@ cdef class Doc:
|
||||||
output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id))
|
output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id))
|
||||||
# remove empty-label spans
|
# remove empty-label spans
|
||||||
output = [o for o in output if o.label_ != ""]
|
output = [o for o in output if o.label_ != ""]
|
||||||
return tuple(output)
|
return SpanGroup(self, spans=output)
|
||||||
|
|
||||||
def __set__(self, ents):
|
def __set__(self, ents):
|
||||||
# TODO:
|
# TODO:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Iterable, Optional
|
from typing import Any, Dict, Iterable, Optional, Union, overload
|
||||||
from .doc import Doc
|
from .doc import Doc
|
||||||
from .span import Span
|
from .span import Span
|
||||||
|
|
||||||
|
@ -22,7 +22,10 @@ class SpanGroup:
|
||||||
def __len__(self) -> int: ...
|
def __len__(self) -> int: ...
|
||||||
def append(self, span: Span) -> None: ...
|
def append(self, span: Span) -> None: ...
|
||||||
def extend(self, spans: Iterable[Span]) -> None: ...
|
def extend(self, spans: Iterable[Span]) -> None: ...
|
||||||
|
@overload
|
||||||
def __getitem__(self, i: int) -> Span: ...
|
def __getitem__(self, i: int) -> Span: ...
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, i: slice) -> SpanGroup: ...
|
||||||
def to_bytes(self) -> bytes: ...
|
def to_bytes(self) -> bytes: ...
|
||||||
def from_bytes(self, bytes_data: bytes) -> SpanGroup: ...
|
def from_bytes(self, bytes_data: bytes) -> SpanGroup: ...
|
||||||
def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ...
|
def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ...
|
||||||
|
|
|
@ -4,10 +4,13 @@ import struct
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from spacy.errors import Errors
|
|
||||||
from .span cimport Span
|
|
||||||
from libcpp.memory cimport make_shared
|
from libcpp.memory cimport make_shared
|
||||||
|
|
||||||
|
from .span cimport Span
|
||||||
|
|
||||||
|
from ..errors import Errors
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
cdef class SpanGroup:
|
cdef class SpanGroup:
|
||||||
"""A group of spans that all belong to the same Doc object. The group
|
"""A group of spans that all belong to the same Doc object. The group
|
||||||
|
@ -93,7 +96,7 @@ cdef class SpanGroup:
|
||||||
"""
|
"""
|
||||||
return self.c.size()
|
return self.c.size()
|
||||||
|
|
||||||
def __getitem__(self, int i) -> Span:
|
def __getitem__(self, object i) -> Span:
|
||||||
"""Get a span from the group. Note that a copy of the span is returned,
|
"""Get a span from the group. Note that a copy of the span is returned,
|
||||||
so if any changes are made to this span, they are not reflected in the
|
so if any changes are made to this span, they are not reflected in the
|
||||||
corresponding member of the span group.
|
corresponding member of the span group.
|
||||||
|
@ -103,6 +106,10 @@ cdef class SpanGroup:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/spangroup#getitem
|
DOCS: https://spacy.io/api/spangroup#getitem
|
||||||
"""
|
"""
|
||||||
|
if isinstance(i, slice):
|
||||||
|
start, stop = util.normalize_slice(len(self), i.start, i.stop, i.step)
|
||||||
|
spans = [self[i] for i in range(start, stop)]
|
||||||
|
return SpanGroup(self.doc, spans=spans)
|
||||||
i = self._normalize_index(i)
|
i = self._normalize_index(i)
|
||||||
return Span.cinit(self.doc, self.c[i])
|
return Span.cinit(self.doc, self.c[i])
|
||||||
|
|
||||||
|
@ -155,8 +162,10 @@ cdef class SpanGroup:
|
||||||
"""
|
"""
|
||||||
# For Cython 0.x and __add__, you cannot rely on `self` as being `self`
|
# For Cython 0.x and __add__, you cannot rely on `self` as being `self`
|
||||||
# or being the right type, so both types need to be checked explicitly.
|
# or being the right type, so both types need to be checked explicitly.
|
||||||
if isinstance(self, SpanGroup) and isinstance(other, SpanGroup):
|
if isinstance(self, SpanGroup):
|
||||||
return self._concat(other)
|
return self._concat(other)
|
||||||
|
if isinstance(other, SpanGroup):
|
||||||
|
return other._concat(self)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user