Fix matcher/.

This commit is contained in:
Raphael Mitsch 2023-07-03 12:20:04 +02:00
parent 8f59eeb772
commit e7cf6c7d9b
3 changed files with 151 additions and 98 deletions

View File

@ -108,7 +108,7 @@ cdef class DependencyMatcher:
key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self.has_key(key)
return self.has_key(key) # no-cython-lint: W601
def _validate_input(self, pattern, key):
idx = 0
@ -264,7 +264,7 @@ cdef class DependencyMatcher:
def remove(self, key):
key = self._normalize_key(key)
if not key in self._patterns:
if key not in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(key)
self._raw_patterns.pop(key)

View File

@ -19,10 +19,8 @@ from ..attrs cimport (
LEMMA,
MORPH,
NULL_ATTR,
ORTH,
POS,
TAG,
attr_id_t,
)
from ..structs cimport TokenC
from ..tokens.doc cimport Doc, get_token_attr_for_matcher
@ -30,13 +28,11 @@ from ..tokens.morphanalysis cimport MorphAnalysis
from ..tokens.span cimport Span
from ..tokens.token cimport Token
from ..typedefs cimport attr_t
from ..vocab cimport Vocab
from ..attrs import IDS
from ..errors import Errors, MatchPatternError, Warnings
from ..schemas import validate_token_pattern
from ..strings import get_string_id
from ..util import registry
from .levenshtein import levenshtein_compare
DEF PADDING = 5
@ -87,7 +83,7 @@ cdef class Matcher:
key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self.has_key(key)
return self.has_key(key) # no-cython-lint: W601
def add(self, key, patterns, *, on_match=None, greedy: str = None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID
@ -143,8 +139,13 @@ cdef class Matcher:
key = self._normalize_key(key)
for pattern in patterns:
try:
specs = _preprocess_pattern(pattern, self.vocab,
self._extensions, self._extra_predicates, self._fuzzy_compare)
specs = _preprocess_pattern(
pattern,
self.vocab,
self._extensions,
self._extra_predicates,
self._fuzzy_compare
)
self.patterns.push_back(init_pattern(self.mem, key, specs))
for spec in specs:
for attr, _ in spec[1]:
@ -168,7 +169,7 @@ cdef class Matcher:
key (str): The ID of the match rule.
"""
norm_key = self._normalize_key(key)
if not norm_key in self._patterns:
if norm_key not in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(norm_key)
self._callbacks.pop(norm_key)
@ -268,8 +269,15 @@ cdef class Matcher:
if self.patterns.empty():
matches = []
else:
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments)
matches = find_matches(
&self.patterns[0],
self.patterns.size(),
doclike,
length,
extensions=self._extensions,
predicates=self._extra_predicates,
with_alignments=with_alignments
)
final_matches = []
pairs_by_id = {}
# For each key, either add all matches, or only the filtered,
@ -366,7 +374,6 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
cdef vector[MatchC] matches
cdef vector[vector[MatchAlignmentC]] align_states
cdef vector[vector[MatchAlignmentC]] align_matches
cdef PatternStateC state
cdef int i, j, nr_extra_attr
cdef Pool mem = Pool()
output = []
@ -388,14 +395,22 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
value = token.vocab.strings[value]
extra_attr_values[i * nr_extra_attr + index] = value
# Main loop
cdef int nr_predicate = len(predicates)
for i in range(length):
for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0))
if with_alignments != 0:
align_states.resize(states.size())
transition_states(states, matches, align_states, align_matches, predicate_cache,
doclike[i], extra_attr_values, predicates, with_alignments)
transition_states(
states,
matches,
align_states,
align_matches,
predicate_cache,
doclike[i],
extra_attr_values,
predicates,
with_alignments
)
extra_attr_values += nr_extra_attr
predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns
@ -421,18 +436,28 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
return output
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches,
cdef void transition_states(
vector[PatternStateC]& states,
vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states,
vector[vector[MatchAlignmentC]]& align_matches,
int8_t* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments) except *:
Token token,
const attr_t* extra_attrs,
py_predicates,
bint with_alignments
) except *:
cdef int q = 0
cdef vector[PatternStateC] new_states
cdef vector[vector[MatchAlignmentC]] align_new_states
cdef int nr_predicate = len(py_predicates)
for i in range(states.size()):
if states[i].pattern.nr_py >= 1:
update_predicate_cache(cached_py_predicates,
states[i].pattern, token, py_predicates)
update_predicate_cache(
cached_py_predicates,
states[i].pattern,
token,
py_predicates
)
action = get_action(states[i], token.c, extra_attrs,
cached_py_predicates)
if action == REJECT:
@ -468,8 +493,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
align_new_states.push_back(align_states[q])
states[q].pattern += 1
if states[q].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates,
states[q].pattern, token, py_predicates)
update_predicate_cache(
cached_py_predicates,
states[q].pattern,
token,
py_predicates
)
action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates)
# Update alignment before the transition of current state
@ -485,8 +514,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
ent_id = get_ent_id(state.pattern)
if action == MATCH:
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length+1))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length+1
)
)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
@ -494,23 +527,35 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
# push match without last token if length > 0
if state.length > 0:
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length
)
)
# MATCH_DOUBLE emits matches twice,
# add one more to align_matches in order to keep 1:1 relationship
if with_alignments != 0:
align_matches.push_back(align_states[q])
# push match with last token
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length+1))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length + 1
)
)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
elif action == MATCH_REJECT:
matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
MatchC(
pattern_id=ent_id,
start=state.start,
length=state.length
)
)
# `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0:
align_matches.push_back(align_states[q])
@ -533,8 +578,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
align_states.push_back(align_new_states[i])
cdef int update_predicate_cache(int8_t* cache,
const TokenPatternC* pattern, Token token, predicates) except -1:
cdef int update_predicate_cache(
int8_t* cache,
const TokenPatternC* pattern,
Token token,
predicates
) except -1:
# If the state references any extra predicates, check whether they match.
# These are cached, so that we don't call these potentially expensive
# Python functions more than we need to.
@ -580,10 +629,12 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
else:
state.pattern += 1
cdef action_t get_action(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs,
const int8_t* predicate_matches) nogil:
cdef action_t get_action(
PatternStateC state,
const TokenC * token,
const attr_t * extra_attrs,
const int8_t * predicate_matches
) nogil:
"""We need to consider:
a) Does the token match the specification? [Yes, No]
b) What's the quantifier? [1, 0+, ?]
@ -693,9 +744,12 @@ cdef action_t get_action(PatternStateC state,
return RETRY
cdef int8_t get_is_match(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs,
const int8_t* predicate_matches) nogil:
cdef int8_t get_is_match(
PatternStateC state,
const TokenC* token,
const attr_t* extra_attrs,
const int8_t* predicate_matches
) nogil:
for i in range(state.pattern.nr_py):
if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0
@ -1101,8 +1155,9 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
return output
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
seen_predicates):
def _get_extension_extra_predicates(
spec, extra_predicates, predicate_types, seen_predicates
):
output = []
for attr, value in spec.items():
if isinstance(value, dict):

View File

@ -1,14 +1,12 @@
# cython: infer_types=True, profile=True
from libc.stdint cimport uintptr_t
from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set
import warnings
from ..attrs cimport DEP, LEMMA, MORPH, ORTH, POS, TAG
from ..attrs cimport DEP, LEMMA, MORPH, POS, TAG
from ..attrs import IDS
from ..structs cimport TokenC
from ..tokens.span cimport Span
from ..tokens.token cimport Token
from ..typedefs cimport attr_t