mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Support custom attributes for tokens and spans in json conversion (#11125)
* Add token and span custom attributes to to_json() * Change logic for to_json * Add functionality to from_json * Small adjustments * Move token/span attributes to new dict key * Fix test * Fix the same test but much better * Add backwards compatibility tests and adjust logic * Add test to check if attributes not set in underscore are not saved in the json * Add tests for json compatibility * Adjust test names * Fix tests and clean up code * Fix assert json tests * small adjustment * adjust naming and code readability * Adjust naming, added more tests and changed logic * Fix typo * Adjust errors, naming, and small test optimization * Fix byte tests * Fix bytes tests * Change naming and json structure * update schema * Update spacy/schemas.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/tokens/doc.pyx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/tokens/doc.pyx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update spacy/schemas.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update schema for underscore attributes * Adjust underscore schema * adjust schema tests Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
7e75327893
commit
5afa98aabf
|
@ -389,7 +389,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"consider using doc.spans instead.")
|
||||
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
|
||||
"settings: {opts}")
|
||||
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}")
|
||||
E107 = ("Value of custom attribute `{attr}` is not JSON-serializable: {value}")
|
||||
E109 = ("Component '{name}' could not be run. Did you forget to "
|
||||
"call `initialize()`?")
|
||||
E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}")
|
||||
|
|
|
@ -514,6 +514,14 @@ class DocJSONSchema(BaseModel):
|
|||
tokens: List[Dict[StrictStr, Union[StrictStr, StrictInt]]] = Field(
|
||||
..., title="Token information - ID, start, annotations"
|
||||
)
|
||||
_: Optional[Dict[StrictStr, Any]] = Field(
|
||||
None, title="Any custom data stored in the document's _ attribute"
|
||||
underscore_doc: Optional[Dict[StrictStr, Any]] = Field(
|
||||
None,
|
||||
title="Any custom data stored in the document's _ attribute",
|
||||
alias="_",
|
||||
)
|
||||
underscore_token: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field(
|
||||
None, title="Any custom data stored in the token's _ attribute"
|
||||
)
|
||||
underscore_span: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field(
|
||||
None, title="Any custom data stored in the span's _ attribute"
|
||||
)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import pytest
|
||||
import spacy
|
||||
from spacy import schemas
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.tokens import Doc, Span, Token
|
||||
import srsly
|
||||
from .test_underscore import clean_underscore # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def doc(en_vocab):
|
||||
words = ["c", "d", "e"]
|
||||
spaces = [True, True, True]
|
||||
pos = ["VERB", "NOUN", "NOUN"]
|
||||
tags = ["VBP", "NN", "NN"]
|
||||
heads = [0, 0, 1]
|
||||
|
@ -17,6 +20,7 @@ def doc(en_vocab):
|
|||
return Doc(
|
||||
en_vocab,
|
||||
words=words,
|
||||
spaces=spaces,
|
||||
pos=pos,
|
||||
tags=tags,
|
||||
heads=heads,
|
||||
|
@ -45,6 +49,47 @@ def doc_without_deps(en_vocab):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def doc_json():
|
||||
return {
|
||||
"text": "c d e ",
|
||||
"ents": [{"start": 2, "end": 3, "label": "ORG"}],
|
||||
"sents": [{"start": 0, "end": 5}],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
"start": 0,
|
||||
"end": 1,
|
||||
"tag": "VBP",
|
||||
"pos": "VERB",
|
||||
"morph": "Feat1=A",
|
||||
"dep": "ROOT",
|
||||
"head": 0,
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"start": 2,
|
||||
"end": 3,
|
||||
"tag": "NN",
|
||||
"pos": "NOUN",
|
||||
"morph": "Feat1=B",
|
||||
"dep": "dobj",
|
||||
"head": 0,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"start": 4,
|
||||
"end": 5,
|
||||
"tag": "NN",
|
||||
"pos": "NOUN",
|
||||
"morph": "Feat1=A|Feat2=D",
|
||||
"dep": "dobj",
|
||||
"head": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_doc_to_json(doc):
|
||||
json_doc = doc.to_json()
|
||||
assert json_doc["text"] == "c d e "
|
||||
|
@ -56,7 +101,8 @@ def test_doc_to_json(doc):
|
|||
assert json_doc["ents"][0]["start"] == 2 # character offset!
|
||||
assert json_doc["ents"][0]["end"] == 3 # character offset!
|
||||
assert json_doc["ents"][0]["label"] == "ORG"
|
||||
assert not schemas.validate(schemas.DocJSONSchema, json_doc)
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
|
||||
|
||||
|
||||
def test_doc_to_json_underscore(doc):
|
||||
|
@ -64,11 +110,96 @@ def test_doc_to_json_underscore(doc):
|
|||
Doc.set_extension("json_test2", default=False)
|
||||
doc._.json_test1 = "hello world"
|
||||
doc._.json_test2 = [1, 2, 3]
|
||||
|
||||
json_doc = doc.to_json(underscore=["json_test1", "json_test2"])
|
||||
assert "_" in json_doc
|
||||
assert json_doc["_"]["json_test1"] == "hello world"
|
||||
assert json_doc["_"]["json_test2"] == [1, 2, 3]
|
||||
assert not schemas.validate(schemas.DocJSONSchema, json_doc)
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
|
||||
|
||||
|
||||
def test_doc_to_json_with_token_span_attributes(doc):
|
||||
Doc.set_extension("json_test1", default=False)
|
||||
Doc.set_extension("json_test2", default=False)
|
||||
Token.set_extension("token_test", default=False)
|
||||
Span.set_extension("span_test", default=False)
|
||||
|
||||
doc._.json_test1 = "hello world"
|
||||
doc._.json_test2 = [1, 2, 3]
|
||||
doc[0:1]._.span_test = "span_attribute"
|
||||
doc[0]._.token_test = 117
|
||||
doc.spans["span_group"] = [doc[0:1]]
|
||||
json_doc = doc.to_json(
|
||||
underscore=["json_test1", "json_test2", "token_test", "span_test"]
|
||||
)
|
||||
|
||||
assert "_" in json_doc
|
||||
assert json_doc["_"]["json_test1"] == "hello world"
|
||||
assert json_doc["_"]["json_test2"] == [1, 2, 3]
|
||||
assert "underscore_token" in json_doc
|
||||
assert "underscore_span" in json_doc
|
||||
assert json_doc["underscore_token"]["token_test"]["value"] == 117
|
||||
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
|
||||
|
||||
|
||||
def test_doc_to_json_with_custom_user_data(doc):
|
||||
Doc.set_extension("json_test", default=False)
|
||||
Token.set_extension("token_test", default=False)
|
||||
Span.set_extension("span_test", default=False)
|
||||
|
||||
doc._.json_test = "hello world"
|
||||
doc[0:1]._.span_test = "span_attribute"
|
||||
doc[0]._.token_test = 117
|
||||
json_doc = doc.to_json(underscore=["json_test", "token_test", "span_test"])
|
||||
doc.user_data["user_data_test"] = 10
|
||||
doc.user_data[("user_data_test2", True)] = 10
|
||||
|
||||
assert "_" in json_doc
|
||||
assert json_doc["_"]["json_test"] == "hello world"
|
||||
assert "underscore_token" in json_doc
|
||||
assert "underscore_span" in json_doc
|
||||
assert json_doc["underscore_token"]["token_test"]["value"] == 117
|
||||
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
|
||||
|
||||
|
||||
def test_doc_to_json_with_token_span_same_identifier(doc):
|
||||
Doc.set_extension("my_ext", default=False)
|
||||
Token.set_extension("my_ext", default=False)
|
||||
Span.set_extension("my_ext", default=False)
|
||||
|
||||
doc._.my_ext = "hello world"
|
||||
doc[0:1]._.my_ext = "span_attribute"
|
||||
doc[0]._.my_ext = 117
|
||||
json_doc = doc.to_json(underscore=["my_ext"])
|
||||
|
||||
assert "_" in json_doc
|
||||
assert json_doc["_"]["my_ext"] == "hello world"
|
||||
assert "underscore_token" in json_doc
|
||||
assert "underscore_span" in json_doc
|
||||
assert json_doc["underscore_token"]["my_ext"]["value"] == 117
|
||||
assert json_doc["underscore_span"]["my_ext"]["value"] == "span_attribute"
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
|
||||
|
||||
|
||||
def test_doc_to_json_with_token_attributes_missing(doc):
|
||||
Token.set_extension("token_test", default=False)
|
||||
Span.set_extension("span_test", default=False)
|
||||
|
||||
doc[0:1]._.span_test = "span_attribute"
|
||||
doc[0]._.token_test = 117
|
||||
json_doc = doc.to_json(underscore=["span_test"])
|
||||
|
||||
assert "underscore_token" in json_doc
|
||||
assert "underscore_span" in json_doc
|
||||
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
|
||||
assert "token_test" not in json_doc["underscore_token"]
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
|
||||
|
||||
def test_doc_to_json_underscore_error_attr(doc):
|
||||
|
@ -94,11 +225,29 @@ def test_doc_to_json_span(doc):
|
|||
assert len(json_doc["spans"]) == 1
|
||||
assert len(json_doc["spans"]["test"]) == 2
|
||||
assert json_doc["spans"]["test"][0]["start"] == 0
|
||||
assert not schemas.validate(schemas.DocJSONSchema, json_doc)
|
||||
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
|
||||
|
||||
|
||||
def test_json_to_doc(doc):
|
||||
new_doc = Doc(doc.vocab).from_json(doc.to_json(), validate=True)
|
||||
json_doc = doc.to_json()
|
||||
json_doc = srsly.json_loads(srsly.json_dumps(json_doc))
|
||||
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)
|
||||
assert new_doc.text == doc.text == "c d e "
|
||||
assert len(new_doc) == len(doc) == 3
|
||||
assert new_doc[0].pos == doc[0].pos
|
||||
assert new_doc[0].tag == doc[0].tag
|
||||
assert new_doc[0].dep == doc[0].dep
|
||||
assert new_doc[0].head.idx == doc[0].head.idx
|
||||
assert new_doc[0].lemma == doc[0].lemma
|
||||
assert len(new_doc.ents) == 1
|
||||
assert new_doc.ents[0].start == 1
|
||||
assert new_doc.ents[0].end == 2
|
||||
assert new_doc.ents[0].label_ == "ORG"
|
||||
assert doc.to_bytes() == new_doc.to_bytes()
|
||||
|
||||
|
||||
def test_json_to_doc_compat(doc, doc_json):
|
||||
new_doc = Doc(doc.vocab).from_json(doc_json, validate=True)
|
||||
new_tokens = [token for token in new_doc]
|
||||
assert new_doc.text == doc.text == "c d e "
|
||||
assert len(new_tokens) == len([token for token in doc]) == 3
|
||||
|
@ -114,11 +263,8 @@ def test_json_to_doc(doc):
|
|||
|
||||
|
||||
def test_json_to_doc_underscore(doc):
|
||||
if not Doc.has_extension("json_test1"):
|
||||
Doc.set_extension("json_test1", default=False)
|
||||
if not Doc.has_extension("json_test2"):
|
||||
Doc.set_extension("json_test2", default=False)
|
||||
|
||||
Doc.set_extension("json_test1", default=False)
|
||||
Doc.set_extension("json_test2", default=False)
|
||||
doc._.json_test1 = "hello world"
|
||||
doc._.json_test2 = [1, 2, 3]
|
||||
json_doc = doc.to_json(underscore=["json_test1", "json_test2"])
|
||||
|
@ -126,6 +272,34 @@ def test_json_to_doc_underscore(doc):
|
|||
assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)])
|
||||
assert new_doc._.json_test1 == "hello world"
|
||||
assert new_doc._.json_test2 == [1, 2, 3]
|
||||
assert doc.to_bytes() == new_doc.to_bytes()
|
||||
|
||||
|
||||
def test_json_to_doc_with_token_span_attributes(doc):
|
||||
Doc.set_extension("json_test1", default=False)
|
||||
Doc.set_extension("json_test2", default=False)
|
||||
Token.set_extension("token_test", default=False)
|
||||
Span.set_extension("span_test", default=False)
|
||||
doc._.json_test1 = "hello world"
|
||||
doc._.json_test2 = [1, 2, 3]
|
||||
doc[0:1]._.span_test = "span_attribute"
|
||||
doc[0]._.token_test = 117
|
||||
|
||||
json_doc = doc.to_json(
|
||||
underscore=["json_test1", "json_test2", "token_test", "span_test"]
|
||||
)
|
||||
json_doc = srsly.json_loads(srsly.json_dumps(json_doc))
|
||||
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)
|
||||
|
||||
assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)])
|
||||
assert new_doc._.json_test1 == "hello world"
|
||||
assert new_doc._.json_test2 == [1, 2, 3]
|
||||
assert new_doc[0]._.token_test == 117
|
||||
assert new_doc[0:1]._.span_test == "span_attribute"
|
||||
assert new_doc.user_data == doc.user_data
|
||||
assert new_doc.to_bytes(exclude=["user_data"]) == doc.to_bytes(
|
||||
exclude=["user_data"]
|
||||
)
|
||||
|
||||
|
||||
def test_json_to_doc_spans(doc):
|
||||
|
|
|
@ -1602,13 +1602,30 @@ cdef class Doc:
|
|||
ents.append(char_span)
|
||||
self.ents = ents
|
||||
|
||||
# Add custom attributes. Note that only Doc extensions are currently considered, Token and Span extensions are
|
||||
# not yet supported.
|
||||
# Add custom attributes for the whole Doc object.
|
||||
for attr in doc_json.get("_", {}):
|
||||
if not Doc.has_extension(attr):
|
||||
Doc.set_extension(attr)
|
||||
self._.set(attr, doc_json["_"][attr])
|
||||
|
||||
if doc_json.get("underscore_token", {}):
|
||||
for token_attr in doc_json["underscore_token"]:
|
||||
token_start = doc_json["underscore_token"][token_attr]["token_start"]
|
||||
value = doc_json["underscore_token"][token_attr]["value"]
|
||||
|
||||
if not Token.has_extension(token_attr):
|
||||
Token.set_extension(token_attr)
|
||||
self[token_start]._.set(token_attr, value)
|
||||
|
||||
if doc_json.get("underscore_span", {}):
|
||||
for span_attr in doc_json["underscore_span"]:
|
||||
token_start = doc_json["underscore_span"][span_attr]["token_start"]
|
||||
token_end = doc_json["underscore_span"][span_attr]["token_end"]
|
||||
value = doc_json["underscore_span"][span_attr]["value"]
|
||||
|
||||
if not Span.has_extension(span_attr):
|
||||
Span.set_extension(span_attr)
|
||||
self[token_start:token_end]._.set(span_attr, value)
|
||||
return self
|
||||
|
||||
def to_json(self, underscore=None):
|
||||
|
@ -1650,20 +1667,40 @@ cdef class Doc:
|
|||
for span_group in self.spans:
|
||||
data["spans"][span_group] = []
|
||||
for span in self.spans[span_group]:
|
||||
span_data = {
|
||||
"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_
|
||||
}
|
||||
span_data = {"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_}
|
||||
data["spans"][span_group].append(span_data)
|
||||
|
||||
if underscore:
|
||||
data["_"] = {}
|
||||
user_keys = set()
|
||||
if self.user_data:
|
||||
data["_"] = {}
|
||||
data["underscore_token"] = {}
|
||||
data["underscore_span"] = {}
|
||||
for data_key in self.user_data:
|
||||
if type(data_key) == tuple and len(data_key) >= 4 and data_key[0] == "._.":
|
||||
attr = data_key[1]
|
||||
start = data_key[2]
|
||||
end = data_key[3]
|
||||
if attr in underscore:
|
||||
user_keys.add(attr)
|
||||
value = self.user_data[data_key]
|
||||
if not srsly.is_json_serializable(value):
|
||||
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
|
||||
# Check if doc attribute
|
||||
if start is None:
|
||||
data["_"][attr] = value
|
||||
# Check if token attribute
|
||||
elif end is None:
|
||||
if attr not in data["underscore_token"]:
|
||||
data["underscore_token"][attr] = {"token_start": start, "value": value}
|
||||
# Else span attribute
|
||||
else:
|
||||
if attr not in data["underscore_span"]:
|
||||
data["underscore_span"][attr] = {"token_start": start, "token_end": end, "value": value}
|
||||
|
||||
for attr in underscore:
|
||||
if not self.has_extension(attr):
|
||||
if attr not in user_keys:
|
||||
raise ValueError(Errors.E106.format(attr=attr, opts=underscore))
|
||||
value = self._.get(attr)
|
||||
if not srsly.is_json_serializable(value):
|
||||
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
|
||||
data["_"][attr] = value
|
||||
return data
|
||||
|
||||
def to_utf8_array(self, int nr_char=-1):
|
||||
|
|
Loading…
Reference in New Issue
Block a user