switch to FUZZYn predicates

use Levenshtein distance.
remove fuzzy param.
remove rapidfuzz_capi.
This commit is contained in:
Kevin Humphreys 2022-08-29 18:10:42 +02:00
parent ecd0455acd
commit 43948f731b
9 changed files with 93 additions and 176 deletions

View File

@ -8,6 +8,5 @@ requires = [
"thinc>=8.1.0,<8.2.0",
"pathy",
"numpy>=1.15.0",
"rapidfuzz_capi>=1.0.5,<2.0.0",
]
build-backend = "setuptools.build_meta"

View File

@ -19,7 +19,6 @@ pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
jinja2
langcodes>=3.2.0,<4.0.0
rapidfuzz>=2.4.0,<3.0.0
rapidfuzz_capi>=1.0.5,<2.0.0
# Official Python utilities
setuptools
packaging>=20.0

View File

@ -34,7 +34,6 @@ python_requires = >=3.6
setup_requires =
cython>=0.25,<3.0
numpy>=1.15.0
rapidfuzz_capi>=1.0.5,<2.0.0
# We also need our Cython packages here to compile against
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
@ -60,7 +59,6 @@ install_requires =
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
jinja2
rapidfuzz>=2.4.0,<3.0.0
rapidfuzz_capi>=1.0.5,<2.0.0
# Official Python utilities
setuptools
packaging>=20.0

View File

@ -71,8 +71,6 @@ cdef class Matcher:
cdef vector[TokenPatternC*] patterns
cdef readonly Vocab vocab
cdef public object validate
cdef public object fuzzy
cdef public object fuzzy_attrs
cdef public object _patterns
cdef public object _callbacks
cdef public object _filter

View File

@ -5,8 +5,7 @@ from ..vocab import Vocab
from ..tokens import Doc, Span
class Matcher:
def __init__(self, vocab: Vocab, validate: bool = ...,
fuzzy: float = ..., fuzzy_attrs: list = ...) -> None: ...
def __init__(self, vocab: Vocab, validate: bool = ...) -> None: ...
def __reduce__(self) -> Any: ...
def __len__(self) -> int: ...
def __contains__(self, key: str) -> bool: ...

View File

