From a77c4c3465d12f70fc2436b6d3def414082d77a9 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 2 Jul 2020 17:11:57 +0200 Subject: [PATCH] Add strings and ENT_KB_ID to Doc serialization (#5691) * Add strings for all writeable Token attributes to `Doc.to/from_bytes()`. * Add ENT_KB_ID to default attributes. --- spacy/tests/doc/test_doc_api.py | 6 ++++++ spacy/tokens/doc.pyx | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 6801d7844..388cd78fe 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -106,10 +106,16 @@ def test_doc_api_getitem(en_tokenizer): ) def test_doc_api_serialize(en_tokenizer, text): tokens = en_tokenizer(text) + tokens[0].lemma_ = "lemma" + tokens[0].norm_ = "norm" + tokens[0].ent_kb_id_ = "ent_kb_id" new_tokens = Doc(tokens.vocab).from_bytes(tokens.to_bytes()) assert tokens.text == new_tokens.text assert [t.text for t in tokens] == [t.text for t in new_tokens] assert [t.orth for t in tokens] == [t.orth for t in new_tokens] + assert new_tokens[0].lemma_ == "lemma" + assert new_tokens[0].norm_ == "norm" + assert new_tokens[0].ent_kb_id_ == "ent_kb_id" new_tokens = Doc(tokens.vocab).from_bytes( tokens.to_bytes(exclude=["tensor"]), exclude=["tensor"] diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 25a147208..5b03dc5d2 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -892,7 +892,7 @@ cdef class Doc: DOCS: https://spacy.io/api/doc#to_bytes """ - array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, NORM] # TODO: ENT_KB_ID ? + array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, NORM, ENT_KB_ID] if self.is_tagged: array_head.extend([TAG, POS]) # If doc parsed add head and dep attribute @@ -901,6 +901,14 @@ cdef class Doc: # Otherwise add sent_start else: array_head.append(SENT_START) + strings = set() + for token in self: + strings.add(token.tag_) + strings.add(token.lemma_) + strings.add(token.dep_) + strings.add(token.ent_type_) + strings.add(token.ent_kb_id_) + strings.add(token.norm_) # Msgpack doesn't distinguish between lists and tuples, which is # vexing for user data. As a best guess, we *know* that within # keys, we must have tuples. In values we just have to hope @@ -912,6 +920,7 @@ cdef class Doc: "sentiment": lambda: self.sentiment, "tensor": lambda: self.tensor, "cats": lambda: self.cats, + "strings": lambda: list(strings), } for key in kwargs: if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"): @@ -942,6 +951,7 @@ cdef class Doc: "sentiment": lambda b: None, "tensor": lambda b: None, "cats": lambda b: None, + "strings": lambda b: None, "user_data_keys": lambda b: None, "user_data_values": lambda b: None, } @@ -965,6 +975,9 @@ cdef class Doc: self.tensor = msg["tensor"] if "cats" not in exclude and "cats" in msg: self.cats = msg["cats"] + if "strings" not in exclude and "strings" in msg: + for s in msg["strings"]: + self.vocab.strings.add(s) start = 0 cdef const LexemeC* lex cdef unicode orth_