diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 60a76280f..27d19b8ea 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -1,4 +1,4 @@ -# cython: infer_types=True, profile=True +# cython: binding=True, infer_types=True, profile=True from typing import List, Iterable from libcpp.vector cimport vector @@ -24,6 +24,7 @@ from ..schemas import validate_token_pattern from ..errors import Errors, MatchPatternError, Warnings from ..strings import get_string_id from ..attrs import IDS +from ..util import registry from .levenshtein import levenshtein @@ -31,6 +32,21 @@ from .levenshtein import levenshtein DEF PADDING = 5 +cpdef bint _default_fuzzy_compare(s1: str, s2: str, fuzzy: int = -1): + distance = min(len(s1), len(s2)) + distance -= 1 # don't allow completely different tokens + if fuzzy == -1: # FUZZY operator with unspecified fuzzy + fuzzy = 5 # default max fuzzy + distance -= 1 # be more restrictive + distance = min(fuzzy, distance if distance > 0 else 1) + return levenshtein(s1, s2, distance) <= distance + + +@registry.misc("spacy.fuzzy_compare.v1") +def make_fuzzy_compare(): + return _default_fuzzy_compare + + cdef class Matcher: """Match sequences of tokens, based on pattern rules. @@ -1148,13 +1164,3 @@ def _get_extensions(spec, string_store, name2index): name2index[name] = len(name2index) attr_values.append((name2index[name], value)) return attr_values - - -cpdef bint _default_fuzzy_compare(s1: str, s2: str, fuzzy: int = -1): - distance = min(len(s1), len(s2)) - distance -= 1 # don't allow completely different tokens - if fuzzy == -1: # FUZZY operator with unspecified fuzzy - fuzzy = 5 # default max fuzzy - distance -= 1 # be more restrictive - distance = min(fuzzy, distance if distance > 0 else 1) - return levenshtein(s1, s2, distance) <= distance diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 342f693d5..2adc1fed2 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -27,6 +27,7 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]] "overwrite_ents": False, "ent_id_sep": DEFAULT_ENT_ID_SEP, "scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"}, + "fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"}, }, default_score_weights={ "ents_f": 1.0, @@ -43,6 +44,7 @@ def make_entity_ruler( overwrite_ents: bool, ent_id_sep: str, scorer: Optional[Callable], + fuzzy_compare: Callable, ): return EntityRuler( nlp, @@ -52,6 +54,7 @@ def make_entity_ruler( overwrite_ents=overwrite_ents, ent_id_sep=ent_id_sep, scorer=scorer, + fuzzy_compare=fuzzy_compare, ) @@ -86,7 +89,7 @@ class EntityRuler(Pipe): ent_id_sep: str = DEFAULT_ENT_ID_SEP, patterns: Optional[List[PatternType]] = None, scorer: Optional[Callable] = entity_ruler_score, - fuzzy_compare: Optional[Callable] = _default_fuzzy_compare + fuzzy_compare: Callable = _default_fuzzy_compare, ) -> None: """Initialize the entity ruler. If patterns are supplied here, they need to be a list of dictionaries with a `"label"` and `"pattern"` @@ -109,6 +112,7 @@ class EntityRuler(Pipe): ent_id_sep (str): Separator used internally for entity IDs. scorer (Optional[Callable]): The scoring method. Defaults to spacy.scorer.get_ner_prf. + fuzzy_compare (Callable): The fuzzy comparison method. DOCS: https://spacy.io/api/entityruler#init """ @@ -119,7 +123,9 @@ class EntityRuler(Pipe): self.phrase_patterns = defaultdict(list) # type: ignore self._validate = validate self._fuzzy_compare = fuzzy_compare - self.matcher = Matcher(nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare) + self.matcher = Matcher( + nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare + ) self.phrase_matcher_attr = phrase_matcher_attr self.phrase_matcher = PhraseMatcher( nlp.vocab, attr=self.phrase_matcher_attr, validate=validate @@ -129,6 +135,7 @@ class EntityRuler(Pipe): if patterns is not None: self.add_patterns(patterns) self.scorer = scorer + self.fuzzy_compare = fuzzy_compare def __len__(self) -> int: """The number of all patterns added to the entity ruler.""" @@ -339,8 +346,9 @@ class EntityRuler(Pipe): self.token_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list) self._ent_ids = defaultdict(tuple) - self.matcher = Matcher(self.nlp.vocab, validate=self._validate, - fuzzy_compare=self._fuzzy_compare) + self.matcher = Matcher( + self.nlp.vocab, validate=self._validate, fuzzy_compare=self._fuzzy_compare + ) self.phrase_matcher = PhraseMatcher( self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate ) @@ -434,7 +442,7 @@ class EntityRuler(Pipe): self.overwrite = cfg.get("overwrite", False) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) self.phrase_matcher = PhraseMatcher( - self.nlp.vocab, attr=self.phrase_matcher_attr + self.nlp.vocab, attr=self.phrase_matcher_attr, ) self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) else: diff --git a/spacy/pipeline/span_ruler.py b/spacy/pipeline/span_ruler.py index 807a4ffe5..385e335d2 100644 --- a/spacy/pipeline/span_ruler.py +++ b/spacy/pipeline/span_ruler.py @@ -12,7 +12,7 @@ from ..errors import Errors, Warnings from ..util import ensure_path, SimpleFrozenList, registry from ..tokens import Doc, Span from ..scorer import Scorer -from ..matcher import Matcher, PhraseMatcher +from ..matcher import Matcher, PhraseMatcher, _default_fuzzy_compare from .. import util PatternType = Dict[str, Union[str, List[Dict[str, Any]]]] @@ -28,6 +28,7 @@ DEFAULT_SPANS_KEY = "ruler" "overwrite_ents": False, "scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"}, "ent_id_sep": "__unused__", + "fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"}, }, default_score_weights={ "ents_f": 1.0, @@ -44,6 +45,7 @@ def make_entity_ruler( overwrite_ents: bool, scorer: Optional[Callable], ent_id_sep: str, + fuzzy_compare: Callable, ): if overwrite_ents: ents_filter = prioritize_new_ents_filter @@ -60,6 +62,7 @@ def make_entity_ruler( validate=validate, overwrite=False, scorer=scorer, + fuzzy_compare=fuzzy_compare, ) @@ -78,6 +81,7 @@ def make_entity_ruler( "@scorers": "spacy.overlapping_labeled_spans_scorer.v1", "spans_key": DEFAULT_SPANS_KEY, }, + "fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"}, }, default_score_weights={ f"spans_{DEFAULT_SPANS_KEY}_f": 1.0, @@ -97,6 +101,7 @@ def make_span_ruler( validate: bool, overwrite: bool, scorer: Optional[Callable], + fuzzy_compare: Callable, ): return SpanRuler( nlp, @@ -109,6 +114,7 @@ def make_span_ruler( validate=validate, overwrite=overwrite, scorer=scorer, + fuzzy_compare=fuzzy_compare, ) @@ -221,6 +227,7 @@ class SpanRuler(Pipe): scorer: Optional[Callable] = partial( overlapping_labeled_spans_score, spans_key=DEFAULT_SPANS_KEY ), + fuzzy_compare: Callable = _default_fuzzy_compare, ) -> None: """Initialize the span ruler. If patterns are supplied here, they need to be a list of dictionaries with a `"label"` and `"pattern"` @@ -253,6 +260,7 @@ class SpanRuler(Pipe): `annotate_ents` is set. Defaults to `True`. scorer (Optional[Callable]): The scoring method. Defaults to spacy.pipeline.span_ruler.overlapping_labeled_spans_score. + fuzzy_compare (Callable): The default fuzzy comparison method. DOCS: https://spacy.io/api/spanruler#init """ @@ -266,6 +274,7 @@ class SpanRuler(Pipe): self.spans_filter = spans_filter self.ents_filter = ents_filter self.scorer = scorer + self.fuzzy_compare = fuzzy_compare self._match_label_id_map: Dict[int, Dict[str, str]] = {} self.clear() @@ -451,7 +460,11 @@ class SpanRuler(Pipe): DOCS: https://spacy.io/api/spanruler#clear """ self._patterns: List[PatternType] = [] - self.matcher: Matcher = Matcher(self.nlp.vocab, validate=self.validate) + self.matcher: Matcher = Matcher( + self.nlp.vocab, + validate=self.validate, + fuzzy_compare=self.fuzzy_compare, + ) self.phrase_matcher: PhraseMatcher = PhraseMatcher( self.nlp.vocab, attr=self.phrase_matcher_attr, diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index 5525a10ed..de2451838 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -385,9 +385,7 @@ def test_entity_ruler_overlapping_spans(nlp, entity_ruler_factory): @pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory): ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler") - patterns = [ - {"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]} - ] + patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}] ruler.add_patterns(patterns) doc = nlp("helloo") assert len(doc.ents) == 1 @@ -396,24 +394,28 @@ def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory): @pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) def test_entity_ruler_fuzzy(nlp, entity_ruler_factory): - patterns = [ - {"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]} - ] - ruler = EntityRuler(nlp, patterns=patterns) + ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler") + patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}] + ruler.add_patterns(patterns) doc = nlp("helloo") - ruler(doc) assert len(doc.ents) == 1 assert doc.ents[0].label_ == "HELLO" @pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS) def test_entity_ruler_fuzzy_disabled(nlp, entity_ruler_factory): - patterns = [ - {"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]} - ] - ruler = EntityRuler(nlp, patterns=patterns, fuzzy_compare=lambda x, y, z: False) + @registry.misc("test_fuzzy_compare_disabled") + def make_test_fuzzy_compare_disabled(): + return lambda x, y, z: False + + ruler = nlp.add_pipe( + entity_ruler_factory, + name="entity_ruler", + config={"fuzzy_compare": {"@misc": "test_fuzzy_compare_disabled"}}, + ) + patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}] + ruler.add_patterns(patterns) doc = nlp("helloo") - ruler(doc) assert len(doc.ents) == 0