allow fuzzy_compare override from EntityRuler

This commit is contained in:
Kevin Humphreys 2022-11-28 16:33:28 -08:00
parent 5088949cf8
commit e029616f53
3 changed files with 43 additions and 5 deletions

View File

@ -1,6 +1,6 @@
from .matcher import Matcher from .matcher import Matcher, _default_fuzzy_compare
from .phrasematcher import PhraseMatcher from .phrasematcher import PhraseMatcher
from .dependencymatcher import DependencyMatcher from .dependencymatcher import DependencyMatcher
from .levenshtein import levenshtein from .levenshtein import levenshtein
__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher", "levenshtein"] __all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher", "levenshtein", "_default_fuzzy_compare"]

View File

@ -10,7 +10,7 @@ from ..language import Language
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList, registry from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList, registry
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher from ..matcher import Matcher, PhraseMatcher, _default_fuzzy_compare
from ..scorer import get_ner_prf from ..scorer import get_ner_prf
@ -86,6 +86,7 @@ class EntityRuler(Pipe):
ent_id_sep: str = DEFAULT_ENT_ID_SEP, ent_id_sep: str = DEFAULT_ENT_ID_SEP,
patterns: Optional[List[PatternType]] = None, patterns: Optional[List[PatternType]] = None,
scorer: Optional[Callable] = entity_ruler_score, scorer: Optional[Callable] = entity_ruler_score,
fuzzy_compare: Optional[Callable] = _default_fuzzy_compare
) -> None: ) -> None:
"""Initialize the entity ruler. If patterns are supplied here, they """Initialize the entity ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"` need to be a list of dictionaries with a `"label"` and `"pattern"`
@ -117,7 +118,8 @@ class EntityRuler(Pipe):
self.token_patterns = defaultdict(list) # type: ignore self.token_patterns = defaultdict(list) # type: ignore
self.phrase_patterns = defaultdict(list) # type: ignore self.phrase_patterns = defaultdict(list) # type: ignore
self._validate = validate self._validate = validate
self.matcher = Matcher(nlp.vocab, validate=validate) self._fuzzy_compare = fuzzy_compare
self.matcher = Matcher(nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare)
self.phrase_matcher_attr = phrase_matcher_attr self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
@ -337,7 +339,8 @@ class EntityRuler(Pipe):
self.token_patterns = defaultdict(list) self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(tuple) self._ent_ids = defaultdict(tuple)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate) self.matcher = Matcher(self.nlp.vocab, validate=self._validate,
fuzzy_compare=self._fuzzy_compare)
self.phrase_matcher = PhraseMatcher( self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
) )

View File

@ -382,6 +382,41 @@ def test_entity_ruler_overlapping_spans(nlp, entity_ruler_factory):
assert doc.ents[0].label_ == "FOOBAR" assert doc.ents[0].label_ == "FOOBAR"
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory):
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
patterns = [
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
]
ruler.add_patterns(patterns)
doc = nlp("helloo")
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "HELLO"
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_fuzzy(nlp, entity_ruler_factory):
patterns = [
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
]
ruler = EntityRuler(nlp, patterns=patterns)
doc = nlp("helloo")
ruler(doc)
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "HELLO"
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_fuzzy_disabled(nlp, entity_ruler_factory):
patterns = [
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
]
ruler = EntityRuler(nlp, patterns=patterns, fuzzy_compare=lambda x, y, z: False)
doc = nlp("helloo")
ruler(doc)
assert len(doc.ents) == 0
@pytest.mark.parametrize("n_process", [1, 2]) @pytest.mark.parametrize("n_process", [1, 2])
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) @pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_multiprocessing(nlp, n_process, entity_ruler_factory): def test_entity_ruler_multiprocessing(nlp, n_process, entity_ruler_factory):