Implement fuzzy_compare config option for EntityRuler and SpanRuler

This commit is contained in:
Adriane Boyd 2022-11-29 11:34:03 +01:00
parent 561adacb6e
commit d1628df277
4 changed files with 60 additions and 31 deletions

View File

@ -1,4 +1,4 @@
# cython: infer_types=True, profile=True
# cython: binding=True, infer_types=True, profile=True
from typing import List, Iterable
from libcpp.vector cimport vector
@ -24,6 +24,7 @@ from ..schemas import validate_token_pattern
from ..errors import Errors, MatchPatternError, Warnings
from ..strings import get_string_id
from ..attrs import IDS
from ..util import registry
from .levenshtein import levenshtein
@ -31,6 +32,21 @@ from .levenshtein import levenshtein
DEF PADDING = 5
cpdef bint _default_fuzzy_compare(s1: str, s2: str, fuzzy: int = -1):
distance = min(len(s1), len(s2))
distance -= 1 # don't allow completely different tokens
if fuzzy == -1: # FUZZY operator with unspecified fuzzy
fuzzy = 5 # default max fuzzy
distance -= 1 # be more restrictive
distance = min(fuzzy, distance if distance > 0 else 1)
return levenshtein(s1, s2, distance) <= distance
@registry.misc("spacy.fuzzy_compare.v1")
def make_fuzzy_compare():
return _default_fuzzy_compare
cdef class Matcher:
"""Match sequences of tokens, based on pattern rules.
@ -1148,13 +1164,3 @@ def _get_extensions(spec, string_store, name2index):
name2index[name] = len(name2index)
attr_values.append((name2index[name], value))
return attr_values
cpdef bint _default_fuzzy_compare(s1: str, s2: str, fuzzy: int = -1):
distance = min(len(s1), len(s2))
distance -= 1 # don't allow completely different tokens
if fuzzy == -1: # FUZZY operator with unspecified fuzzy
fuzzy = 5 # default max fuzzy
distance -= 1 # be more restrictive
distance = min(fuzzy, distance if distance > 0 else 1)
return levenshtein(s1, s2, distance) <= distance

View File

@ -27,6 +27,7 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
"overwrite_ents": False,
"ent_id_sep": DEFAULT_ENT_ID_SEP,
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
"fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"},
},
default_score_weights={
"ents_f": 1.0,
@ -43,6 +44,7 @@ def make_entity_ruler(
overwrite_ents: bool,
ent_id_sep: str,
scorer: Optional[Callable],
fuzzy_compare: Callable,
):
return EntityRuler(
nlp,
@ -52,6 +54,7 @@ def make_entity_ruler(
overwrite_ents=overwrite_ents,
ent_id_sep=ent_id_sep,
scorer=scorer,
fuzzy_compare=fuzzy_compare,
)
@ -86,7 +89,7 @@ class EntityRuler(Pipe):
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
patterns: Optional[List[PatternType]] = None,
scorer: Optional[Callable] = entity_ruler_score,
fuzzy_compare: Optional[Callable] = _default_fuzzy_compare
fuzzy_compare: Callable = _default_fuzzy_compare,
) -> None:
"""Initialize the entity ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"`
@ -109,6 +112,7 @@ class EntityRuler(Pipe):
ent_id_sep (str): Separator used internally for entity IDs.
scorer (Optional[Callable]): The scoring method. Defaults to
spacy.scorer.get_ner_prf.
fuzzy_compare (Callable): The fuzzy comparison method.
DOCS: https://spacy.io/api/entityruler#init
"""
@ -119,7 +123,9 @@ class EntityRuler(Pipe):
self.phrase_patterns = defaultdict(list) # type: ignore
self._validate = validate
self._fuzzy_compare = fuzzy_compare
self.matcher = Matcher(nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare)
self.matcher = Matcher(
nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare
)
self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
@ -129,6 +135,7 @@ class EntityRuler(Pipe):
if patterns is not None:
self.add_patterns(patterns)
self.scorer = scorer
self.fuzzy_compare = fuzzy_compare
def __len__(self) -> int:
"""The number of all patterns added to the entity ruler."""
@ -339,8 +346,9 @@ class EntityRuler(Pipe):
self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(tuple)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate,
fuzzy_compare=self._fuzzy_compare)
self.matcher = Matcher(
self.nlp.vocab, validate=self._validate, fuzzy_compare=self._fuzzy_compare
)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
)
@ -434,7 +442,7 @@ class EntityRuler(Pipe):
self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr
self.nlp.vocab, attr=self.phrase_matcher_attr,
)
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
else:

View File

@ -12,7 +12,7 @@ from ..errors import Errors, Warnings
from ..util import ensure_path, SimpleFrozenList, registry
from ..tokens import Doc, Span
from ..scorer import Scorer
from ..matcher import Matcher, PhraseMatcher
from ..matcher import Matcher, PhraseMatcher, _default_fuzzy_compare
from .. import util
PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
@ -28,6 +28,7 @@ DEFAULT_SPANS_KEY = "ruler"
"overwrite_ents": False,
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
"ent_id_sep": "__unused__",
"fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"},
},
default_score_weights={
"ents_f": 1.0,
@ -44,6 +45,7 @@ def make_entity_ruler(
overwrite_ents: bool,
scorer: Optional[Callable],
ent_id_sep: str,
fuzzy_compare: Callable,
):
if overwrite_ents:
ents_filter = prioritize_new_ents_filter
@ -60,6 +62,7 @@ def make_entity_ruler(
validate=validate,
overwrite=False,
scorer=scorer,
fuzzy_compare=fuzzy_compare,
)
@ -78,6 +81,7 @@ def make_entity_ruler(
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
"spans_key": DEFAULT_SPANS_KEY,
},
"fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"},
},
default_score_weights={
f"spans_{DEFAULT_SPANS_KEY}_f": 1.0,
@ -97,6 +101,7 @@ def make_span_ruler(
validate: bool,
overwrite: bool,
scorer: Optional[Callable],
fuzzy_compare: Callable,
):
return SpanRuler(
nlp,
@ -109,6 +114,7 @@ def make_span_ruler(
validate=validate,
overwrite=overwrite,
scorer=scorer,
fuzzy_compare=fuzzy_compare,
)
@ -221,6 +227,7 @@ class SpanRuler(Pipe):
scorer: Optional[Callable] = partial(
overlapping_labeled_spans_score, spans_key=DEFAULT_SPANS_KEY
),
fuzzy_compare: Callable = _default_fuzzy_compare,
) -> None:
"""Initialize the span ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"`
@ -253,6 +260,7 @@ class SpanRuler(Pipe):
`annotate_ents` is set. Defaults to `True`.
scorer (Optional[Callable]): The scoring method. Defaults to
spacy.pipeline.span_ruler.overlapping_labeled_spans_score.
fuzzy_compare (Callable): The default fuzzy comparison method.
DOCS: https://spacy.io/api/spanruler#init
"""
@ -266,6 +274,7 @@ class SpanRuler(Pipe):
self.spans_filter = spans_filter
self.ents_filter = ents_filter
self.scorer = scorer
self.fuzzy_compare = fuzzy_compare
self._match_label_id_map: Dict[int, Dict[str, str]] = {}
self.clear()
@ -451,7 +460,11 @@ class SpanRuler(Pipe):
DOCS: https://spacy.io/api/spanruler#clear
"""
self._patterns: List[PatternType] = []
self.matcher: Matcher = Matcher(self.nlp.vocab, validate=self.validate)
self.matcher: Matcher = Matcher(
self.nlp.vocab,
validate=self.validate,
fuzzy_compare=self.fuzzy_compare,
)
self.phrase_matcher: PhraseMatcher = PhraseMatcher(
self.nlp.vocab,
attr=self.phrase_matcher_attr,

View File

@ -385,9 +385,7 @@ def test_entity_ruler_overlapping_spans(nlp, entity_ruler_factory):
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory):
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
patterns = [
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
]
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
ruler.add_patterns(patterns)
doc = nlp("helloo")
assert len(doc.ents) == 1
@ -396,24 +394,28 @@ def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory):
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_fuzzy(nlp, entity_ruler_factory):
patterns = [
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
]
ruler = EntityRuler(nlp, patterns=patterns)
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
ruler.add_patterns(patterns)
doc = nlp("helloo")
ruler(doc)
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "HELLO"
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
def test_entity_ruler_fuzzy_disabled(nlp, entity_ruler_factory):
patterns = [
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
]
ruler = EntityRuler(nlp, patterns=patterns, fuzzy_compare=lambda x, y, z: False)
@registry.misc("test_fuzzy_compare_disabled")
def make_test_fuzzy_compare_disabled():
return lambda x, y, z: False
ruler = nlp.add_pipe(
entity_ruler_factory,
name="entity_ruler",
config={"fuzzy_compare": {"@misc": "test_fuzzy_compare_disabled"}},
)
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
ruler.add_patterns(patterns)
doc = nlp("helloo")
ruler(doc)
assert len(doc.ents) == 0