From 09dcb75076e39eca904e54c21e22c25491a82a02 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Fri, 2 Oct 2020 15:43:32 +0200 Subject: [PATCH] small UX fix for DocBin (#6167) * add informative warning when messing up store_user_data DocBin flags * add informative warning when messing up store_user_data DocBin flags * cleanup test * rename to patterns_path --- spacy/errors.py | 2 +- spacy/tests/serialize/test_serialize_doc.py | 20 +++++++++++++ spacy/tokens/_serialize.py | 31 +++++++++++++++------ website/docs/api/docbin.md | 2 +- 4 files changed, 44 insertions(+), 11 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 4edd1cbae..dbb25479d 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -419,7 +419,7 @@ class Errors: E164 = ("x is neither increasing nor decreasing: {}.") E165 = ("Only one class present in y_true. ROC AUC score is not defined in " "that case.") - E166 = ("Can only merge DocBins with the same pre-defined attributes.\n" + E166 = ("Can only merge DocBins with the same value for '{param}'.\n" "Current DocBin: {current}\nOther DocBin: {other}") E169 = ("Can't find module: {module}") E170 = ("Cannot apply transition {name}: invalid for the current state.") diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py index 4a976fc02..8b6adb83b 100644 --- a/spacy/tests/serialize/test_serialize_doc.py +++ b/spacy/tests/serialize/test_serialize_doc.py @@ -1,3 +1,6 @@ +import pytest +from spacy.tokens.doc import Underscore + import spacy from spacy.lang.en import English from spacy.tokens import Doc, DocBin @@ -86,3 +89,20 @@ def test_serialize_doc_bin_unknown_spaces(en_vocab): assert re_doc1.text == "that 's " assert not re_doc2.has_unknown_spaces assert re_doc2.text == "that's" + + +@pytest.mark.parametrize( + "writer_flag,reader_flag,reader_value", [(True, True, "bar"), (True, False, "bar"), (False, True, "nothing"), (False, False, "nothing")] +) +def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value): + """Test that custom extensions are correctly serialized in DocBin.""" + Doc.set_extension("foo", default="nothing") + doc = Doc(en_vocab, words=["hello", "world"]) + doc._.foo = "bar" + doc_bin_1 = DocBin(store_user_data=writer_flag) + doc_bin_1.add(doc) + doc_bin_bytes = doc_bin_1.to_bytes() + doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes) + doc_2 = list(doc_bin_2.get_docs(en_vocab))[0] + assert doc_2._.foo == reader_value + Underscore.doc_extensions = {} diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index ed283a86b..11eb75821 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -58,7 +58,7 @@ class DocBin: attrs (Iterable[str]): List of attributes to serialize. 'orth' and 'spacy' are always serialized, so they're not required. - store_user_data (bool): Whether to include the `Doc.user_data`. + store_user_data (bool): Whether to write the `Doc.user_data` to bytes/file. docs (Iterable[Doc]): Docs to add. DOCS: https://nightly.spacy.io/api/docbin#init @@ -106,11 +106,12 @@ class DocBin: self.strings.add(token.ent_type_) self.strings.add(token.ent_kb_id_) self.cats.append(doc.cats) - if self.store_user_data: - self.user_data.append(srsly.msgpack_dumps(doc.user_data)) + self.user_data.append(srsly.msgpack_dumps(doc.user_data)) def get_docs(self, vocab: Vocab) -> Iterator[Doc]: """Recover Doc objects from the annotations, using the given vocab. + Note that the user data of each doc will be read (if available) and returned, + regardless of the setting of 'self.store_user_data'. vocab (Vocab): The shared vocab. YIELDS (Doc): The Doc objects. @@ -129,7 +130,7 @@ class DocBin: doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces) doc = doc.from_array(self.attrs, tokens) doc.cats = self.cats[i] - if self.store_user_data: + if i < len(self.user_data) and self.user_data[i] is not None: user_data = srsly.msgpack_loads(self.user_data[i], use_list=False) doc.user_data.update(user_data) yield doc @@ -137,21 +138,31 @@ class DocBin: def merge(self, other: "DocBin") -> None: """Extend the annotations of this DocBin with the annotations from another. Will raise an error if the pre-defined attrs of the two - DocBins don't match. + DocBins don't match, or if they differ in whether or not to store + user data. other (DocBin): The DocBin to merge into the current bin. DOCS: https://nightly.spacy.io/api/docbin#merge """ if self.attrs != other.attrs: - raise ValueError(Errors.E166.format(current=self.attrs, other=other.attrs)) + raise ValueError( + Errors.E166.format(param="attrs", current=self.attrs, other=other.attrs) + ) + if self.store_user_data != other.store_user_data: + raise ValueError( + Errors.E166.format( + param="store_user_data", + current=self.store_user_data, + other=other.store_user_data, + ) + ) self.tokens.extend(other.tokens) self.spaces.extend(other.spaces) self.strings.update(other.strings) self.cats.extend(other.cats) self.flags.extend(other.flags) - if self.store_user_data: - self.user_data.extend(other.user_data) + self.user_data.extend(other.user_data) def to_bytes(self) -> bytes: """Serialize the DocBin's annotations to a bytestring. @@ -200,8 +211,10 @@ class DocBin: self.spaces = NumpyOps().unflatten(flat_spaces, lengths) self.cats = msg["cats"] self.flags = msg.get("flags", [{} for _ in lengths]) - if self.store_user_data and "user_data" in msg: + if "user_data" in msg: self.user_data = list(msg["user_data"]) + else: + self.user_data = [None] * len(self) for tokens in self.tokens: assert len(tokens.shape) == 2, tokens.shape # this should never happen return self diff --git a/website/docs/api/docbin.md b/website/docs/api/docbin.md index 03aff2f6e..3625ed790 100644 --- a/website/docs/api/docbin.md +++ b/website/docs/api/docbin.md @@ -47,7 +47,7 @@ Create a `DocBin` object to hold serialized annotations. | Argument | Description | | ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `attrs` | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. ~~Iterable[str]~~ | -| `store_user_data` | Whether to include the `Doc.user_data` and the values of custom extension attributes. Defaults to `False`. ~~bool~~ | +| `store_user_data` | Whether to write the `Doc.user_data` and the values of custom extension attributes to file/bytes. Defaults to `False`. ~~bool~~ | | `docs` | `Doc` objects to add on initialization. ~~Iterable[Doc]~~ | ## DocBin.\_\len\_\_ {#len tag="method"}