From 8674b17651e1da154057208b575c19c7dd25dd81 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 28 Aug 2020 20:42:26 +0200 Subject: [PATCH] Serialize AttributeRuler.patterns Serialize `AttributeRuler.patterns` instead of the individual lists to simplify the serialized and so that patterns are reloaded exactly as they were originally provided (preserving `_attrs_unnormed`). --- spacy/pipeline/attributeruler.py | 46 ++------------------- spacy/tests/pipeline/test_attributeruler.py | 2 + 2 files changed, 6 insertions(+), 42 deletions(-) diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py index d93afc642..8e48dce57 100644 --- a/spacy/pipeline/attributeruler.py +++ b/spacy/pipeline/attributeruler.py @@ -230,10 +230,7 @@ class AttributeRuler(Pipe): """ serialize = {} serialize["vocab"] = self.vocab.to_bytes - patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} - serialize["patterns"] = lambda: srsly.msgpack_dumps(patterns) - serialize["attrs"] = lambda: srsly.msgpack_dumps(self.attrs) - serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices) + serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns) return util.to_bytes(serialize, exclude) def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()): @@ -245,31 +242,15 @@ class AttributeRuler(Pipe): DOCS: https://spacy.io/api/attributeruler#from_bytes """ - data = {"patterns": b""} - def load_patterns(b): - data["patterns"] = srsly.msgpack_loads(b) - - def load_attrs(b): - self.attrs = srsly.msgpack_loads(b) - - def load_indices(b): - self.indices = srsly.msgpack_loads(b) + self.add_patterns(srsly.msgpack_loads(b)) deserialize = { "vocab": lambda b: self.vocab.from_bytes(b), "patterns": load_patterns, - "attrs": load_attrs, - "indices": load_indices, } util.from_bytes(bytes_data, deserialize, exclude) - if data["patterns"]: - for key, pattern in data["patterns"].items(): - self.matcher.add(key, pattern) - assert len(self.attrs) == len(data["patterns"]) - assert len(self.indices) == len(data["patterns"]) - return self def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None: @@ -279,12 +260,9 @@ class AttributeRuler(Pipe): exclude (Iterable[str]): String names of serialization fields to exclude. DOCS: https://spacy.io/api/attributeruler#to_disk """ - patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} serialize = { "vocab": lambda p: self.vocab.to_disk(p), - "patterns": lambda p: srsly.write_msgpack(p, patterns), - "attrs": lambda p: srsly.write_msgpack(p, self.attrs), - "indices": lambda p: srsly.write_msgpack(p, self.indices), + "patterns": lambda p: srsly.write_msgpack(p, self.patterns), } util.to_disk(path, serialize, exclude) @@ -297,31 +275,15 @@ class AttributeRuler(Pipe): exclude (Iterable[str]): String names of serialization fields to exclude. DOCS: https://spacy.io/api/attributeruler#from_disk """ - data = {"patterns": b""} - def load_patterns(p): - data["patterns"] = srsly.read_msgpack(p) - - def load_attrs(p): - self.attrs = srsly.read_msgpack(p) - - def load_indices(p): - self.indices = srsly.read_msgpack(p) + self.add_patterns(srsly.read_msgpack(p)) deserialize = { "vocab": lambda p: self.vocab.from_disk(p), "patterns": load_patterns, - "attrs": load_attrs, - "indices": load_indices, } util.from_disk(path, deserialize, exclude) - if data["patterns"]: - for key, pattern in data["patterns"].items(): - self.matcher.add(key, pattern) - assert len(self.attrs) == len(data["patterns"]) - assert len(self.indices) == len(data["patterns"]) - return self diff --git a/spacy/tests/pipeline/test_attributeruler.py b/spacy/tests/pipeline/test_attributeruler.py index d9a492580..4c7cdb4e7 100644 --- a/spacy/tests/pipeline/test_attributeruler.py +++ b/spacy/tests/pipeline/test_attributeruler.py @@ -215,6 +215,7 @@ def test_attributeruler_serialize(nlp, pattern_dicts): assert a.to_bytes() == a_reloaded.to_bytes() doc1 = a_reloaded(nlp.make_doc(text)) numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs)) + assert a.patterns == a_reloaded.patterns # disk roundtrip with make_tempdir() as tmp_dir: @@ -223,3 +224,4 @@ def test_attributeruler_serialize(nlp, pattern_dicts): doc2 = nlp2(text) assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes() assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs)) + assert a.patterns == nlp2.get_pipe("attribute_ruler").patterns