@ -10,7 +10,7 @@ from murmurhash.mrmr cimport hash64
import re
import srsly
import warnings
from rapidfuzz import fuzz_cpp
from rapidfuzz.distance import Levenshtein
from ..typedefs cimport attr_t
from ..structs cimport TokenC
@ -37,7 +37,7 @@ cdef class Matcher:
USAGE: https://spacy.io/usage/rule-based-matching
"""
def __init__(self, vocab, validate=True, fuzzy=None, fuzzy_attrs=None):
def __init__(self, vocab, validate=True):
"""Create the Matcher.
vocab (Vocab): The vocabulary object, which must be shared with the
@ -52,8 +52,6 @@ cdef class Matcher:
self.vocab = vocab
self.mem = Pool()
self.validate = validate
self.fuzzy = fuzzy if fuzzy is not None else 0
self.fuzzy_attrs = [IDS.get(attr) for attr in fuzzy_attrs] if fuzzy_attrs else []
def __reduce__(self):
data = (self.vocab, self._patterns, self._callbacks)
@ -131,8 +129,7 @@ cdef class Matcher:
for pattern in patterns:
try:
specs = _preprocess_pattern(pattern, self.vocab,
self._extensions, self._extra_predicates,
self.fuzzy, self.fuzzy_attrs)
self._extensions, self._extra_predicates)
self.patterns.push_back(init_pattern(self.mem, key, specs))
for spec in specs:
for attr, _ in spec[1]:
@ -257,8 +254,7 @@ cdef class Matcher:
matches = []
else:
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments,
fuzzy=self.fuzzy, fuzzy_attrs=self.fuzzy_attrs)
extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments)
final_matches = []
pairs_by_id = {}
# For each key, either add all matches, or only the filtered,
@ -339,8 +335,7 @@ def unpickle_matcher(vocab, patterns, callbacks):
return matcher
cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, extensions=None, predicates=tuple(), bint with_alignments=0,
float fuzzy=0, list fuzzy_attrs=[]):
cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, extensions=None, predicates=tuple(), bint with_alignments=0):
"""Find matches in a doc, with a compiled array of patterns. Matches are
returned as a list of (id, start, end) tuples or (id, start, end, alignments) tuples (if with_alignments != 0)
@ -359,8 +354,6 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
cdef PatternStateC state
cdef int i, j, nr_extra_attr
cdef Pool mem = Pool()
cdef int8_t* fuzzy_attrs_array
cdef int n_fuzzy_attrs = len(fuzzy_attrs)
output = []
if length == 0:
@ -380,10 +373,6 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
if isinstance(value, str):
value = token.vocab.strings[value]
extra_attr_values[i * nr_extra_attr + index] = value
if n_fuzzy_attrs > 0:
fuzzy_attrs_array = <int8_t*>mem.alloc(n_fuzzy_attrs, sizeof(int8_t))
for i in range(n_fuzzy_attrs):
fuzzy_attrs_array[i] = fuzzy_attrs[i]
# Main loop
cdef int nr_predicate = len(predicates)
for i in range(length):
@ -392,8 +381,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
if with_alignments != 0:
align_states.resize(states.size())
transition_states(states, matches, align_states, align_matches, predicate_cache,
doclike[i], extra_attr_values, predicates, with_alignments,
fuzzy, fuzzy_attrs_array, n_fuzzy_attrs)
doclike[i], extra_attr_values, predicates, with_alignments)
extra_attr_values += nr_extra_attr
predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns
@ -422,8 +410,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches,
int8_t* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments,
float fuzzy, int8_t* fuzzy_attrs, int n_fuzzy_attrs) except *:
Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments) except *:
cdef int q = 0
cdef vector[PatternStateC] new_states
cdef vector[vector[MatchAlignmentC]] align_new_states
@ -433,8 +420,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
update_predicate_cache(cached_py_predicates,
states[i].pattern, token, py_predicates)
action = get_action(states[i], token, extra_attrs,
cached_py_predicates,
fuzzy, fuzzy_attrs, n_fuzzy_attrs)
cached_py_predicates)
if action == REJECT:
continue
# Keep only a subset of states (the active ones). Index q is the
@ -471,8 +457,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
update_predicate_cache(cached_py_predicates,
states[q].pattern, token, py_predicates)
action = get_action(states[q], token, extra_attrs,
cached_py_predicates,
fuzzy, fuzzy_attrs, n_fuzzy_attrs)
cached_py_predicates)
# Update alignment before the transition of current state
if with_alignments != 0:
align_states[q].push_back(MatchAlignmentC(states[q].pattern.token_idx, states[q].length))
@ -584,8 +569,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
cdef action_t get_action(PatternStateC state,
Token token, const attr_t* extra_attrs,
const int8_t* predicate_matches,
float fuzzy, int8_t* fuzzy_attrs, int n_fuzzy_attrs) nogil:
const int8_t* predicate_matches) nogil:
"""We need to consider:
a) Does the token match the specification? [Yes, No]
b) What's the quantifier? [1, 0+, ?]
@ -644,8 +628,7 @@ cdef action_t get_action(PatternStateC state,
Problem: If a quantifier is matching, we're adding a lot of open partials
"""
cdef int8_t is_match
is_match = get_is_match(state, token, extra_attrs, predicate_matches,
fuzzy, fuzzy_attrs, n_fuzzy_attrs)
is_match = get_is_match(state, token, extra_attrs, predicate_matches)
quantifier = get_quantifier(state)
is_final = get_is_final(state)
if quantifier == ZERO:
@ -698,8 +681,7 @@ cdef action_t get_action(PatternStateC state,
cdef int8_t get_is_match(PatternStateC state,
Token token, const attr_t* extra_attrs,
const int8_t* predicate_matches,
float fuzzy, int8_t* fuzzy_attrs, int n_fuzzy_attrs) nogil:
const int8_t* predicate_matches) nogil:
for i in range(state.pattern.nr_py):
if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0
@ -708,22 +690,9 @@ cdef int8_t get_is_match(PatternStateC state,
for attr in spec.attrs[:spec.nr_attr]:
token_attr_value = get_token_attr_for_matcher(token.c, attr.attr)
if token_attr_value != attr.value:
if fuzzy:
fuzzy_match = False
for i in range(n_fuzzy_attrs):
if attr.attr == fuzzy_attrs[i]:
with gil:
if fuzz_cpp.ratio(token.vocab.strings[token_attr_value],
token.vocab.strings[attr.value]) >= fuzzy:
fuzzy_match = True
break
if not fuzzy_match:
return 0
else:
return 0
return 0
for i in range(spec.nr_extra_attr):
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
# TODO: fuzzy match
return 0
return True
@ -788,8 +757,7 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
return id_attr.value
def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates,
fuzzy, fuzzy_attrs):
def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates):
"""This function interprets the pattern, converting the various bits of
syntactic sugar before we compile it into a struct with init_pattern.
@ -816,7 +784,7 @@ def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates,
ops = _get_operators(spec)
attr_values = _get_attr_values(spec, string_store)
extensions = _get_extensions(spec, string_store, extensions_table)
predicates = _get_extra_predicates(spec, extra_predicates, vocab, fuzzy, fuzzy_attrs)
predicates = _get_extra_predicates(spec, extra_predicates, vocab)
for op in ops:
tokens.append((op, list(attr_values), list(extensions), list(predicates), token_idx))
return tokens
@ -862,31 +830,31 @@ def _get_attr_values(spec, string_store):
# extensions to the matcher introduced in #3173.
class _FuzzyPredicate:
operators = ("FUZZY",)
operators = ("FUZZY1", "FUZZY2", "FUZZY3", "FUZZY4", "FUZZY5")
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, fuzzy=None):
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None):
self.i = i
self.attr = attr
self.value = value
self.predicate = predicate
self.is_extension = is_extension
self.fuzzy = fuzzy
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
self.distance = int(self.predicate[len('FUZZY'):]) # number after prefix
def __call__(self, Token token):
if self.is_extension:
value = token._.get(self.attr)
else:
value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)]
return bool(self.fuzzy and fuzz_cpp.ratio(self.value, value) >= self.fuzzy)
return bool(Levenshtein.distance(self.value, value) <= self.distance)
class _RegexPredicate:
operators = ("REGEX",)
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, fuzzy=None):
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None):
self.i = i
self.attr = attr
self.value = re.compile(value)
@ -907,22 +875,22 @@ class _RegexPredicate:
class _SetPredicate:
operators = ("IN", "NOT_IN", "IS_SUBSET", "IS_SUPERSET", "INTERSECTS")
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, fuzzy=None):
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None):
self.i = i
self.attr = attr
self.vocab = vocab
self.distance = distance
if self.attr == MORPH:
# normalize morph strings
self.value = set(self.vocab.morphology.add(v) for v in value)
else:
if fuzzy:
if self.distance:
# add to string store
self.value = set(self.vocab.strings.add(v) for v in value)
else:
self.value = set(get_string_id(v) for v in value)
self.predicate = predicate
self.is_extension = is_extension
self.fuzzy = fuzzy
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
@ -946,19 +914,19 @@ class _SetPredicate:
if self.predicate == "IN":
if value in self.value:
return True
elif self.fuzzy:
elif self.distance:
for v in self.value:
if fuzz_cpp.ratio(self.vocab.strings[value],
self.vocab.strings[v]) >= self.fuzzy:
if Levenshtein.distance(self.vocab.strings[value],
self.vocab.strings[v]) <= self.distance:
return True
return False
elif self.predicate == "NOT_IN":
if value in self.value:
return False
elif self.fuzzy:
elif self.distance:
for v in self.value:
if fuzz_cpp.ratio(self.vocab.strings[value],
self.vocab.strings[v]) >= self.fuzzy:
if Levenshtein.distance(self.vocab.strings[value],
self.vocab.strings[v]) <= self.distance:
return False
return True
elif self.predicate == "IS_SUBSET":
@ -975,7 +943,7 @@ class _SetPredicate:
class _ComparisonPredicate:
operators = ("==", "!=", ">=", "<=", ">", "<")
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, fuzzy=None):
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None):
self.i = i
self.attr = attr
self.value = value
@ -1004,7 +972,7 @@ class _ComparisonPredicate:
return value < self.value
def _get_extra_predicates(spec, extra_predicates, vocab, fuzzy, fuzzy_attrs):
def _get_extra_predicates(spec, extra_predicates, vocab):
predicate_types = {
"REGEX": _RegexPredicate,
"IN": _SetPredicate,
@ -1018,7 +986,11 @@ def _get_extra_predicates(spec, extra_predicates, vocab, fuzzy, fuzzy_attrs):
"<=": _ComparisonPredicate,
">": _ComparisonPredicate,
"<": _ComparisonPredicate,
"FUZZY": _FuzzyPredicate,
"FUZZY1": _FuzzyPredicate,
"FUZZY2": _FuzzyPredicate,
"FUZZY3": _FuzzyPredicate,
"FUZZY4": _FuzzyPredicate,
"FUZZY5": _FuzzyPredicate,
}
seen_predicates = {pred.key: pred.i for pred in extra_predicates}
output = []
@ -1037,33 +1009,30 @@ def _get_extra_predicates(spec, extra_predicates, vocab, fuzzy, fuzzy_attrs):
attr = IDS.get(attr.upper())
if isinstance(value, dict):
fuzzy_match = attr in fuzzy_attrs # fuzzy match enabled for this attr
output.extend(_get_extra_predicates_dict(attr, value, vocab, fuzzy, fuzzy_match,
predicate_types,
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
extra_predicates, seen_predicates))
return output
def _get_extra_predicates_dict(attr, value_dict, vocab, fuzzy, fuzzy_match,
predicate_types, extra_predicates, seen_predicates):
def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
extra_predicates, seen_predicates, distance=None):
output = []
for type_, value in value_dict.items():
type_ = type_.upper()
if type_ == 'FUZZY':
fuzzy_match = True # explicit fuzzy match
if isinstance(value, dict):
# add predicates inside fuzzy operator
output.extend(_get_extra_predicates_dict(attr, value, vocab, fuzzy, fuzzy_match,
predicate_types,
extra_predicates, seen_predicates))
continue
cls = predicate_types.get(type_)
if cls is None:
warnings.warn(Warnings.W035.format(pattern=value_dict))
# ignore unrecognized predicate type
continue
predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab,
fuzzy=fuzzy if fuzzy_match else 0)
elif cls == _FuzzyPredicate:
distance = int(type_[len("FUZZY"):]) # number after prefix
if isinstance(value, dict):
# add predicates inside fuzzy operator
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
extra_predicates, seen_predicates,
distance=distance))
continue
predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab, distance=distance)
# Don't create a redundant predicates.
# This helps with efficiency, as we're caching the results.
if predicate.key in seen_predicates:

