Fix matcher/.

This commit is contained in:
Raphael Mitsch 2023-07-03 12:20:04 +02:00
parent 8f59eeb772
commit e7cf6c7d9b
3 changed files with 151 additions and 98 deletions

View File

@ -108,7 +108,7 @@ cdef class DependencyMatcher:
key (str): The match ID. key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID. RETURNS (bool): Whether the matcher contains rules for this match ID.
""" """
return self.has_key(key) return self.has_key(key) # no-cython-lint: W601
def _validate_input(self, pattern, key): def _validate_input(self, pattern, key):
idx = 0 idx = 0
@ -264,7 +264,7 @@ cdef class DependencyMatcher:
def remove(self, key): def remove(self, key):
key = self._normalize_key(key) key = self._normalize_key(key)
if not key in self._patterns: if key not in self._patterns:
raise ValueError(Errors.E175.format(key=key)) raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(key) self._patterns.pop(key)
self._raw_patterns.pop(key) self._raw_patterns.pop(key)
@ -382,7 +382,7 @@ cdef class DependencyMatcher:
return [] return []
return [doc[node].head] return [doc[node].head]
def _gov(self,doc,node): def _gov(self, doc, node):
return list(doc[node].children) return list(doc[node].children)
def _dep_chain(self, doc, node): def _dep_chain(self, doc, node):

View File

