mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
Merge pull request #5995 from adrianeboyd/bugfix/attribute-ruler-bugfixes
This commit is contained in:
commit
f45095a666
|
@ -78,7 +78,7 @@ class AttributeRuler(Pipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/attributeruler#call
|
DOCS: https://spacy.io/api/attributeruler#call
|
||||||
"""
|
"""
|
||||||
matches = self.matcher(doc)
|
matches = sorted(self.matcher(doc))
|
||||||
|
|
||||||
for match_id, start, end in matches:
|
for match_id, start, end in matches:
|
||||||
span = Span(doc, start, end, label=match_id)
|
span = Span(doc, start, end, label=match_id)
|
||||||
|
@ -230,10 +230,7 @@ class AttributeRuler(Pipe):
|
||||||
"""
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
serialize["vocab"] = self.vocab.to_bytes
|
serialize["vocab"] = self.vocab.to_bytes
|
||||||
patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))}
|
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
|
||||||
serialize["patterns"] = lambda: srsly.msgpack_dumps(patterns)
|
|
||||||
serialize["attrs"] = lambda: srsly.msgpack_dumps(self.attrs)
|
|
||||||
serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices)
|
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()):
|
def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()):
|
||||||
|
@ -245,31 +242,15 @@ class AttributeRuler(Pipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/attributeruler#from_bytes
|
DOCS: https://spacy.io/api/attributeruler#from_bytes
|
||||||
"""
|
"""
|
||||||
data = {"patterns": b""}
|
|
||||||
|
|
||||||
def load_patterns(b):
|
def load_patterns(b):
|
||||||
data["patterns"] = srsly.msgpack_loads(b)
|
self.add_patterns(srsly.msgpack_loads(b))
|
||||||
|
|
||||||
def load_attrs(b):
|
|
||||||
self.attrs = srsly.msgpack_loads(b)
|
|
||||||
|
|
||||||
def load_indices(b):
|
|
||||||
self.indices = srsly.msgpack_loads(b)
|
|
||||||
|
|
||||||
deserialize = {
|
deserialize = {
|
||||||
"vocab": lambda b: self.vocab.from_bytes(b),
|
"vocab": lambda b: self.vocab.from_bytes(b),
|
||||||
"patterns": load_patterns,
|
"patterns": load_patterns,
|
||||||
"attrs": load_attrs,
|
|
||||||
"indices": load_indices,
|
|
||||||
}
|
}
|
||||||
util.from_bytes(bytes_data, deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
|
|
||||||
if data["patterns"]:
|
|
||||||
for key, pattern in data["patterns"].items():
|
|
||||||
self.matcher.add(key, pattern)
|
|
||||||
assert len(self.attrs) == len(data["patterns"])
|
|
||||||
assert len(self.indices) == len(data["patterns"])
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None:
|
def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None:
|
||||||
|
@ -279,12 +260,9 @@ class AttributeRuler(Pipe):
|
||||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
DOCS: https://spacy.io/api/attributeruler#to_disk
|
DOCS: https://spacy.io/api/attributeruler#to_disk
|
||||||
"""
|
"""
|
||||||
patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))}
|
|
||||||
serialize = {
|
serialize = {
|
||||||
"vocab": lambda p: self.vocab.to_disk(p),
|
"vocab": lambda p: self.vocab.to_disk(p),
|
||||||
"patterns": lambda p: srsly.write_msgpack(p, patterns),
|
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
|
||||||
"attrs": lambda p: srsly.write_msgpack(p, self.attrs),
|
|
||||||
"indices": lambda p: srsly.write_msgpack(p, self.indices),
|
|
||||||
}
|
}
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -297,31 +275,15 @@ class AttributeRuler(Pipe):
|
||||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
DOCS: https://spacy.io/api/attributeruler#from_disk
|
DOCS: https://spacy.io/api/attributeruler#from_disk
|
||||||
"""
|
"""
|
||||||
data = {"patterns": b""}
|
|
||||||
|
|
||||||
def load_patterns(p):
|
def load_patterns(p):
|
||||||
data["patterns"] = srsly.read_msgpack(p)
|
self.add_patterns(srsly.read_msgpack(p))
|
||||||
|
|
||||||
def load_attrs(p):
|
|
||||||
self.attrs = srsly.read_msgpack(p)
|
|
||||||
|
|
||||||
def load_indices(p):
|
|
||||||
self.indices = srsly.read_msgpack(p)
|
|
||||||
|
|
||||||
deserialize = {
|
deserialize = {
|
||||||
"vocab": lambda p: self.vocab.from_disk(p),
|
"vocab": lambda p: self.vocab.from_disk(p),
|
||||||
"patterns": load_patterns,
|
"patterns": load_patterns,
|
||||||
"attrs": load_attrs,
|
|
||||||
"indices": load_indices,
|
|
||||||
}
|
}
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
||||||
if data["patterns"]:
|
|
||||||
for key, pattern in data["patterns"].items():
|
|
||||||
self.matcher.add(key, pattern)
|
|
||||||
assert len(self.attrs) == len(data["patterns"])
|
|
||||||
assert len(self.indices) == len(data["patterns"])
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -112,6 +112,28 @@ def test_attributeruler_score(nlp, pattern_dicts):
|
||||||
assert scores["morph_acc"] == pytest.approx(0.6)
|
assert scores["morph_acc"] == pytest.approx(0.6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_attributeruler_rule_order(nlp):
|
||||||
|
a = AttributeRuler(nlp.vocab)
|
||||||
|
patterns = [
|
||||||
|
{
|
||||||
|
"patterns": [[{"TAG": "VBZ"}]],
|
||||||
|
"attrs": {"POS": "VERB"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"patterns": [[{"TAG": "VBZ"}]],
|
||||||
|
"attrs": {"POS": "NOUN"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
a.add_patterns(patterns)
|
||||||
|
doc = get_doc(
|
||||||
|
nlp.vocab,
|
||||||
|
words=["This", "is", "a", "test", "."],
|
||||||
|
tags=["DT", "VBZ", "DT", "NN", "."]
|
||||||
|
)
|
||||||
|
doc = a(doc)
|
||||||
|
assert doc[1].pos_ == "NOUN"
|
||||||
|
|
||||||
|
|
||||||
def test_attributeruler_tag_map(nlp, tag_map):
|
def test_attributeruler_tag_map(nlp, tag_map):
|
||||||
a = AttributeRuler(nlp.vocab)
|
a = AttributeRuler(nlp.vocab)
|
||||||
a.load_from_tag_map(tag_map)
|
a.load_from_tag_map(tag_map)
|
||||||
|
@ -215,6 +237,7 @@ def test_attributeruler_serialize(nlp, pattern_dicts):
|
||||||
assert a.to_bytes() == a_reloaded.to_bytes()
|
assert a.to_bytes() == a_reloaded.to_bytes()
|
||||||
doc1 = a_reloaded(nlp.make_doc(text))
|
doc1 = a_reloaded(nlp.make_doc(text))
|
||||||
numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs))
|
numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs))
|
||||||
|
assert a.patterns == a_reloaded.patterns
|
||||||
|
|
||||||
# disk roundtrip
|
# disk roundtrip
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
|
@ -223,3 +246,4 @@ def test_attributeruler_serialize(nlp, pattern_dicts):
|
||||||
doc2 = nlp2(text)
|
doc2 = nlp2(text)
|
||||||
assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes()
|
assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes()
|
||||||
assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs))
|
assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs))
|
||||||
|
assert a.patterns == nlp2.get_pipe("attribute_ruler").patterns
|
||||||
|
|
Loading…
Reference in New Issue
Block a user