View File

@ -26,7 +26,6 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
"phrase_matcher_attr": None,
"validate": False,
"overwrite_ents": False,
"fuzzy": 0.0,
"ent_id_sep": DEFAULT_ENT_ID_SEP,
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
},
@ -43,7 +42,6 @@ def make_entity_ruler(
phrase_matcher_attr: Optional[Union[int, str]],
validate: bool,
overwrite_ents: bool,
fuzzy: float,
ent_id_sep: str,
scorer: Optional[Callable],
):
@ -53,7 +51,6 @@ def make_entity_ruler(
phrase_matcher_attr=phrase_matcher_attr,
validate=validate,
overwrite_ents=overwrite_ents,
fuzzy=fuzzy,
ent_id_sep=ent_id_sep,
scorer=scorer,
)
@ -87,7 +84,6 @@ class EntityRuler(Pipe):
phrase_matcher_attr: Optional[Union[int, str]] = None,
validate: bool = False,
overwrite_ents: bool = False,
fuzzy: float = 0,
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
patterns: Optional[List[PatternType]] = None,
scorer: Optional[Callable] = entity_ruler_score,
@ -122,8 +118,7 @@ class EntityRuler(Pipe):
self.token_patterns = defaultdict(list) # type: ignore
self.phrase_patterns = defaultdict(list) # type: ignore
self._validate = validate
self.fuzzy = fuzzy
self.matcher = Matcher(nlp.vocab, validate=validate, fuzzy=self.fuzzy)
self.matcher = Matcher(nlp.vocab, validate=validate)
self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
@ -343,7 +338,7 @@ class EntityRuler(Pipe):
self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(tuple)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate, fuzzy=self.fuzzy)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
)

