mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-05 12:50:20 +03:00
Implement fuzzy_compare config option for EntityRuler and SpanRuler
This commit is contained in:
parent
561adacb6e
commit
d1628df277
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
# cython: binding=True, infer_types=True, profile=True
|
||||
from typing import List, Iterable
|
||||
|
||||
from libcpp.vector cimport vector
|
||||
|
@ -24,6 +24,7 @@ from ..schemas import validate_token_pattern
|
|||
from ..errors import Errors, MatchPatternError, Warnings
|
||||
from ..strings import get_string_id
|
||||
from ..attrs import IDS
|
||||
from ..util import registry
|
||||
|
||||
from .levenshtein import levenshtein
|
||||
|
||||
|
@ -31,6 +32,21 @@ from .levenshtein import levenshtein
|
|||
DEF PADDING = 5
|
||||
|
||||
|
||||
cpdef bint _default_fuzzy_compare(s1: str, s2: str, fuzzy: int = -1):
|
||||
distance = min(len(s1), len(s2))
|
||||
distance -= 1 # don't allow completely different tokens
|
||||
if fuzzy == -1: # FUZZY operator with unspecified fuzzy
|
||||
fuzzy = 5 # default max fuzzy
|
||||
distance -= 1 # be more restrictive
|
||||
distance = min(fuzzy, distance if distance > 0 else 1)
|
||||
return levenshtein(s1, s2, distance) <= distance
|
||||
|
||||
|
||||
@registry.misc("spacy.fuzzy_compare.v1")
|
||||
def make_fuzzy_compare():
|
||||
return _default_fuzzy_compare
|
||||
|
||||
|
||||
cdef class Matcher:
|
||||
"""Match sequences of tokens, based on pattern rules.
|
||||
|
||||
|
@ -1148,13 +1164,3 @@ def _get_extensions(spec, string_store, name2index):
|
|||
name2index[name] = len(name2index)
|
||||
attr_values.append((name2index[name], value))
|
||||
return attr_values
|
||||
|
||||
|
||||
cpdef bint _default_fuzzy_compare(s1: str, s2: str, fuzzy: int = -1):
|
||||
distance = min(len(s1), len(s2))
|
||||
distance -= 1 # don't allow completely different tokens
|
||||
if fuzzy == -1: # FUZZY operator with unspecified fuzzy
|
||||
fuzzy = 5 # default max fuzzy
|
||||
distance -= 1 # be more restrictive
|
||||
distance = min(fuzzy, distance if distance > 0 else 1)
|
||||
return levenshtein(s1, s2, distance) <= distance
|
||||
|
|
|
@ -27,6 +27,7 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
|
|||
"overwrite_ents": False,
|
||||
"ent_id_sep": DEFAULT_ENT_ID_SEP,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
"fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
|
@ -43,6 +44,7 @@ def make_entity_ruler(
|
|||
overwrite_ents: bool,
|
||||
ent_id_sep: str,
|
||||
scorer: Optional[Callable],
|
||||
fuzzy_compare: Callable,
|
||||
):
|
||||
return EntityRuler(
|
||||
nlp,
|
||||
|
@ -52,6 +54,7 @@ def make_entity_ruler(
|
|||
overwrite_ents=overwrite_ents,
|
||||
ent_id_sep=ent_id_sep,
|
||||
scorer=scorer,
|
||||
fuzzy_compare=fuzzy_compare,
|
||||
)
|
||||
|
||||
|
||||
|
@ -86,7 +89,7 @@ class EntityRuler(Pipe):
|
|||
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
|
||||
patterns: Optional[List[PatternType]] = None,
|
||||
scorer: Optional[Callable] = entity_ruler_score,
|
||||
fuzzy_compare: Optional[Callable] = _default_fuzzy_compare
|
||||
fuzzy_compare: Callable = _default_fuzzy_compare,
|
||||
) -> None:
|
||||
"""Initialize the entity ruler. If patterns are supplied here, they
|
||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||
|
@ -109,6 +112,7 @@ class EntityRuler(Pipe):
|
|||
ent_id_sep (str): Separator used internally for entity IDs.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
spacy.scorer.get_ner_prf.
|
||||
fuzzy_compare (Callable): The fuzzy comparison method.
|
||||
|
||||
DOCS: https://spacy.io/api/entityruler#init
|
||||
"""
|
||||
|
@ -119,7 +123,9 @@ class EntityRuler(Pipe):
|
|||
self.phrase_patterns = defaultdict(list) # type: ignore
|
||||
self._validate = validate
|
||||
self._fuzzy_compare = fuzzy_compare
|
||||
self.matcher = Matcher(nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare)
|
||||
self.matcher = Matcher(
|
||||
nlp.vocab, validate=validate, fuzzy_compare=fuzzy_compare
|
||||
)
|
||||
self.phrase_matcher_attr = phrase_matcher_attr
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
|
||||
|
@ -129,6 +135,7 @@ class EntityRuler(Pipe):
|
|||
if patterns is not None:
|
||||
self.add_patterns(patterns)
|
||||
self.scorer = scorer
|
||||
self.fuzzy_compare = fuzzy_compare
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of all patterns added to the entity ruler."""
|
||||
|
@ -339,8 +346,9 @@ class EntityRuler(Pipe):
|
|||
self.token_patterns = defaultdict(list)
|
||||
self.phrase_patterns = defaultdict(list)
|
||||
self._ent_ids = defaultdict(tuple)
|
||||
self.matcher = Matcher(self.nlp.vocab, validate=self._validate,
|
||||
fuzzy_compare=self._fuzzy_compare)
|
||||
self.matcher = Matcher(
|
||||
self.nlp.vocab, validate=self._validate, fuzzy_compare=self._fuzzy_compare
|
||||
)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
|
||||
)
|
||||
|
@ -434,7 +442,7 @@ class EntityRuler(Pipe):
|
|||
self.overwrite = cfg.get("overwrite", False)
|
||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr,
|
||||
)
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
else:
|
||||
|
|
|
@ -12,7 +12,7 @@ from ..errors import Errors, Warnings
|
|||
from ..util import ensure_path, SimpleFrozenList, registry
|
||||
from ..tokens import Doc, Span
|
||||
from ..scorer import Scorer
|
||||
from ..matcher import Matcher, PhraseMatcher
|
||||
from ..matcher import Matcher, PhraseMatcher, _default_fuzzy_compare
|
||||
from .. import util
|
||||
|
||||
PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
|
||||
|
@ -28,6 +28,7 @@ DEFAULT_SPANS_KEY = "ruler"
|
|||
"overwrite_ents": False,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
"ent_id_sep": "__unused__",
|
||||
"fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
|
@ -44,6 +45,7 @@ def make_entity_ruler(
|
|||
overwrite_ents: bool,
|
||||
scorer: Optional[Callable],
|
||||
ent_id_sep: str,
|
||||
fuzzy_compare: Callable,
|
||||
):
|
||||
if overwrite_ents:
|
||||
ents_filter = prioritize_new_ents_filter
|
||||
|
@ -60,6 +62,7 @@ def make_entity_ruler(
|
|||
validate=validate,
|
||||
overwrite=False,
|
||||
scorer=scorer,
|
||||
fuzzy_compare=fuzzy_compare,
|
||||
)
|
||||
|
||||
|
||||
|
@ -78,6 +81,7 @@ def make_entity_ruler(
|
|||
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
|
||||
"spans_key": DEFAULT_SPANS_KEY,
|
||||
},
|
||||
"fuzzy_compare": {"@misc": "spacy.fuzzy_compare.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
f"spans_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
|
@ -97,6 +101,7 @@ def make_span_ruler(
|
|||
validate: bool,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
fuzzy_compare: Callable,
|
||||
):
|
||||
return SpanRuler(
|
||||
nlp,
|
||||
|
@ -109,6 +114,7 @@ def make_span_ruler(
|
|||
validate=validate,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
fuzzy_compare=fuzzy_compare,
|
||||
)
|
||||
|
||||
|
||||
|
@ -221,6 +227,7 @@ class SpanRuler(Pipe):
|
|||
scorer: Optional[Callable] = partial(
|
||||
overlapping_labeled_spans_score, spans_key=DEFAULT_SPANS_KEY
|
||||
),
|
||||
fuzzy_compare: Callable = _default_fuzzy_compare,
|
||||
) -> None:
|
||||
"""Initialize the span ruler. If patterns are supplied here, they
|
||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||
|
@ -253,6 +260,7 @@ class SpanRuler(Pipe):
|
|||
`annotate_ents` is set. Defaults to `True`.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
spacy.pipeline.span_ruler.overlapping_labeled_spans_score.
|
||||
fuzzy_compare (Callable): The default fuzzy comparison method.
|
||||
|
||||
DOCS: https://spacy.io/api/spanruler#init
|
||||
"""
|
||||
|
@ -266,6 +274,7 @@ class SpanRuler(Pipe):
|
|||
self.spans_filter = spans_filter
|
||||
self.ents_filter = ents_filter
|
||||
self.scorer = scorer
|
||||
self.fuzzy_compare = fuzzy_compare
|
||||
self._match_label_id_map: Dict[int, Dict[str, str]] = {}
|
||||
self.clear()
|
||||
|
||||
|
@ -451,7 +460,11 @@ class SpanRuler(Pipe):
|
|||
DOCS: https://spacy.io/api/spanruler#clear
|
||||
"""
|
||||
self._patterns: List[PatternType] = []
|
||||
self.matcher: Matcher = Matcher(self.nlp.vocab, validate=self.validate)
|
||||
self.matcher: Matcher = Matcher(
|
||||
self.nlp.vocab,
|
||||
validate=self.validate,
|
||||
fuzzy_compare=self.fuzzy_compare,
|
||||
)
|
||||
self.phrase_matcher: PhraseMatcher = PhraseMatcher(
|
||||
self.nlp.vocab,
|
||||
attr=self.phrase_matcher_attr,
|
||||
|
|
|
@ -385,9 +385,7 @@ def test_entity_ruler_overlapping_spans(nlp, entity_ruler_factory):
|
|||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory):
|
||||
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
|
||||
patterns = [
|
||||
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
|
||||
]
|
||||
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("helloo")
|
||||
assert len(doc.ents) == 1
|
||||
|
@ -396,24 +394,28 @@ def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory):
|
|||
|
||||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_fuzzy(nlp, entity_ruler_factory):
|
||||
patterns = [
|
||||
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
|
||||
]
|
||||
ruler = EntityRuler(nlp, patterns=patterns)
|
||||
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
|
||||
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("helloo")
|
||||
ruler(doc)
|
||||
assert len(doc.ents) == 1
|
||||
assert doc.ents[0].label_ == "HELLO"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_fuzzy_disabled(nlp, entity_ruler_factory):
|
||||
patterns = [
|
||||
{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}
|
||||
]
|
||||
ruler = EntityRuler(nlp, patterns=patterns, fuzzy_compare=lambda x, y, z: False)
|
||||
@registry.misc("test_fuzzy_compare_disabled")
|
||||
def make_test_fuzzy_compare_disabled():
|
||||
return lambda x, y, z: False
|
||||
|
||||
ruler = nlp.add_pipe(
|
||||
entity_ruler_factory,
|
||||
name="entity_ruler",
|
||||
config={"fuzzy_compare": {"@misc": "test_fuzzy_compare_disabled"}},
|
||||
)
|
||||
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("helloo")
|
||||
ruler(doc)
|
||||
assert len(doc.ents) == 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user