Add strings and ENT_KB_ID to Doc serialization (#5691)

* Add strings for all writeable Token attributes to `Doc.to/from_bytes()`.
* Add ENT_KB_ID to default attributes.
This commit is contained in:
Adriane Boyd 2020-07-02 17:11:57 +02:00 committed by GitHub
parent 971826a96d
commit a77c4c3465
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 1 deletions

View File

@ -106,10 +106,16 @@ def test_doc_api_getitem(en_tokenizer):
) )
def test_doc_api_serialize(en_tokenizer, text): def test_doc_api_serialize(en_tokenizer, text):
tokens = en_tokenizer(text) tokens = en_tokenizer(text)
tokens[0].lemma_ = "lemma"
tokens[0].norm_ = "norm"
tokens[0].ent_kb_id_ = "ent_kb_id"
new_tokens = Doc(tokens.vocab).from_bytes(tokens.to_bytes()) new_tokens = Doc(tokens.vocab).from_bytes(tokens.to_bytes())
assert tokens.text == new_tokens.text assert tokens.text == new_tokens.text
assert [t.text for t in tokens] == [t.text for t in new_tokens] assert [t.text for t in tokens] == [t.text for t in new_tokens]
assert [t.orth for t in tokens] == [t.orth for t in new_tokens] assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
assert new_tokens[0].lemma_ == "lemma"
assert new_tokens[0].norm_ == "norm"
assert new_tokens[0].ent_kb_id_ == "ent_kb_id"
new_tokens = Doc(tokens.vocab).from_bytes( new_tokens = Doc(tokens.vocab).from_bytes(
tokens.to_bytes(exclude=["tensor"]), exclude=["tensor"] tokens.to_bytes(exclude=["tensor"]), exclude=["tensor"]

View File

@ -892,7 +892,7 @@ cdef class Doc:
DOCS: https://spacy.io/api/doc#to_bytes DOCS: https://spacy.io/api/doc#to_bytes
""" """
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, NORM] # TODO: ENT_KB_ID ? array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, NORM, ENT_KB_ID]
if self.is_tagged: if self.is_tagged:
array_head.extend([TAG, POS]) array_head.extend([TAG, POS])
# If doc parsed add head and dep attribute # If doc parsed add head and dep attribute
@ -901,6 +901,14 @@ cdef class Doc:
# Otherwise add sent_start # Otherwise add sent_start
else: else:
array_head.append(SENT_START) array_head.append(SENT_START)
strings = set()
for token in self:
strings.add(token.tag_)
strings.add(token.lemma_)
strings.add(token.dep_)
strings.add(token.ent_type_)
strings.add(token.ent_kb_id_)
strings.add(token.norm_)
# Msgpack doesn't distinguish between lists and tuples, which is # Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within # vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope # keys, we must have tuples. In values we just have to hope
@ -912,6 +920,7 @@ cdef class Doc:
"sentiment": lambda: self.sentiment, "sentiment": lambda: self.sentiment,
"tensor": lambda: self.tensor, "tensor": lambda: self.tensor,
"cats": lambda: self.cats, "cats": lambda: self.cats,
"strings": lambda: list(strings),
} }
for key in kwargs: for key in kwargs:
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"): if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
@ -942,6 +951,7 @@ cdef class Doc:
"sentiment": lambda b: None, "sentiment": lambda b: None,
"tensor": lambda b: None, "tensor": lambda b: None,
"cats": lambda b: None, "cats": lambda b: None,
"strings": lambda b: None,
"user_data_keys": lambda b: None, "user_data_keys": lambda b: None,
"user_data_values": lambda b: None, "user_data_values": lambda b: None,
} }
@ -965,6 +975,9 @@ cdef class Doc:
self.tensor = msg["tensor"] self.tensor = msg["tensor"]
if "cats" not in exclude and "cats" in msg: if "cats" not in exclude and "cats" in msg:
self.cats = msg["cats"] self.cats = msg["cats"]
if "strings" not in exclude and "strings" in msg:
for s in msg["strings"]:
self.vocab.strings.add(s)
start = 0 start = 0
cdef const LexemeC* lex cdef const LexemeC* lex
cdef unicode orth_ cdef unicode orth_