mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Override SpanGroups.setdefault to provide default SpanGroup (#10772)
* Fix mistake in SpanGroup API docs * Restrict SpanGroups.setdefault to SpanGroup only * Refactor to support default span iterable
This commit is contained in:
parent
d524f6415f
commit
b65d652881
|
@ -11,7 +11,7 @@ from spacy.lang.en import English
|
||||||
from spacy.lang.xx import MultiLanguage
|
from spacy.lang.xx import MultiLanguage
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.lexeme import Lexeme
|
from spacy.lexeme import Lexeme
|
||||||
from spacy.tokens import Doc, Span, Token
|
from spacy.tokens import Doc, Span, SpanGroup, Token
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
from .test_underscore import clean_underscore # noqa: F401
|
from .test_underscore import clean_underscore # noqa: F401
|
||||||
|
@ -964,3 +964,13 @@ def test_doc_spans_copy(en_tokenizer):
|
||||||
assert weakref.ref(doc1) == doc1.spans.doc_ref
|
assert weakref.ref(doc1) == doc1.spans.doc_ref
|
||||||
doc2 = doc1.copy()
|
doc2 = doc1.copy()
|
||||||
assert weakref.ref(doc2) == doc2.spans.doc_ref
|
assert weakref.ref(doc2) == doc2.spans.doc_ref
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_spans_setdefault(en_tokenizer):
|
||||||
|
doc = en_tokenizer("Some text about Colombia and the Czech Republic")
|
||||||
|
doc.spans.setdefault("key1")
|
||||||
|
assert len(doc.spans["key1"]) == 0
|
||||||
|
doc.spans.setdefault("key2", default=[doc[0:1]])
|
||||||
|
assert len(doc.spans["key2"]) == 1
|
||||||
|
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
|
||||||
|
assert len(doc.spans["key3"]) == 2
|
||||||
|
|
|
@ -43,6 +43,15 @@ class SpanGroups(UserDict):
|
||||||
doc = self._ensure_doc()
|
doc = self._ensure_doc()
|
||||||
return SpanGroups(doc).from_bytes(self.to_bytes())
|
return SpanGroups(doc).from_bytes(self.to_bytes())
|
||||||
|
|
||||||
|
def setdefault(self, key, default=None):
|
||||||
|
if not isinstance(default, SpanGroup):
|
||||||
|
if default is None:
|
||||||
|
spans = []
|
||||||
|
else:
|
||||||
|
spans = default
|
||||||
|
default = self._make_span_group(key, spans)
|
||||||
|
return super().setdefault(key, default=default)
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
# We don't need to serialize this as a dict, because the groups
|
# We don't need to serialize this as a dict, because the groups
|
||||||
# know their names.
|
# know their names.
|
||||||
|
|
|
@ -233,7 +233,7 @@ group.
|
||||||
> doc.spans["errors"] = []
|
> doc.spans["errors"] = []
|
||||||
> doc.spans["errors"].extend([doc[1:3], doc[0:1]])
|
> doc.spans["errors"].extend([doc[1:3], doc[0:1]])
|
||||||
> assert len(doc.spans["errors"]) == 2
|
> assert len(doc.spans["errors"]) == 2
|
||||||
> span_group = SpanGroup([doc[1:4], doc[0:3])
|
> span_group = SpanGroup(doc, spans=[doc[1:4], doc[0:3]])
|
||||||
> doc.spans["errors"].extend(span_group)
|
> doc.spans["errors"].extend(span_group)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user