small UX fix for DocBin (#6167)

* add informative warning when messing up store_user_data DocBin flags

* add informative warning when messing up store_user_data DocBin flags

* cleanup test

* rename to patterns_path
This commit is contained in:
Sofie Van Landeghem 2020-10-02 15:43:32 +02:00 committed by GitHub
parent f0b30aedad
commit 09dcb75076
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 11 deletions

View File

@ -419,7 +419,7 @@ class Errors:
E164 = ("x is neither increasing nor decreasing: {}.")
E165 = ("Only one class present in y_true. ROC AUC score is not defined in "
"that case.")
E166 = ("Can only merge DocBins with the same pre-defined attributes.\n"
E166 = ("Can only merge DocBins with the same value for '{param}'.\n"
"Current DocBin: {current}\nOther DocBin: {other}")
E169 = ("Can't find module: {module}")
E170 = ("Cannot apply transition {name}: invalid for the current state.")

View File

@ -1,3 +1,6 @@
import pytest
from spacy.tokens.doc import Underscore
import spacy
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
@ -86,3 +89,20 @@ def test_serialize_doc_bin_unknown_spaces(en_vocab):
assert re_doc1.text == "that 's "
assert not re_doc2.has_unknown_spaces
assert re_doc2.text == "that's"
@pytest.mark.parametrize(
"writer_flag,reader_flag,reader_value", [(True, True, "bar"), (True, False, "bar"), (False, True, "nothing"), (False, False, "nothing")]
)
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
"""Test that custom extensions are correctly serialized in DocBin."""
Doc.set_extension("foo", default="nothing")
doc = Doc(en_vocab, words=["hello", "world"])
doc._.foo = "bar"
doc_bin_1 = DocBin(store_user_data=writer_flag)
doc_bin_1.add(doc)
doc_bin_bytes = doc_bin_1.to_bytes()
doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
assert doc_2._.foo == reader_value
Underscore.doc_extensions = {}

View File

@ -58,7 +58,7 @@ class DocBin:
attrs (Iterable[str]): List of attributes to serialize. 'orth' and
'spacy' are always serialized, so they're not required.
store_user_data (bool): Whether to include the `Doc.user_data`.
store_user_data (bool): Whether to write the `Doc.user_data` to bytes/file.
docs (Iterable[Doc]): Docs to add.
DOCS: https://nightly.spacy.io/api/docbin#init
@ -106,11 +106,12 @@ class DocBin:
self.strings.add(token.ent_type_)
self.strings.add(token.ent_kb_id_)
self.cats.append(doc.cats)
if self.store_user_data:
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
"""Recover Doc objects from the annotations, using the given vocab.
Note that the user data of each doc will be read (if available) and returned,
regardless of the setting of 'self.store_user_data'.
vocab (Vocab): The shared vocab.
YIELDS (Doc): The Doc objects.
@ -129,7 +130,7 @@ class DocBin:
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
doc = doc.from_array(self.attrs, tokens)
doc.cats = self.cats[i]
if self.store_user_data:
if i < len(self.user_data) and self.user_data[i] is not None:
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
doc.user_data.update(user_data)
yield doc
@ -137,21 +138,31 @@ class DocBin:
def merge(self, other: "DocBin") -> None:
"""Extend the annotations of this DocBin with the annotations from
another. Will raise an error if the pre-defined attrs of the two
DocBins don't match.
DocBins don't match, or if they differ in whether or not to store
user data.
other (DocBin): The DocBin to merge into the current bin.
DOCS: https://nightly.spacy.io/api/docbin#merge
"""
if self.attrs != other.attrs:
raise ValueError(Errors.E166.format(current=self.attrs, other=other.attrs))
raise ValueError(
Errors.E166.format(param="attrs", current=self.attrs, other=other.attrs)
)
if self.store_user_data != other.store_user_data:
raise ValueError(
Errors.E166.format(
param="store_user_data",
current=self.store_user_data,
other=other.store_user_data,
)
)
self.tokens.extend(other.tokens)
self.spaces.extend(other.spaces)
self.strings.update(other.strings)
self.cats.extend(other.cats)
self.flags.extend(other.flags)
if self.store_user_data:
self.user_data.extend(other.user_data)
self.user_data.extend(other.user_data)
def to_bytes(self) -> bytes:
"""Serialize the DocBin's annotations to a bytestring.
@ -200,8 +211,10 @@ class DocBin:
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
self.cats = msg["cats"]
self.flags = msg.get("flags", [{} for _ in lengths])
if self.store_user_data and "user_data" in msg:
if "user_data" in msg:
self.user_data = list(msg["user_data"])
else:
self.user_data = [None] * len(self)
for tokens in self.tokens:
assert len(tokens.shape) == 2, tokens.shape # this should never happen
return self

View File

@ -47,7 +47,7 @@ Create a `DocBin` object to hold serialized annotations.
| Argument | Description |
| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `attrs` | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. ~~Iterable[str]~~ |
| `store_user_data` | Whether to include the `Doc.user_data` and the values of custom extension attributes. Defaults to `False`. ~~bool~~ |
| `store_user_data` | Whether to write the `Doc.user_data` and the values of custom extension attributes to file/bytes. Defaults to `False`. ~~bool~~ |
| `docs` | `Doc` objects to add on initialization. ~~Iterable[Doc]~~ |
## DocBin.\_\len\_\_ {#len tag="method"}