Fix matcher/.

This commit is contained in:
Raphael Mitsch 2023-07-03 12:20:04 +02:00
parent 8f59eeb772
commit e7cf6c7d9b
3 changed files with 151 additions and 98 deletions

View File

@ -108,7 +108,7 @@ cdef class DependencyMatcher:
key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self.has_key(key)
return self.has_key(key) # no-cython-lint: W601
def _validate_input(self, pattern, key):
idx = 0
@ -264,7 +264,7 @@ cdef class DependencyMatcher:
def remove(self, key):
key = self._normalize_key(key)
if not key in self._patterns:
if key not in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(key)
self._raw_patterns.pop(key)
@ -382,7 +382,7 @@ cdef class DependencyMatcher:
return []
return [doc[node].head]
def _gov(self,doc,node):
def _gov(self, doc, node):
return list(doc[node].children)
def _dep_chain(self, doc, node):

View File

@ -19,10 +19,8 @@ from ..attrs cimport (
LEMMA,
MORPH,
NULL_ATTR,
ORTH,
POS,
TAG,
attr_id_t,
)
from ..structs cimport TokenC
from ..tokens.doc cimport Doc, get_token_attr_for_matcher
@ -30,13 +28,11 @@ from ..tokens.morphanalysis cimport MorphAnalysis
from ..tokens.span cimport Span
from ..tokens.token cimport Token
from ..typedefs cimport attr_t
from ..vocab cimport Vocab
from ..attrs import IDS
from ..errors import Errors, MatchPatternError, Warnings
from ..schemas import validate_token_pattern
from ..strings import get_string_id
from ..util import registry
from .levenshtein import levenshtein_compare
DEF PADDING = 5
@ -87,9 +83,9 @@ cdef class Matcher:
key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self.has_key(key)
return self.has_key(key) # no-cython-lint: W601
def add(self, key, patterns, *, on_match=None, greedy: str=None):
def add(self, key, patterns, *, on_match=None, greedy: str = None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns.
@ -143,8 +139,13 @@ cdef class Matcher:
key = self._normalize_key(key)
for pattern in patterns:
try:
specs = _preprocess_pattern(pattern, self.vocab,
self._extensions, self._extra_predicates, self._fuzzy_compare)
specs = _preprocess_pattern(
pattern,
self.vocab,
self._extensions,
self._extra_predicates,
self._fuzzy_compare
)
self.patterns.push_back(init_pattern(self.mem, key, specs))
for spec in specs:
for attr, _ in spec[1]:
@ -168,7 +169,7 @@ cdef class Matcher:
key (str): The ID of the match rule.
"""
norm_key = self._normalize_key(key)
if not norm_key in self._patterns:
if norm_key not in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(norm_key)
self._callbacks.pop(norm_key)
@ -268,8 +269,15 @@ cdef class Matcher:
if self.patterns.empty():
matches = []
else:
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments)
matches = find_matches(
&self.patterns[0],
self.patterns.size(),
doclike,
length,
extensions=self._extensions,
predicates=self._extra_predicates,
with_alignments=with_alignments
)
final_matches = []
pairs_by_id = {}
# For each key, either add all matches, or only the filtered,
@ -289,9 +297,9 @@ cdef class Matcher:
memset(matched, 0, length * sizeof(matched[0]))
span_filter = self._filter.get(key)
if span_filter == "FIRST":
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
elif span_filter == "LONGEST":
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
else:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
for match in sorted_pairs:
@ -366,7 +374,6 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
cdef vector[MatchC] matches
cdef vector[vector[MatchAlignmentC]] align_states
cdef vector[vector[MatchAlignmentC]] align_matches
cdef PatternStateC state
cdef int i, j, nr_extra_attr
cdef Pool mem = Pool()
output = []
@ -388,14 +395,22 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
value = token.vocab.strings[value]
extra_attr_values[i * nr_extra_attr + index] = value
# Main loop
cdef int nr_predicate = len(predicates)
for i in range(length):
for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0))
if with_alignments != 0:
align_states.resize(states.size())
transition_states(states, matches, align_states, align_matches, predicate_cache,
doclike[i], extra_attr_values, predicates, with_alignments)
transition_states(
states,
matches,
align_states,
align_matches,
predicate_cache,
doclike[i],
extra_attr_values,
predicates,
with_alignments
)
extra_attr_values += nr_extra_attr
predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns
@ -421,18 +436,28 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
return output
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches,
int8_t* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments) except *:
cdef void transition_states(
vector[PatternStateC]& states,
vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states,
vector[vector[MatchAlignmentC]]& align_matches,
int8_t* cached_py_predicates,
Token token,
const attr_t* extra_attrs,
py_predicates,
bint with_alignments
) except *:
cdef int q = 0
cdef vector[PatternStateC] new_states
cdef vector[vector[MatchAlignmentC]] align_new_states
cdef int nr_predicate = len(py_predicates)
for i in range(states.size()):
if states[i].pattern.nr_py >= 1:
update_predicate_cache(cached_py_predicates,
states[i].pattern, token, py_predicates)
update_predicate_cache(
cached_py_predicates,
states[i].pattern,
token,
py_predicates
)
action = get_action(states[i], token.c, extra_attrs,
cached_py_predicates)
if action == REJECT:
@ -468,8 +493,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
align_new_states.push_back(align_states[q])
states[q].pattern += 1
if states[q].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates,
states[q].pattern, token, py_predicates)
update_predicate_cache(
cached_py_predicates,
states[q].pattern,
token,
py_predicates
)
action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates)
# Update alignment before the transition of current state
@ -485,8 +514,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
ent_id = get_ent_id(state.pattern)
if action == MATCH:
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length+1))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length+1
)
)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
@ -494,23 +527,35 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
# push match without last token if length > 0
if state.length > 0:
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length
)
)
# MATCH_DOUBLE emits matches twice,
# add one more to align_matches in order to keep 1:1 relationship
if with_alignments != 0:
align_matches.push_back(align_states[q])
# push match with last token
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length+1))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length + 1
)
)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
elif action == MATCH_REJECT:
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length
)
)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
@ -533,8 +578,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
align_states.push_back(align_new_states[i])
cdef int update_predicate_cache(int8_t* cache,
const TokenPatternC* pattern, Token token, predicates) except -1:
cdef int update_predicate_cache(
int8_t* cache,
const TokenPatternC* pattern,
Token token,
predicates
) except -1:
# If the state references any extra predicates, check whether they match.
# These are cached, so that we don't call these potentially expensive
# Python functions more than we need to.
@ -580,10 +629,12 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
else:
state.pattern += 1
cdef action_t get_action(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs,
const int8_t* predicate_matches) nogil:
cdef action_t get_action(
PatternStateC state,
const TokenC * token,
const attr_t * extra_attrs,
const int8_t * predicate_matches
) nogil:
"""We need to consider:
a) Does the token match the specification? [Yes, No]
b) What's the quantifier? [1, 0+, ?]
@ -649,53 +700,56 @@ cdef action_t get_action(PatternStateC state,
is_match = not is_match
quantifier = ONE
if quantifier == ONE:
if is_match and is_final:
# Yes, final: 1000
return MATCH
elif is_match and not is_final:
# Yes, non-final: 0100
return ADVANCE
elif not is_match and is_final:
# No, final: 0000
return REJECT
else:
return REJECT
if is_match and is_final:
# Yes, final: 1000
return MATCH
elif is_match and not is_final:
# Yes, non-final: 0100
return ADVANCE
elif not is_match and is_final:
# No, final: 0000
return REJECT
else:
return REJECT
elif quantifier == ZERO_PLUS:
if is_match and is_final:
# Yes, final: 1001
return MATCH_EXTEND
elif is_match and not is_final:
# Yes, non-final: 0011
return RETRY_EXTEND
elif not is_match and is_final:
# No, final 2000 (note: Don't include last token!)
return MATCH_REJECT
else:
# No, non-final 0010
return RETRY
if is_match and is_final:
# Yes, final: 1001
return MATCH_EXTEND
elif is_match and not is_final:
# Yes, non-final: 0011
return RETRY_EXTEND
elif not is_match and is_final:
# No, final 2000 (note: Don't include last token!)
return MATCH_REJECT
else:
# No, non-final 0010
return RETRY
elif quantifier == ZERO_ONE:
if is_match and is_final:
# Yes, final: 3000
# To cater for a pattern ending in "?", we need to add
# a match both with and without the last token
return MATCH_DOUBLE
elif is_match and not is_final:
# Yes, non-final: 0110
# We need both branches here, consider a pair like:
# pattern: .?b string: b
# If we 'ADVANCE' on the .?, we miss the match.
return RETRY_ADVANCE
elif not is_match and is_final:
# No, final 2000 (note: Don't include last token!)
return MATCH_REJECT
else:
# No, non-final 0010
return RETRY
if is_match and is_final:
# Yes, final: 3000
# To cater for a pattern ending in "?", we need to add
# a match both with and without the last token
return MATCH_DOUBLE
elif is_match and not is_final:
# Yes, non-final: 0110
# We need both branches here, consider a pair like:
# pattern: .?b string: b
# If we 'ADVANCE' on the .?, we miss the match.
return RETRY_ADVANCE
elif not is_match and is_final:
# No, final 2000 (note: Don't include last token!)
return MATCH_REJECT
else:
# No, non-final 0010
return RETRY
cdef int8_t get_is_match(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs,
const int8_t* predicate_matches) nogil:
cdef int8_t get_is_match(
PatternStateC state,
const TokenC* token,
const attr_t* extra_attrs,
const int8_t* predicate_matches
) nogil:
for i in range(state.pattern.nr_py):
if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0
@ -860,7 +914,7 @@ class _FuzzyPredicate:
self.is_extension = is_extension
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
fuzz = self.predicate[len("FUZZY"):] # number after prefix
fuzz = self.predicate[len("FUZZY"):] # number after prefix
self.fuzzy = int(fuzz) if fuzz else -1
self.fuzzy_compare = fuzzy_compare
self.key = _predicate_cache_key(self.attr, self.predicate, value, fuzzy=self.fuzzy)
@ -1082,7 +1136,7 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
elif cls == _FuzzyPredicate:
if isinstance(value, dict):
# add predicates inside fuzzy operator
fuzz = type_[len("FUZZY"):] # number after prefix
fuzz = type_[len("FUZZY"):] # number after prefix
fuzzy_val = int(fuzz) if fuzz else -1
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
extra_predicates, seen_predicates,
@ -1101,8 +1155,9 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
return output
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
seen_predicates):
def _get_extension_extra_predicates(
spec, extra_predicates, predicate_types, seen_predicates
):
output = []
for attr, value in spec.items():
if isinstance(value, dict):
@ -1131,7 +1186,7 @@ def _get_operators(spec):
return (ONE,)
elif spec["OP"] in lookup:
return lookup[spec["OP"]]
#Min_max {n,m}
# Min_max {n,m}
elif spec["OP"].startswith("{") and spec["OP"].endswith("}"):
# {n} --> {n,n} exactly n ONE,(n)
# {n,m}--> {n,m} min of n, max of m ONE,(n),ZERO_ONE,(m)
@ -1142,8 +1197,8 @@ def _get_operators(spec):
min_max = min_max if "," in min_max else f"{min_max},{min_max}"
n, m = min_max.split(",")
#1. Either n or m is a blank string and the other is numeric -->isdigit
#2. Both are numeric and n <= m
# 1. Either n or m is a blank string and the other is numeric -->isdigit
# 2. Both are numeric and n <= m
if (not n.isdecimal() and not m.isdecimal()) or (n.isdecimal() and m.isdecimal() and int(n) > int(m)):
keys = ", ".join(lookup.keys()) + ", {n}, {n,m}, {n,}, {,m} where n and m are integers and n <= m "
raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))

View File

@ -1,14 +1,12 @@
# cython: infer_types=True, profile=True
from libc.stdint cimport uintptr_t
from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set
import warnings
from ..attrs cimport DEP, LEMMA, MORPH, ORTH, POS, TAG
from ..attrs cimport DEP, LEMMA, MORPH, POS, TAG
from ..attrs import IDS
from ..structs cimport TokenC
from ..tokens.span cimport Span
from ..tokens.token cimport Token
from ..typedefs cimport attr_t