diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py index d93afc642..4d1816238 100644 --- a/spacy/pipeline/attributeruler.py +++ b/spacy/pipeline/attributeruler.py @@ -78,7 +78,7 @@ class AttributeRuler(Pipe): DOCS: https://spacy.io/api/attributeruler#call """ - matches = self.matcher(doc) + matches = sorted(self.matcher(doc)) for match_id, start, end in matches: span = Span(doc, start, end, label=match_id) @@ -230,10 +230,7 @@ class AttributeRuler(Pipe): """ serialize = {} serialize["vocab"] = self.vocab.to_bytes - patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} - serialize["patterns"] = lambda: srsly.msgpack_dumps(patterns) - serialize["attrs"] = lambda: srsly.msgpack_dumps(self.attrs) - serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices) + serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns) return util.to_bytes(serialize, exclude) def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()): @@ -245,31 +242,15 @@ class AttributeRuler(Pipe): DOCS: https://spacy.io/api/attributeruler#from_bytes """ - data = {"patterns": b""} - def load_patterns(b): - data["patterns"] = srsly.msgpack_loads(b) - - def load_attrs(b): - self.attrs = srsly.msgpack_loads(b) - - def load_indices(b): - self.indices = srsly.msgpack_loads(b) + self.add_patterns(srsly.msgpack_loads(b)) deserialize = { "vocab": lambda b: self.vocab.from_bytes(b), "patterns": load_patterns, - "attrs": load_attrs, - "indices": load_indices, } util.from_bytes(bytes_data, deserialize, exclude) - if data["patterns"]: - for key, pattern in data["patterns"].items(): - self.matcher.add(key, pattern) - assert len(self.attrs) == len(data["patterns"]) - assert len(self.indices) == len(data["patterns"]) - return self def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None: @@ -279,12 +260,9 @@ class AttributeRuler(Pipe): exclude (Iterable[str]): String names of serialization fields to exclude. DOCS: https://spacy.io/api/attributeruler#to_disk """ - patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} serialize = { "vocab": lambda p: self.vocab.to_disk(p), - "patterns": lambda p: srsly.write_msgpack(p, patterns), - "attrs": lambda p: srsly.write_msgpack(p, self.attrs), - "indices": lambda p: srsly.write_msgpack(p, self.indices), + "patterns": lambda p: srsly.write_msgpack(p, self.patterns), } util.to_disk(path, serialize, exclude) @@ -297,31 +275,15 @@ class AttributeRuler(Pipe): exclude (Iterable[str]): String names of serialization fields to exclude. DOCS: https://spacy.io/api/attributeruler#from_disk """ - data = {"patterns": b""} - def load_patterns(p): - data["patterns"] = srsly.read_msgpack(p) - - def load_attrs(p): - self.attrs = srsly.read_msgpack(p) - - def load_indices(p): - self.indices = srsly.read_msgpack(p) + self.add_patterns(srsly.read_msgpack(p)) deserialize = { "vocab": lambda p: self.vocab.from_disk(p), "patterns": load_patterns, - "attrs": load_attrs, - "indices": load_indices, } util.from_disk(path, deserialize, exclude) - if data["patterns"]: - for key, pattern in data["patterns"].items(): - self.matcher.add(key, pattern) - assert len(self.attrs) == len(data["patterns"]) - assert len(self.indices) == len(data["patterns"]) - return self diff --git a/spacy/tests/pipeline/test_attributeruler.py b/spacy/tests/pipeline/test_attributeruler.py index d9a492580..5fc6faaf0 100644 --- a/spacy/tests/pipeline/test_attributeruler.py +++ b/spacy/tests/pipeline/test_attributeruler.py @@ -112,6 +112,28 @@ def test_attributeruler_score(nlp, pattern_dicts): assert scores["morph_acc"] == pytest.approx(0.6) +def test_attributeruler_rule_order(nlp): + a = AttributeRuler(nlp.vocab) + patterns = [ + { + "patterns": [[{"TAG": "VBZ"}]], + "attrs": {"POS": "VERB"}, + }, + { + "patterns": [[{"TAG": "VBZ"}]], + "attrs": {"POS": "NOUN"}, + }, + ] + a.add_patterns(patterns) + doc = get_doc( + nlp.vocab, + words=["This", "is", "a", "test", "."], + tags=["DT", "VBZ", "DT", "NN", "."] + ) + doc = a(doc) + assert doc[1].pos_ == "NOUN" + + def test_attributeruler_tag_map(nlp, tag_map): a = AttributeRuler(nlp.vocab) a.load_from_tag_map(tag_map) @@ -215,6 +237,7 @@ def test_attributeruler_serialize(nlp, pattern_dicts): assert a.to_bytes() == a_reloaded.to_bytes() doc1 = a_reloaded(nlp.make_doc(text)) numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs)) + assert a.patterns == a_reloaded.patterns # disk roundtrip with make_tempdir() as tmp_dir: @@ -223,3 +246,4 @@ def test_attributeruler_serialize(nlp, pattern_dicts): doc2 = nlp2(text) assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes() assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs)) + assert a.patterns == nlp2.get_pipe("attribute_ruler").patterns