mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 13:17:06 +03:00
e8420ab2b7
* Perserve flags in EntityRuler The EntityRuler (explosion/spaCy#3526) does not preserve overwrite flags (or `ent_id_sep`) when serialized. This commit adds support for serialization/deserialization preserving overwrite and ent_id_sep flags. * add signed contributor agreement * flake8 cleanup mostly blank line issues. * mark test from the issue as needing a model The test from the issue needs some language model for serialization but the test wasn't originally marked correctly. * remove unneeded model loading The model didn't need to be loaded, and I replaced it with a change that doesn't require it (using existings fixtures) * change tempdir handling to be compatible with python 2.7 * Adds code to handle item saved before this change. This code chanes how the save files are handled and how the bytes are stored as well. This code adds check to dispatch correctly if it encounters bytes or files saved in the old format (and tests for those cases). * use util function for tempdir management Updated after PR comments: this code now uses the make_tempdir function from util instead of doing it by hand.
83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
# coding: utf8
|
|
from __future__ import unicode_literals
|
|
|
|
import pytest
|
|
from spacy.tokens import Span
|
|
from spacy.language import Language
|
|
from spacy.pipeline import EntityRuler
|
|
from spacy import load
|
|
import srsly
|
|
from ..util import make_tempdir
|
|
|
|
@pytest.fixture
|
|
def patterns():
|
|
return [
|
|
{"label": "HELLO", "pattern": "hello world"},
|
|
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
|
|
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
|
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
|
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def add_ent():
|
|
def add_ent_component(doc):
|
|
doc.ents = [Span(doc, 0, 3, label=doc.vocab.strings["ORG"])]
|
|
return doc
|
|
|
|
return add_ent_component
|
|
|
|
def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab):
|
|
nlp = Language(vocab=en_vocab)
|
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
|
ruler_bytes = ruler.to_bytes()
|
|
assert len(ruler) == len(patterns)
|
|
assert len(ruler.labels) == 4
|
|
assert ruler.overwrite
|
|
new_ruler = EntityRuler(nlp)
|
|
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
|
assert len(new_ruler) == len(ruler)
|
|
assert len(new_ruler.labels) == 4
|
|
assert new_ruler.overwrite == ruler.overwrite
|
|
assert new_ruler.ent_id_sep == ruler.ent_id_sep
|
|
|
|
|
|
def test_entity_ruler_existing_bytes_old_format_safe(patterns, en_vocab):
|
|
nlp = Language(vocab=en_vocab)
|
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
|
bytes_old_style = srsly.msgpack_dumps(ruler.patterns)
|
|
new_ruler = EntityRuler(nlp)
|
|
new_ruler = new_ruler.from_bytes(bytes_old_style)
|
|
assert len(new_ruler) == len(ruler)
|
|
assert new_ruler.patterns == ruler.patterns
|
|
assert new_ruler.overwrite is not ruler.overwrite
|
|
|
|
|
|
def test_entity_ruler_from_disk_old_format_safe(patterns, en_vocab):
|
|
nlp = Language(vocab=en_vocab)
|
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
|
with make_tempdir() as tmpdir:
|
|
out_file = tmpdir / "entity_ruler.jsonl"
|
|
srsly.write_jsonl(out_file, ruler.patterns)
|
|
new_ruler = EntityRuler(nlp)
|
|
new_ruler = new_ruler.from_disk(out_file)
|
|
assert new_ruler.patterns == ruler.patterns
|
|
assert len(new_ruler) == len(ruler)
|
|
assert new_ruler.overwrite is not ruler.overwrite
|
|
|
|
|
|
def test_entity_ruler_in_pipeline_from_issue(patterns, en_vocab):
|
|
nlp = Language(vocab=en_vocab)
|
|
ruler = EntityRuler(nlp, overwrite_ents=True)
|
|
|
|
ruler.add_patterns([{"label": "ORG", "pattern": "Apple"}])
|
|
nlp.add_pipe(ruler)
|
|
with make_tempdir() as tmpdir:
|
|
nlp.to_disk(tmpdir)
|
|
assert nlp.pipeline[-1][-1].patterns == [{"label": "ORG", "pattern": "Apple"}]
|
|
assert nlp.pipeline[-1][-1].overwrite is True
|
|
nlp2 = load(tmpdir)
|
|
assert nlp2.pipeline[-1][-1].patterns == [{"label": "ORG", "pattern": "Apple"}]
|
|
assert nlp2.pipeline[-1][-1].overwrite is True
|