spaCy/spacy/tests/regression/test_issue3526.py
Joshua Smith 2eb925bd05 Added an argument to EntityRuler constructor to pass attrs to… (#3919)
* Perserve flags in EntityRuler

The EntityRuler (explosion/spaCy#3526) does not preserve
overwrite flags (or `ent_id_sep`) when serialized.  This
commit adds support for serialization/deserialization preserving
overwrite and ent_id_sep flags.

* add signed contributor agreement

* flake8 cleanup

mostly blank line issues.

* mark test from the issue as needing a model

The test from the issue needs some language model for serialization
but the test wasn't originally marked correctly.

* Adds `phrase_matcher_attr` to allow args to PhraseMatcher

This is an added arg to pass to the `PhraseMatcher`. For example,
this allows creation of a case insensitive phrase matcher when the
`EntityRuler` is created.  References explosion/spaCy#3822

* remove unneeded model loading

The model didn't need to be loaded, and I replaced it with
a change that doesn't require it (using existings fixtures)

* updated docstring for new argument

* updated docs to reflect new argument to the EntityRuler constructor

* change tempdir handling to be compatible with python 2.7

* return conflicted code to entityruler

Some stuff got cut out because of merge conflicts, this
returns that code for the phrase_matcher_attr.

* fixed typo in the code added back after conflicts

* flake8 compliance

When I deconflicted the branch there were some flake8 issues
introduced. This resolves the spacing problems.

* test changes:  attempts to fix flaky test in python3.5

These tests seem to be alittle flaky in 3.5 so I changed the check to avoid
the comparisons that seem to be fail sometimes.
2019-07-09 20:09:17 +02:00

87 lines
3.0 KiB
Python

# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.tokens import Span
from spacy.language import Language
from spacy.pipeline import EntityRuler
from spacy import load
import srsly
from ..util import make_tempdir
@pytest.fixture
def patterns():
return [
{"label": "HELLO", "pattern": "hello world"},
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
]
@pytest.fixture
def add_ent():
def add_ent_component(doc):
doc.ents = [Span(doc, 0, 3, label=doc.vocab.strings["ORG"])]
return doc
return add_ent_component
def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab):
nlp = Language(vocab=en_vocab)
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
ruler_bytes = ruler.to_bytes()
assert len(ruler) == len(patterns)
assert len(ruler.labels) == 4
assert ruler.overwrite
new_ruler = EntityRuler(nlp)
new_ruler = new_ruler.from_bytes(ruler_bytes)
assert len(new_ruler) == len(ruler)
assert len(new_ruler.labels) == 4
assert new_ruler.overwrite == ruler.overwrite
assert new_ruler.ent_id_sep == ruler.ent_id_sep
def test_entity_ruler_existing_bytes_old_format_safe(patterns, en_vocab):
nlp = Language(vocab=en_vocab)
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
bytes_old_style = srsly.msgpack_dumps(ruler.patterns)
new_ruler = EntityRuler(nlp)
new_ruler = new_ruler.from_bytes(bytes_old_style)
assert len(new_ruler) == len(ruler)
for pattern in ruler.patterns:
assert pattern in new_ruler.patterns
assert new_ruler.overwrite is not ruler.overwrite
def test_entity_ruler_from_disk_old_format_safe(patterns, en_vocab):
nlp = Language(vocab=en_vocab)
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
with make_tempdir() as tmpdir:
out_file = tmpdir / "entity_ruler.jsonl"
srsly.write_jsonl(out_file, ruler.patterns)
new_ruler = EntityRuler(nlp)
new_ruler = new_ruler.from_disk(out_file)
for pattern in ruler.patterns:
assert pattern in new_ruler.patterns
assert len(new_ruler) == len(ruler)
assert new_ruler.overwrite is not ruler.overwrite
def test_entity_ruler_in_pipeline_from_issue(patterns, en_vocab):
nlp = Language(vocab=en_vocab)
ruler = EntityRuler(nlp, overwrite_ents=True)
ruler.add_patterns([{"label": "ORG", "pattern": "Apple"}])
nlp.add_pipe(ruler)
with make_tempdir() as tmpdir:
nlp.to_disk(tmpdir)
assert nlp.pipeline[-1][-1].patterns == [{"label": "ORG", "pattern": "Apple"}]
assert nlp.pipeline[-1][-1].overwrite is True
nlp2 = load(tmpdir)
assert nlp2.pipeline[-1][-1].patterns == [{"label": "ORG", "pattern": "Apple"}]
assert nlp2.pipeline[-1][-1].overwrite is True