mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
small UX fix for DocBin (#6167)
* add informative warning when messing up store_user_data DocBin flags * add informative warning when messing up store_user_data DocBin flags * cleanup test * rename to patterns_path
This commit is contained in:
parent
f0b30aedad
commit
09dcb75076
|
@ -419,7 +419,7 @@ class Errors:
|
|||
E164 = ("x is neither increasing nor decreasing: {}.")
|
||||
E165 = ("Only one class present in y_true. ROC AUC score is not defined in "
|
||||
"that case.")
|
||||
E166 = ("Can only merge DocBins with the same pre-defined attributes.\n"
|
||||
E166 = ("Can only merge DocBins with the same value for '{param}'.\n"
|
||||
"Current DocBin: {current}\nOther DocBin: {other}")
|
||||
E169 = ("Can't find module: {module}")
|
||||
E170 = ("Cannot apply transition {name}: invalid for the current state.")
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import pytest
|
||||
from spacy.tokens.doc import Underscore
|
||||
|
||||
import spacy
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, DocBin
|
||||
|
@ -86,3 +89,20 @@ def test_serialize_doc_bin_unknown_spaces(en_vocab):
|
|||
assert re_doc1.text == "that 's "
|
||||
assert not re_doc2.has_unknown_spaces
|
||||
assert re_doc2.text == "that's"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"writer_flag,reader_flag,reader_value", [(True, True, "bar"), (True, False, "bar"), (False, True, "nothing"), (False, False, "nothing")]
|
||||
)
|
||||
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
|
||||
"""Test that custom extensions are correctly serialized in DocBin."""
|
||||
Doc.set_extension("foo", default="nothing")
|
||||
doc = Doc(en_vocab, words=["hello", "world"])
|
||||
doc._.foo = "bar"
|
||||
doc_bin_1 = DocBin(store_user_data=writer_flag)
|
||||
doc_bin_1.add(doc)
|
||||
doc_bin_bytes = doc_bin_1.to_bytes()
|
||||
doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
|
||||
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
|
||||
assert doc_2._.foo == reader_value
|
||||
Underscore.doc_extensions = {}
|
||||
|
|
|
@ -58,7 +58,7 @@ class DocBin:
|
|||
|
||||
attrs (Iterable[str]): List of attributes to serialize. 'orth' and
|
||||
'spacy' are always serialized, so they're not required.
|
||||
store_user_data (bool): Whether to include the `Doc.user_data`.
|
||||
store_user_data (bool): Whether to write the `Doc.user_data` to bytes/file.
|
||||
docs (Iterable[Doc]): Docs to add.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/docbin#init
|
||||
|
@ -106,11 +106,12 @@ class DocBin:
|
|||
self.strings.add(token.ent_type_)
|
||||
self.strings.add(token.ent_kb_id_)
|
||||
self.cats.append(doc.cats)
|
||||
if self.store_user_data:
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
|
||||
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
||||
"""Recover Doc objects from the annotations, using the given vocab.
|
||||
Note that the user data of each doc will be read (if available) and returned,
|
||||
regardless of the setting of 'self.store_user_data'.
|
||||
|
||||
vocab (Vocab): The shared vocab.
|
||||
YIELDS (Doc): The Doc objects.
|
||||
|
@ -129,7 +130,7 @@ class DocBin:
|
|||
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
|
||||
doc = doc.from_array(self.attrs, tokens)
|
||||
doc.cats = self.cats[i]
|
||||
if self.store_user_data:
|
||||
if i < len(self.user_data) and self.user_data[i] is not None:
|
||||
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
||||
doc.user_data.update(user_data)
|
||||
yield doc
|
||||
|
@ -137,20 +138,30 @@ class DocBin:
|
|||
def merge(self, other: "DocBin") -> None:
|
||||
"""Extend the annotations of this DocBin with the annotations from
|
||||
another. Will raise an error if the pre-defined attrs of the two
|
||||
DocBins don't match.
|
||||
DocBins don't match, or if they differ in whether or not to store
|
||||
user data.
|
||||
|
||||
other (DocBin): The DocBin to merge into the current bin.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/docbin#merge
|
||||
"""
|
||||
if self.attrs != other.attrs:
|
||||
raise ValueError(Errors.E166.format(current=self.attrs, other=other.attrs))
|
||||
raise ValueError(
|
||||
Errors.E166.format(param="attrs", current=self.attrs, other=other.attrs)
|
||||
)
|
||||
if self.store_user_data != other.store_user_data:
|
||||
raise ValueError(
|
||||
Errors.E166.format(
|
||||
param="store_user_data",
|
||||
current=self.store_user_data,
|
||||
other=other.store_user_data,
|
||||
)
|
||||
)
|
||||
self.tokens.extend(other.tokens)
|
||||
self.spaces.extend(other.spaces)
|
||||
self.strings.update(other.strings)
|
||||
self.cats.extend(other.cats)
|
||||
self.flags.extend(other.flags)
|
||||
if self.store_user_data:
|
||||
self.user_data.extend(other.user_data)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
|
@ -200,8 +211,10 @@ class DocBin:
|
|||
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
||||
self.cats = msg["cats"]
|
||||
self.flags = msg.get("flags", [{} for _ in lengths])
|
||||
if self.store_user_data and "user_data" in msg:
|
||||
if "user_data" in msg:
|
||||
self.user_data = list(msg["user_data"])
|
||||
else:
|
||||
self.user_data = [None] * len(self)
|
||||
for tokens in self.tokens:
|
||||
assert len(tokens.shape) == 2, tokens.shape # this should never happen
|
||||
return self
|
||||
|
|
|
@ -47,7 +47,7 @@ Create a `DocBin` object to hold serialized annotations.
|
|||
| Argument | Description |
|
||||
| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `attrs` | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. ~~Iterable[str]~~ |
|
||||
| `store_user_data` | Whether to include the `Doc.user_data` and the values of custom extension attributes. Defaults to `False`. ~~bool~~ |
|
||||
| `store_user_data` | Whether to write the `Doc.user_data` and the values of custom extension attributes to file/bytes. Defaults to `False`. ~~bool~~ |
|
||||
| `docs` | `Doc` objects to add on initialization. ~~Iterable[Doc]~~ |
|
||||
|
||||
## DocBin.\_\len\_\_ {#len tag="method"}
|
||||
|
|
Loading…
Reference in New Issue
Block a user