diff --git a/spacy/matcher/__init__.py b/spacy/matcher/__init__.py index a4f164847..2fb347da4 100644 --- a/spacy/matcher/__init__.py +++ b/spacy/matcher/__init__.py @@ -1,6 +1,6 @@ -from .matcher import Matcher +from .matcher import Matcher, _default_fuzzy_compare from .phrasematcher import PhraseMatcher from .dependencymatcher import DependencyMatcher from .levenshtein import levenshtein -__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher", "levenshtein"] +__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher", "levenshtein", "_default_fuzzy_compare"] diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 8154a077d..342f693d5 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -10,7 +10,7 @@ from ..language import Language from ..errors import Errors, Warnings from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList, registry from ..tokens import Doc, Span -from ..matcher import Matcher, PhraseMatcher +from ..matcher import Matcher, PhraseMatcher, _default_fuzzy_compare from ..scorer import get_ner_prf @@ -86,6 +86,7 @@ class EntityRuler(Pipe): ent_id_sep: str = DEFAULT_ENT_ID_SEP, patterns: Optional[List[PatternType]] = None, scorer: Optional[Callable] = entity_ruler_score, + fuzzy_compare: Optional[Callable] = _default_fuzzy_compare ) -> None: """Initialize the entity ruler. If patterns are supplied here, they need to be a list of dictionaries with a `"label"` and `"pattern"` @@ -117,7 +118,8 @@ class EntityRuler(Pipe): self.token_patterns = defaultdict(list) # type: ignore self.phrase_patterns = defaultdict(list) # type: ignore self._validate = validate - self.matcher = Matcher(nlp.vocab, validate=validate) + self._fuzzy_compare = fuzzy_compare + self.matcher = Matcher(nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare) self.phrase_matcher_attr = phrase_matcher_attr self.phrase_matcher = PhraseMatcher( nlp.vocab, attr=self.phrase_matcher_attr, validate=validate @@ -337,7 +339,8 @@ class EntityRuler(Pipe): self.token_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list) self._ent_ids = defaultdict(tuple) - self.matcher = Matcher(self.nlp.vocab, validate=self._validate) + self.matcher = Matcher(self.nlp.vocab, validate=self._validate, + fuzzy_compare=self._fuzzy_compare) self.phrase_matcher = PhraseMatcher( self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate ) diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index 6851e2a7c..5525a10ed 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -382,6 +382,41 @@ def test_entity_ruler_overlapping_spans(nlp, entity_ruler_factory): assert doc.ents[0].label_ == "FOOBAR" +@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) +def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory): + ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler") + patterns = [ + {"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]} + ] + ruler.add_patterns(patterns) + doc = nlp("helloo") + assert len(doc.ents) == 1 + assert doc.ents[0].label_ == "HELLO" + + +@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) +def test_entity_ruler_fuzzy(nlp, entity_ruler_factory): + patterns = [ + {"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]} + ] + ruler = EntityRuler(nlp, patterns=patterns) + doc = nlp("helloo") + ruler(doc) + assert len(doc.ents) == 1 + assert doc.ents[0].label_ == "HELLO" + + +@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) +def test_entity_ruler_fuzzy_disabled(nlp, entity_ruler_factory): + patterns = [ + {"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]} + ] + ruler = EntityRuler(nlp, patterns=patterns, fuzzy_compare=lambda x, y, z: False) + doc = nlp("helloo") + ruler(doc) + assert len(doc.ents) == 0 + + @pytest.mark.parametrize("n_process", [1, 2]) @pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) def test_entity_ruler_multiprocessing(nlp, n_process, entity_ruler_factory):