mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
small UX fix for DocBin (#6167)
* add informative warning when messing up store_user_data DocBin flags * add informative warning when messing up store_user_data DocBin flags * cleanup test * rename to patterns_path
This commit is contained in:
parent
f0b30aedad
commit
09dcb75076
|
@ -419,7 +419,7 @@ class Errors:
|
||||||
E164 = ("x is neither increasing nor decreasing: {}.")
|
E164 = ("x is neither increasing nor decreasing: {}.")
|
||||||
E165 = ("Only one class present in y_true. ROC AUC score is not defined in "
|
E165 = ("Only one class present in y_true. ROC AUC score is not defined in "
|
||||||
"that case.")
|
"that case.")
|
||||||
E166 = ("Can only merge DocBins with the same pre-defined attributes.\n"
|
E166 = ("Can only merge DocBins with the same value for '{param}'.\n"
|
||||||
"Current DocBin: {current}\nOther DocBin: {other}")
|
"Current DocBin: {current}\nOther DocBin: {other}")
|
||||||
E169 = ("Can't find module: {module}")
|
E169 = ("Can't find module: {module}")
|
||||||
E170 = ("Cannot apply transition {name}: invalid for the current state.")
|
E170 = ("Cannot apply transition {name}: invalid for the current state.")
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import pytest
|
||||||
|
from spacy.tokens.doc import Underscore
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
|
@ -86,3 +89,20 @@ def test_serialize_doc_bin_unknown_spaces(en_vocab):
|
||||||
assert re_doc1.text == "that 's "
|
assert re_doc1.text == "that 's "
|
||||||
assert not re_doc2.has_unknown_spaces
|
assert not re_doc2.has_unknown_spaces
|
||||||
assert re_doc2.text == "that's"
|
assert re_doc2.text == "that's"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"writer_flag,reader_flag,reader_value", [(True, True, "bar"), (True, False, "bar"), (False, True, "nothing"), (False, False, "nothing")]
|
||||||
|
)
|
||||||
|
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
|
||||||
|
"""Test that custom extensions are correctly serialized in DocBin."""
|
||||||
|
Doc.set_extension("foo", default="nothing")
|
||||||
|
doc = Doc(en_vocab, words=["hello", "world"])
|
||||||
|
doc._.foo = "bar"
|
||||||
|
doc_bin_1 = DocBin(store_user_data=writer_flag)
|
||||||
|
doc_bin_1.add(doc)
|
||||||
|
doc_bin_bytes = doc_bin_1.to_bytes()
|
||||||
|
doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
|
||||||
|
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
|
||||||
|
assert doc_2._.foo == reader_value
|
||||||
|
Underscore.doc_extensions = {}
|
||||||
|
|
|
@ -58,7 +58,7 @@ class DocBin:
|
||||||
|
|
||||||
attrs (Iterable[str]): List of attributes to serialize. 'orth' and
|
attrs (Iterable[str]): List of attributes to serialize. 'orth' and
|
||||||
'spacy' are always serialized, so they're not required.
|
'spacy' are always serialized, so they're not required.
|
||||||
store_user_data (bool): Whether to include the `Doc.user_data`.
|
store_user_data (bool): Whether to write the `Doc.user_data` to bytes/file.
|
||||||
docs (Iterable[Doc]): Docs to add.
|
docs (Iterable[Doc]): Docs to add.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/docbin#init
|
DOCS: https://nightly.spacy.io/api/docbin#init
|
||||||
|
@ -106,11 +106,12 @@ class DocBin:
|
||||||
self.strings.add(token.ent_type_)
|
self.strings.add(token.ent_type_)
|
||||||
self.strings.add(token.ent_kb_id_)
|
self.strings.add(token.ent_kb_id_)
|
||||||
self.cats.append(doc.cats)
|
self.cats.append(doc.cats)
|
||||||
if self.store_user_data:
|
|
||||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||||
|
|
||||||
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
||||||
"""Recover Doc objects from the annotations, using the given vocab.
|
"""Recover Doc objects from the annotations, using the given vocab.
|
||||||
|
Note that the user data of each doc will be read (if available) and returned,
|
||||||
|
regardless of the setting of 'self.store_user_data'.
|
||||||
|
|
||||||
vocab (Vocab): The shared vocab.
|
vocab (Vocab): The shared vocab.
|
||||||
YIELDS (Doc): The Doc objects.
|
YIELDS (Doc): The Doc objects.
|
||||||
|
@ -129,7 +130,7 @@ class DocBin:
|
||||||
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
|
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
|
||||||
doc = doc.from_array(self.attrs, tokens)
|
doc = doc.from_array(self.attrs, tokens)
|
||||||
doc.cats = self.cats[i]
|
doc.cats = self.cats[i]
|
||||||
if self.store_user_data:
|
if i < len(self.user_data) and self.user_data[i] is not None:
|
||||||
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
||||||
doc.user_data.update(user_data)
|
doc.user_data.update(user_data)
|
||||||
yield doc
|
yield doc
|
||||||
|
@ -137,20 +138,30 @@ class DocBin:
|
||||||
def merge(self, other: "DocBin") -> None:
|
def merge(self, other: "DocBin") -> None:
|
||||||
"""Extend the annotations of this DocBin with the annotations from
|
"""Extend the annotations of this DocBin with the annotations from
|
||||||
another. Will raise an error if the pre-defined attrs of the two
|
another. Will raise an error if the pre-defined attrs of the two
|
||||||
DocBins don't match.
|
DocBins don't match, or if they differ in whether or not to store
|
||||||
|
user data.
|
||||||
|
|
||||||
other (DocBin): The DocBin to merge into the current bin.
|
other (DocBin): The DocBin to merge into the current bin.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/docbin#merge
|
DOCS: https://nightly.spacy.io/api/docbin#merge
|
||||||
"""
|
"""
|
||||||
if self.attrs != other.attrs:
|
if self.attrs != other.attrs:
|
||||||
raise ValueError(Errors.E166.format(current=self.attrs, other=other.attrs))
|
raise ValueError(
|
||||||
|
Errors.E166.format(param="attrs", current=self.attrs, other=other.attrs)
|
||||||
|
)
|
||||||
|
if self.store_user_data != other.store_user_data:
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E166.format(
|
||||||
|
param="store_user_data",
|
||||||
|
current=self.store_user_data,
|
||||||
|
other=other.store_user_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
self.tokens.extend(other.tokens)
|
self.tokens.extend(other.tokens)
|
||||||
self.spaces.extend(other.spaces)
|
self.spaces.extend(other.spaces)
|
||||||
self.strings.update(other.strings)
|
self.strings.update(other.strings)
|
||||||
self.cats.extend(other.cats)
|
self.cats.extend(other.cats)
|
||||||
self.flags.extend(other.flags)
|
self.flags.extend(other.flags)
|
||||||
if self.store_user_data:
|
|
||||||
self.user_data.extend(other.user_data)
|
self.user_data.extend(other.user_data)
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
|
@ -200,8 +211,10 @@ class DocBin:
|
||||||
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
||||||
self.cats = msg["cats"]
|
self.cats = msg["cats"]
|
||||||
self.flags = msg.get("flags", [{} for _ in lengths])
|
self.flags = msg.get("flags", [{} for _ in lengths])
|
||||||
if self.store_user_data and "user_data" in msg:
|
if "user_data" in msg:
|
||||||
self.user_data = list(msg["user_data"])
|
self.user_data = list(msg["user_data"])
|
||||||
|
else:
|
||||||
|
self.user_data = [None] * len(self)
|
||||||
for tokens in self.tokens:
|
for tokens in self.tokens:
|
||||||
assert len(tokens.shape) == 2, tokens.shape # this should never happen
|
assert len(tokens.shape) == 2, tokens.shape # this should never happen
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -47,7 +47,7 @@ Create a `DocBin` object to hold serialized annotations.
|
||||||
| Argument | Description |
|
| Argument | Description |
|
||||||
| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `attrs` | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. ~~Iterable[str]~~ |
|
| `attrs` | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. ~~Iterable[str]~~ |
|
||||||
| `store_user_data` | Whether to include the `Doc.user_data` and the values of custom extension attributes. Defaults to `False`. ~~bool~~ |
|
| `store_user_data` | Whether to write the `Doc.user_data` and the values of custom extension attributes to file/bytes. Defaults to `False`. ~~bool~~ |
|
||||||
| `docs` | `Doc` objects to add on initialization. ~~Iterable[Doc]~~ |
|
| `docs` | `Doc` objects to add on initialization. ~~Iterable[Doc]~~ |
|
||||||
|
|
||||||
## DocBin.\_\len\_\_ {#len tag="method"}
|
## DocBin.\_\len\_\_ {#len tag="method"}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user