mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Override SpanGroups.setdefault to provide default SpanGroup (#10772)
* Fix mistake in SpanGroup API docs * Restrict SpanGroups.setdefault to SpanGroup only * Refactor to support default span iterable
This commit is contained in:
parent
d524f6415f
commit
b65d652881
|
@ -11,7 +11,7 @@ from spacy.lang.en import English
|
|||
from spacy.lang.xx import MultiLanguage
|
||||
from spacy.language import Language
|
||||
from spacy.lexeme import Lexeme
|
||||
from spacy.tokens import Doc, Span, Token
|
||||
from spacy.tokens import Doc, Span, SpanGroup, Token
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
from .test_underscore import clean_underscore # noqa: F401
|
||||
|
@ -964,3 +964,13 @@ def test_doc_spans_copy(en_tokenizer):
|
|||
assert weakref.ref(doc1) == doc1.spans.doc_ref
|
||||
doc2 = doc1.copy()
|
||||
assert weakref.ref(doc2) == doc2.spans.doc_ref
|
||||
|
||||
|
||||
def test_doc_spans_setdefault(en_tokenizer):
|
||||
doc = en_tokenizer("Some text about Colombia and the Czech Republic")
|
||||
doc.spans.setdefault("key1")
|
||||
assert len(doc.spans["key1"]) == 0
|
||||
doc.spans.setdefault("key2", default=[doc[0:1]])
|
||||
assert len(doc.spans["key2"]) == 1
|
||||
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
|
||||
assert len(doc.spans["key3"]) == 2
|
||||
|
|
|
@ -43,6 +43,15 @@ class SpanGroups(UserDict):
|
|||
doc = self._ensure_doc()
|
||||
return SpanGroups(doc).from_bytes(self.to_bytes())
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
if not isinstance(default, SpanGroup):
|
||||
if default is None:
|
||||
spans = []
|
||||
else:
|
||||
spans = default
|
||||
default = self._make_span_group(key, spans)
|
||||
return super().setdefault(key, default=default)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
# We don't need to serialize this as a dict, because the groups
|
||||
# know their names.
|
||||
|
|
|
@ -233,7 +233,7 @@ group.
|
|||
> doc.spans["errors"] = []
|
||||
> doc.spans["errors"].extend([doc[1:3], doc[0:1]])
|
||||
> assert len(doc.spans["errors"]) == 2
|
||||
> span_group = SpanGroup([doc[1:4], doc[0:3])
|
||||
> span_group = SpanGroup(doc, spans=[doc[1:4], doc[0:3]])
|
||||
> doc.spans["errors"].extend(span_group)
|
||||
> ```
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user