spaCy/spacy/tests/pipeline/test_entity_ruler.py
Adriane Boyd f94168a41e
Backport bugfixes from v3.1.0 to v3.0 (#8739)
* Fix scoring normalization (#7629)

* fix scoring normalization

* score weights by total sum instead of per component

* cleanup

* more cleanup

* Use a context manager when reading model (fix #7036) (#8244)

* Fix other open calls without context managers (#8245)

* Don't add duplicate patterns all the time in EntityRuler (fix #8216) (#8246)

* Don't add duplicate patterns (fix #8216)

* Refactor EntityRuler init

This simplifies the EntityRuler init code. This is helpful as prep for
allowing the EntityRuler to reset itself.

* Make EntityRuler.clear reset matchers

Includes a new test for this.

* Tidy PhraseMatcher instantiation

Since the attr can be None safely now, the guard if is no longer
required here.

Also renamed the `_validate` attr. Maybe it's not needed?

* Fix NER test

* Add test to make sure patterns aren't increasing

* Move test to regression tests

* Exclude generated .cpp files from package (#8271)

* Fix non-deterministic deduplication in Greek lemmatizer (#8421)

* Fix setting empty entities in Example.from_dict (#8426)

* Filter W036 for entity ruler, etc. (#8424)

* Preserve paths.vectors/initialize.vectors setting in quickstart template

* Various fixes for spans in Docs.from_docs (#8487)

* Fix spans offsets if a doc ends in a single space and no space is
  inserted
* Also include spans key in merged doc for empty spans lists

* Fix duplicate spacy package CLI opts (#8551)

Use `-c` for `--code` and not additionally for `--create-meta`, in line
with the docs.

* Raise an error for textcat with <2 labels (#8584)

* Raise an error for textcat with <2 labels

Raise an error if initializing a `textcat` component without at least
two labels.

* Add similar note to docs

* Update positive_label description in API docs

* Add Macedonian models to website (#8637)

* Fix Azerbaijani init, extend lang init tests (#8656)

* Extend langs in initialize tests

* Fix az init

* Fix ru/uk lemmatizer mp with spawn (#8657)

Use an instance variable instead a class variable for the morphological
analzyer so that multiprocessing with spawn is possible.

* Use 0-vector for OOV lexemes (#8639)

* Set version to v3.0.7

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>
2021-07-19 09:20:40 +02:00

241 lines
7.9 KiB
Python

import pytest
from spacy import registry
from spacy.tokens import Span
from spacy.language import Language
from spacy.pipeline import EntityRuler
from spacy.errors import MatchPatternError
from thinc.api import NumpyOps, get_current_ops
@pytest.fixture
def nlp():
return Language()
@pytest.fixture
@registry.misc("entity_ruler_patterns")
def patterns():
return [
{"label": "HELLO", "pattern": "hello world"},
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
]
@Language.component("add_ent")
def add_ent_component(doc):
doc.ents = [Span(doc, 0, 3, label="ORG")]
return doc
def test_entity_ruler_init(nlp, patterns):
ruler = EntityRuler(nlp, patterns=patterns)
assert len(ruler) == len(patterns)
assert len(ruler.labels) == 4
assert "HELLO" in ruler
assert "BYE" in ruler
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
doc = nlp("hello world bye bye")
assert len(doc.ents) == 2
assert doc.ents[0].label_ == "HELLO"
assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_no_patterns_warns(nlp):
ruler = EntityRuler(nlp)
assert len(ruler) == 0
assert len(ruler.labels) == 0
nlp.add_pipe("entity_ruler")
assert nlp.pipe_names == ["entity_ruler"]
with pytest.warns(UserWarning):
doc = nlp("hello world bye bye")
assert len(doc.ents) == 0
def test_entity_ruler_init_patterns(nlp, patterns):
# initialize with patterns
ruler = nlp.add_pipe("entity_ruler")
assert len(ruler.labels) == 0
ruler.initialize(lambda: [], patterns=patterns)
assert len(ruler.labels) == 4
doc = nlp("hello world bye bye")
assert doc.ents[0].label_ == "HELLO"
assert doc.ents[1].label_ == "BYE"
nlp.remove_pipe("entity_ruler")
# initialize with patterns from misc registry
nlp.config["initialize"]["components"]["entity_ruler"] = {
"patterns": {"@misc": "entity_ruler_patterns"}
}
ruler = nlp.add_pipe("entity_ruler")
assert len(ruler.labels) == 0
nlp.initialize()
assert len(ruler.labels) == 4
doc = nlp("hello world bye bye")
assert doc.ents[0].label_ == "HELLO"
assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_init_clear(nlp, patterns):
"""Test that initialization clears patterns."""
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
assert len(ruler.labels) == 4
ruler.initialize(lambda: [])
assert len(ruler.labels) == 0
def test_entity_ruler_clear(nlp, patterns):
"""Test that initialization clears patterns."""
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
assert len(ruler.labels) == 4
doc = nlp("hello world")
assert len(doc.ents) == 1
ruler.clear()
assert len(ruler.labels) == 0
with pytest.warns(UserWarning):
doc = nlp("hello world")
assert len(doc.ents) == 0
def test_entity_ruler_existing(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
nlp.add_pipe("add_ent", before="entity_ruler")
doc = nlp("OH HELLO WORLD bye bye")
assert len(doc.ents) == 2
assert doc.ents[0].label_ == "ORG"
assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_existing_overwrite(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True})
ruler.add_patterns(patterns)
nlp.add_pipe("add_ent", before="entity_ruler")
doc = nlp("OH HELLO WORLD bye bye")
assert len(doc.ents) == 2
assert doc.ents[0].label_ == "HELLO"
assert doc.ents[0].text == "HELLO"
assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_existing_complex(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True})
ruler.add_patterns(patterns)
nlp.add_pipe("add_ent", before="entity_ruler")
doc = nlp("foo foo bye bye")
assert len(doc.ents) == 2
assert doc.ents[0].label_ == "COMPLEX"
assert doc.ents[1].label_ == "BYE"
assert len(doc.ents[0]) == 2
assert len(doc.ents[1]) == 2
def test_entity_ruler_entity_id(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True})
ruler.add_patterns(patterns)
doc = nlp("Apple is a technology company")
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "TECH_ORG"
assert doc.ents[0].ent_id_ == "a1"
def test_entity_ruler_cfg_ent_id_sep(nlp, patterns):
config = {"overwrite_ents": True, "ent_id_sep": "**"}
ruler = nlp.add_pipe("entity_ruler", config=config)
ruler.add_patterns(patterns)
assert "TECH_ORG**a1" in ruler.phrase_patterns
doc = nlp("Apple is a technology company")
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "TECH_ORG"
assert doc.ents[0].ent_id_ == "a1"
def test_entity_ruler_serialize_bytes(nlp, patterns):
ruler = EntityRuler(nlp, patterns=patterns)
assert len(ruler) == len(patterns)
assert len(ruler.labels) == 4
ruler_bytes = ruler.to_bytes()
new_ruler = EntityRuler(nlp)
assert len(new_ruler) == 0
assert len(new_ruler.labels) == 0
new_ruler = new_ruler.from_bytes(ruler_bytes)
assert len(new_ruler) == len(patterns)
assert len(new_ruler.labels) == 4
assert len(new_ruler.patterns) == len(ruler.patterns)
for pattern in ruler.patterns:
assert pattern in new_ruler.patterns
assert sorted(new_ruler.labels) == sorted(ruler.labels)
def test_entity_ruler_serialize_phrase_matcher_attr_bytes(nlp, patterns):
ruler = EntityRuler(nlp, phrase_matcher_attr="LOWER", patterns=patterns)
assert len(ruler) == len(patterns)
assert len(ruler.labels) == 4
ruler_bytes = ruler.to_bytes()
new_ruler = EntityRuler(nlp)
assert len(new_ruler) == 0
assert len(new_ruler.labels) == 0
assert new_ruler.phrase_matcher_attr is None
new_ruler = new_ruler.from_bytes(ruler_bytes)
assert len(new_ruler) == len(patterns)
assert len(new_ruler.labels) == 4
assert new_ruler.phrase_matcher_attr == "LOWER"
def test_entity_ruler_validate(nlp):
ruler = EntityRuler(nlp)
validated_ruler = EntityRuler(nlp, validate=True)
valid_pattern = {"label": "HELLO", "pattern": [{"LOWER": "HELLO"}]}
invalid_pattern = {"label": "HELLO", "pattern": [{"ASDF": "HELLO"}]}
# invalid pattern raises error without validate
with pytest.raises(ValueError):
ruler.add_patterns([invalid_pattern])
# valid pattern is added without errors with validate
validated_ruler.add_patterns([valid_pattern])
# invalid pattern raises error with validate
with pytest.raises(MatchPatternError):
validated_ruler.add_patterns([invalid_pattern])
def test_entity_ruler_properties(nlp, patterns):
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
assert sorted(ruler.labels) == sorted(["HELLO", "BYE", "COMPLEX", "TECH_ORG"])
assert sorted(ruler.ent_ids) == ["a1", "a2"]
def test_entity_ruler_overlapping_spans(nlp):
ruler = EntityRuler(nlp)
patterns = [
{"label": "FOOBAR", "pattern": "foo bar"},
{"label": "BARBAZ", "pattern": "bar baz"},
]
ruler.add_patterns(patterns)
doc = ruler(nlp.make_doc("foo bar baz"))
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "FOOBAR"
@pytest.mark.parametrize("n_process", [1, 2])
def test_entity_ruler_multiprocessing(nlp, n_process):
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
texts = ["I enjoy eating Pizza Hut pizza."]
patterns = [{"label": "FASTFOOD", "pattern": "Pizza Hut", "id": "1234"}]
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
for doc in nlp.pipe(texts, n_process=2):
for ent in doc.ents:
assert ent.ent_id_ == "1234"