@ -19,10 +19,8 @@ from ..attrs cimport (
LEMMA, LEMMA,
MORPH, MORPH,
NULL_ATTR, NULL_ATTR,
ORTH,
POS, POS,
TAG, TAG,
attr_id_t,
) )
from ..structs cimport TokenC from ..structs cimport TokenC
from ..tokens.doc cimport Doc, get_token_attr_for_matcher from ..tokens.doc cimport Doc, get_token_attr_for_matcher
@ -30,13 +28,11 @@ from ..tokens.morphanalysis cimport MorphAnalysis
from ..tokens.span cimport Span from ..tokens.span cimport Span
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..typedefs cimport attr_t from ..typedefs cimport attr_t
from ..vocab cimport Vocab
from ..attrs import IDS from ..attrs import IDS
from ..errors import Errors, MatchPatternError, Warnings from ..errors import Errors, MatchPatternError, Warnings
from ..schemas import validate_token_pattern from ..schemas import validate_token_pattern
from ..strings import get_string_id from ..strings import get_string_id
from ..util import registry
from .levenshtein import levenshtein_compare from .levenshtein import levenshtein_compare
DEF PADDING = 5 DEF PADDING = 5
@ -87,9 +83,9 @@ cdef class Matcher:
key (str): The match ID. key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID. RETURNS (bool): Whether the matcher contains rules for this match ID.
""" """
return self.has_key(key) return self.has_key(key) # no-cython-lint: W601
def add(self, key, patterns, *, on_match=None, greedy: str=None): def add(self, key, patterns, *, on_match=None, greedy: str = None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID """Add a match-rule to the matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns. key, an on_match callback, and one or more patterns.
@ -143,8 +139,13 @@ cdef class Matcher:
key = self._normalize_key(key) key = self._normalize_key(key)
for pattern in patterns: for pattern in patterns:
try: try:
specs = _preprocess_pattern(pattern, self.vocab, specs = _preprocess_pattern(
self._extensions, self._extra_predicates, self._fuzzy_compare) pattern,
self.vocab,
self._extensions,
self._extra_predicates,
self._fuzzy_compare
)
self.patterns.push_back(init_pattern(self.mem, key, specs)) self.patterns.push_back(init_pattern(self.mem, key, specs))
for spec in specs: for spec in specs:
for attr, _ in spec[1]: for attr, _ in spec[1]:
@ -168,7 +169,7 @@ cdef class Matcher:
key (str): The ID of the match rule. key (str): The ID of the match rule.
""" """
norm_key = self._normalize_key(key) norm_key = self._normalize_key(key)
if not norm_key in self._patterns: if norm_key not in self._patterns:
raise ValueError(Errors.E175.format(key=key)) raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(norm_key) self._patterns.pop(norm_key)
self._callbacks.pop(norm_key) self._callbacks.pop(norm_key)
@ -268,8 +269,15 @@ cdef class Matcher:
if self.patterns.empty(): if self.patterns.empty():
matches = [] matches = []
else: else:
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length, matches = find_matches(
extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments) &self.patterns[0],
self.patterns.size(),
doclike,
length,
extensions=self._extensions,
predicates=self._extra_predicates,
with_alignments=with_alignments
)
final_matches = [] final_matches = []
pairs_by_id = {} pairs_by_id = {}
# For each key, either add all matches, or only the filtered, # For each key, either add all matches, or only the filtered,
@ -366,7 +374,6 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
cdef vector[MatchC] matches cdef vector[MatchC] matches
cdef vector[vector[MatchAlignmentC]] align_states cdef vector[vector[MatchAlignmentC]] align_states
cdef vector[vector[MatchAlignmentC]] align_matches cdef vector[vector[MatchAlignmentC]] align_matches
cdef PatternStateC state
cdef int i, j, nr_extra_attr cdef int i, j, nr_extra_attr
cdef Pool mem = Pool() cdef Pool mem = Pool()
output = [] output = []
@ -388,14 +395,22 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
value = token.vocab.strings[value] value = token.vocab.strings[value]
extra_attr_values[i * nr_extra_attr + index] = value extra_attr_values[i * nr_extra_attr + index] = value
# Main loop # Main loop
cdef int nr_predicate = len(predicates)
for i in range(length): for i in range(length):
for j in range(n): for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0)) states.push_back(PatternStateC(patterns[j], i, 0))
if with_alignments != 0: if with_alignments != 0:
align_states.resize(states.size()) align_states.resize(states.size())
transition_states(states, matches, align_states, align_matches, predicate_cache, transition_states(
doclike[i], extra_attr_values, predicates, with_alignments) states,
matches,
align_states,
align_matches,
predicate_cache,
doclike[i],
extra_attr_values,
predicates,
with_alignments
)
extra_attr_values += nr_extra_attr extra_attr_values += nr_extra_attr
predicate_cache += len(predicates) predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns # Handle matches that end in 0-width patterns
@ -421,18 +436,28 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
return output return output
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, cdef void transition_states(
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches, vector[PatternStateC]& states,
vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states,
vector[vector[MatchAlignmentC]]& align_matches,
int8_t* cached_py_predicates, int8_t* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments) except *: Token token,
const attr_t* extra_attrs,
py_predicates,
bint with_alignments
) except *:
cdef int q = 0 cdef int q = 0
cdef vector[PatternStateC] new_states cdef vector[PatternStateC] new_states
cdef vector[vector[MatchAlignmentC]] align_new_states cdef vector[vector[MatchAlignmentC]] align_new_states
cdef int nr_predicate = len(py_predicates)
for i in range(states.size()): for i in range(states.size()):
if states[i].pattern.nr_py >= 1: if states[i].pattern.nr_py >= 1:
update_predicate_cache(cached_py_predicates, update_predicate_cache(
states[i].pattern, token, py_predicates) cached_py_predicates,
states[i].pattern,
token,
py_predicates
)
action = get_action(states[i], token.c, extra_attrs, action = get_action(states[i], token.c, extra_attrs,
cached_py_predicates) cached_py_predicates)
if action == REJECT: if action == REJECT:
@ -468,8 +493,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
align_new_states.push_back(align_states[q]) align_new_states.push_back(align_states[q])
states[q].pattern += 1 states[q].pattern += 1
if states[q].pattern.nr_py != 0: if states[q].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates, update_predicate_cache(
states[q].pattern, token, py_predicates) cached_py_predicates,
states[q].pattern,
token,
py_predicates
)
action = get_action(states[q], token.c, extra_attrs, action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates) cached_py_predicates)
# Update alignment before the transition of current state # Update alignment before the transition of current state
@ -485,8 +514,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
ent_id = get_ent_id(state.pattern) ent_id = get_ent_id(state.pattern)
if action == MATCH: if action == MATCH:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(
length=state.length+1)) pattern_id=ent_id,
start=state.start,
length=state.length+1
)
)
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
@ -494,23 +527,35 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
# push match without last token if length > 0 # push match without last token if length > 0
if state.length > 0: if state.length > 0:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(
length=state.length)) pattern_id=ent_id,
start=state.start,
length=state.length
)
)
# MATCH_DOUBLE emits matches twice, # MATCH_DOUBLE emits matches twice,
# add one more to align_matches in order to keep 1:1 relationship # add one more to align_matches in order to keep 1:1 relationship
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
# push match with last token # push match with last token
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(
length=state.length+1)) pattern_id=ent_id,
start=state.start,
length=state.length + 1
)
)
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
elif action == MATCH_REJECT: elif action == MATCH_REJECT:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(
length=state.length)) pattern_id=ent_id,
start=state.start,
length=state.length
)
)
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
@ -533,8 +578,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
align_states.push_back(align_new_states[i]) align_states.push_back(align_new_states[i])
cdef int update_predicate_cache(int8_t* cache, cdef int update_predicate_cache(
const TokenPatternC* pattern, Token token, predicates) except -1: int8_t* cache,
const TokenPatternC* pattern,
Token token,
predicates
) except -1:
# If the state references any extra predicates, check whether they match. # If the state references any extra predicates, check whether they match.
# These are cached, so that we don't call these potentially expensive # These are cached, so that we don't call these potentially expensive
# Python functions more than we need to. # Python functions more than we need to.
@ -580,10 +629,12 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
else: else:
state.pattern += 1 state.pattern += 1
cdef action_t get_action(
cdef action_t get_action(PatternStateC state, PatternStateC state,
const TokenC* token, const attr_t* extra_attrs, const TokenC * token,
const int8_t* predicate_matches) nogil: const attr_t * extra_attrs,
const int8_t * predicate_matches
) nogil:
"""We need to consider: """We need to consider:
a) Does the token match the specification? [Yes, No] a) Does the token match the specification? [Yes, No]
b) What's the quantifier? [1, 0+, ?] b) What's the quantifier? [1, 0+, ?]
@ -693,9 +744,12 @@ cdef action_t get_action(PatternStateC state,
return RETRY return RETRY
cdef int8_t get_is_match(PatternStateC state, cdef int8_t get_is_match(
const TokenC* token, const attr_t* extra_attrs, PatternStateC state,
const int8_t* predicate_matches) nogil: const TokenC* token,
const attr_t* extra_attrs,
const int8_t* predicate_matches
) nogil:
for i in range(state.pattern.nr_py): for i in range(state.pattern.nr_py):
if predicate_matches[state.pattern.py_predicates[i]] == -1: if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0 return 0
@ -1101,8 +1155,9 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
return output return output
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types, def _get_extension_extra_predicates(
seen_predicates): spec, extra_predicates, predicate_types, seen_predicates
):
output = [] output = []
for attr, value in spec.items(): for attr, value in spec.items():
if isinstance(value, dict): if isinstance(value, dict):
@ -1131,7 +1186,7 @@ def _get_operators(spec):
return (ONE,) return (ONE,)
elif spec["OP"] in lookup: elif spec["OP"] in lookup:
return lookup[spec["OP"]] return lookup[spec["OP"]]
#Min_max {n,m} # Min_max {n,m}
elif spec["OP"].startswith("{") and spec["OP"].endswith("}"): elif spec["OP"].startswith("{") and spec["OP"].endswith("}"):
# {n} --> {n,n} exactly n ONE,(n) # {n} --> {n,n} exactly n ONE,(n)
# {n,m}--> {n,m} min of n, max of m ONE,(n),ZERO_ONE,(m) # {n,m}--> {n,m} min of n, max of m ONE,(n),ZERO_ONE,(m)
@ -1142,8 +1197,8 @@ def _get_operators(spec):
min_max = min_max if "," in min_max else f"{min_max},{min_max}" min_max = min_max if "," in min_max else f"{min_max},{min_max}"
n, m = min_max.split(",") n, m = min_max.split(",")
#1. Either n or m is a blank string and the other is numeric -->isdigit # 1. Either n or m is a blank string and the other is numeric -->isdigit
#2. Both are numeric and n <= m # 2. Both are numeric and n <= m
if (not n.isdecimal() and not m.isdecimal()) or (n.isdecimal() and m.isdecimal() and int(n) > int(m)): if (not n.isdecimal() and not m.isdecimal()) or (n.isdecimal() and m.isdecimal() and int(n) > int(m)):
keys = ", ".join(lookup.keys()) + ", {n}, {n,m}, {n,}, {,m} where n and m are integers and n <= m " keys = ", ".join(lookup.keys()) + ", {n}, {n,m}, {n,}, {,m} where n and m are integers and n <= m "
raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys)) raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))

View File

@ -1,14 +1,12 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from libc.stdint cimport uintptr_t
from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set
import warnings import warnings
from ..attrs cimport DEP, LEMMA, MORPH, ORTH, POS, TAG from ..attrs cimport DEP, LEMMA, MORPH, POS, TAG
from ..attrs import IDS from ..attrs import IDS
from ..structs cimport TokenC
from ..tokens.span cimport Span from ..tokens.span cimport Span
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..typedefs cimport attr_t from ..typedefs cimport attr_t