View File

@ -157,12 +157,16 @@ def validate_token_pattern(obj: list) -> List[str]:
class TokenPatternString(BaseModel):
REGEX: Optional[StrictStr] = Field(None, alias="regex")
FUZZY: Union[StrictStr, "TokenPatternString"] = Field(None, alias="fuzzy")
IN: Optional[List[StrictStr]] = Field(None, alias="in")
NOT_IN: Optional[List[StrictStr]] = Field(None, alias="not_in")
IS_SUBSET: Optional[List[StrictStr]] = Field(None, alias="is_subset")
IS_SUPERSET: Optional[List[StrictStr]] = Field(None, alias="is_superset")
INTERSECTS: Optional[List[StrictStr]] = Field(None, alias="intersects")
FUZZY1: Union[StrictStr, "TokenPatternString"] = Field(None, alias="fuzzy1")
FUZZY2: Union[StrictStr, "TokenPatternString"] = Field(None, alias="fuzzy2")
FUZZY3: Union[StrictStr, "TokenPatternString"] = Field(None, alias="fuzzy3")
FUZZY4: Union[StrictStr, "TokenPatternString"] = Field(None, alias="fuzzy4")
FUZZY5: Union[StrictStr, "TokenPatternString"] = Field(None, alias="fuzzy5")
class Config:
extra = "forbid"
@ -177,7 +181,6 @@ class TokenPatternString(BaseModel):
class TokenPatternNumber(BaseModel):
REGEX: Optional[StrictStr] = Field(None, alias="regex")
FUZZY: Optional[StrictStr] = Field(None, alias="fuzzy")
IN: Optional[List[StrictInt]] = Field(None, alias="in")
NOT_IN: Optional[List[StrictInt]] = Field(None, alias="not_in")
IS_SUBSET: Optional[List[StrictInt]] = Field(None, alias="is_subset")
@ -189,6 +192,11 @@ class TokenPatternNumber(BaseModel):
LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=")
GT: Union[StrictInt, StrictFloat] = Field(None, alias=">")
LT: Union[StrictInt, StrictFloat] = Field(None, alias="<")
FUZZY1: Optional[StrictStr] = Field(None, alias="fuzzy1")
FUZZY2: Optional[StrictStr] = Field(None, alias="fuzzy2")
FUZZY3: Optional[StrictStr] = Field(None, alias="fuzzy3")
FUZZY4: Optional[StrictStr] = Field(None, alias="fuzzy4")
FUZZY5: Optional[StrictStr] = Field(None, alias="fuzzy5")
class Config:
extra = "forbid"

View File

@ -6,16 +6,15 @@ from spacy.tokens import Doc, Token, Span
from ..doc.test_underscore import clean_underscore # noqa: F401
matcher_rules = {
"JS": [[{"ORTH": "JavaScript"}]],
"GoogleNow": [[{"ORTH": "Google"}, {"ORTH": "Now"}]],
"Java": [[{"LOWER": "java"}]],
}
@pytest.fixture
def matcher(en_vocab):
rules = {
"JS": [[{"ORTH": "JavaScript"}]],
"GoogleNow": [[{"ORTH": "Google"}, {"ORTH": "Now"}]],
"Java": [[{"LOWER": "java"}]],
}
matcher = Matcher(en_vocab)
for key, patterns in matcher_rules.items():
for key, patterns in rules.items():
matcher.add(key, patterns)
return matcher
@ -119,98 +118,51 @@ def test_matcher_match_multi(matcher):
]
# fuzzy matches on specific attributes
def test_matcher_match_fuzz_all(en_vocab):
matcher = Matcher(en_vocab, fuzzy=80, fuzzy_attrs=["ORTH", "LOWER"])
for key, patterns in matcher_rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["GoogleNow"], 2, 4),
(doc.vocab.strings["Java"], 5, 6),
(doc.vocab.strings["JS"], 8, 9),
]
def test_matcher_match_fuzz_all_lower(en_vocab):
matcher = Matcher(en_vocab, fuzzy=80, fuzzy_attrs=["LOWER"])
for key, patterns in matcher_rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["Java"], 5, 6),
]
def test_matcher_match_fuzz_some(en_vocab):
matcher = Matcher(en_vocab, fuzzy=85, fuzzy_attrs=["ORTH", "LOWER"])
for key, patterns in matcher_rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["Java"], 5, 6),
]
def test_matcher_match_fuzz_none(en_vocab):
matcher = Matcher(en_vocab, fuzzy=90, fuzzy_attrs=["ORTH", "LOWER"])
for key, patterns in matcher_rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == []
# fuzzy matches on specific tokens
def test_matcher_match_fuzz_pred1(en_vocab):
def test_matcher_match_fuzzy1(en_vocab):
rules = {
"JS": [[{"ORTH": "JavaScript"}]],
"GoogleNow": [[{"ORTH": {"FUZZY": "Google"}}, {"ORTH": "Now"}]],
"GoogleNow": [[{"ORTH": {"FUZZY1": "Google"}}, {"ORTH": "Now"}]],
"Java": [[{"LOWER": "java"}]],
}
matcher = Matcher(en_vocab, fuzzy=80)
matcher = Matcher(en_vocab)
for key, patterns in rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
words = ["They", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["GoogleNow"], 2, 4),
]
def test_matcher_match_fuzz_pred2(en_vocab):
def test_matcher_match_fuzzy2(en_vocab):
rules = {
"JS": [[{"ORTH": "JavaScript"}]],
"GoogleNow": [[{"ORTH": "Google"}, {"ORTH": "Now"}]],
"Java": [[{"LOWER": {"FUZZY": "java"}}]],
"Java": [[{"LOWER": {"FUZZY1": "java"}}]],
}
matcher = Matcher(en_vocab, fuzzy=80)
matcher = Matcher(en_vocab)
for key, patterns in rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
words = ["They", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["Java"], 5, 6),
]
def test_matcher_match_fuzz_preds(en_vocab):
def test_matcher_match_fuzzy3(en_vocab):
rules = {
"JS": [[{"ORTH": {"FUZZY": "JavaScript"}}]],
"GoogleNow": [[{"ORTH": {"FUZZY": "Google"}}, {"ORTH": "Now"}]],
"Java": [[{"LOWER": {"FUZZY": "java"}}]],
"JS": [[{"ORTH": {"FUZZY2": "JavaScript"}}]],
"GoogleNow": [[{"ORTH": {"FUZZY1": "Google"}}, {"ORTH": "Now"}]],
"Java": [[{"LOWER": {"FUZZY1": "java"}}]],
}
matcher = Matcher(en_vocab, fuzzy=80)
matcher = Matcher(en_vocab)
for key, patterns in rules.items():
matcher.add(key, patterns)
words = ["I", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
words = ["They", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["GoogleNow"], 2, 4),
@ -218,45 +170,45 @@ def test_matcher_match_fuzz_preds(en_vocab):
(doc.vocab.strings["JS"], 8, 9),
]
def test_matcher_match_fuzz_pred_in_set(en_vocab):
def test_matcher_match_fuzzy_set1(en_vocab):
rules = {
"GoogleNow": [[{"ORTH": {"FUZZY": {"IN": ["Google", "No"]}}, "OP": "+"}]]
"GoogleNow": [[{"ORTH": {"FUZZY2": {"IN": ["Google", "No"]}}, "OP": "+"}]]
}
matcher = Matcher(en_vocab, fuzzy=80)
matcher = Matcher(en_vocab)
for key, patterns in rules.items():
matcher.add(key, patterns, greedy="LONGEST")
words = ["I", "like", "Goggle", "Now"]
words = ["They", "like", "Goggle", "Now"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["GoogleNow"], 2, 4),
]
def test_matcher_match_fuzz_pred_not_in_set(en_vocab):
def test_matcher_match_fuzzy_set2(en_vocab):
rules = {
"GoogleNow": [[{"ORTH": {"FUZZY": {"NOT_IN": ["Google", "No"]}}, "OP": "+"}]],
"GoogleNow": [[{"ORTH": {"FUZZY2": {"NOT_IN": ["Google", "No"]}}, "OP": "+"}]],
}
matcher = Matcher(en_vocab, fuzzy=80)
matcher = Matcher(en_vocab)
for key, patterns in rules.items():
matcher.add(key, patterns, greedy="LONGEST")
words = ["I", "like", "Goggle", "Now"]
words = ["They", "like", "Goggle", "Now"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["GoogleNow"], 0, 2),
]
def test_matcher_match_fuzz_pred_in_set_with_exclude(en_vocab):
def test_matcher_match_fuzzy_set3(en_vocab):
rules = {
"GoogleNow": [[{"ORTH": {"FUZZY": {"IN": ["Google", "No"]},
"GoogleNow": [[{"ORTH": {"FUZZY1": {"IN": ["Google", "No"]},
"NOT_IN": ["Goggle"]},
"OP": "+"}]]
}
matcher = Matcher(en_vocab, fuzzy=80)
matcher = Matcher(en_vocab)
for key, patterns in rules.items():
matcher.add(key, patterns, greedy="LONGEST")
words = ["I", "like", "Goggle", "Now"]
words = ["They", "like", "Goggle", "Now"]
doc = Doc(matcher.vocab, words=words)
assert matcher(doc) == [
(doc.vocab.strings["GoogleNow"], 3, 4),