Removal of formatting changes

This commit is contained in:
Source-Shen 2022-07-05 16:54:10 +08:00
parent 57ec153587
commit 1e1acf640a

View File

@ -25,8 +25,10 @@ from ..errors import Errors, MatchPatternError, Warnings
from ..strings import get_string_id from ..strings import get_string_id
from ..attrs import IDS from ..attrs import IDS
DEF PADDING = 5 DEF PADDING = 5
cdef class Matcher: cdef class Matcher:
"""Match sequences of tokens, based on pattern rules. """Match sequences of tokens, based on pattern rules.
@ -71,7 +73,7 @@ cdef class Matcher:
""" """
return self.has_key(key) return self.has_key(key)
def add(self, key, patterns, *, on_match=None, greedy: str = None): def add(self, key, patterns, *, on_match=None, greedy: str=None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID """Add a match-rule to the matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns. key, an on_match callback, and one or more patterns.
@ -160,7 +162,7 @@ cdef class Matcher:
while i < self.patterns.size(): while i < self.patterns.size():
pattern_key = get_ent_id(self.patterns.at(i)) pattern_key = get_ent_id(self.patterns.at(i))
if pattern_key == norm_key: if pattern_key == norm_key:
self.patterns.erase(self.patterns.begin() + i) self.patterns.erase(self.patterns.begin()+i)
else: else:
i += 1 i += 1
@ -253,8 +255,7 @@ cdef class Matcher:
matches = [] matches = []
else: else:
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length, matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates, extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments)
with_alignments=with_alignments)
final_matches = [] final_matches = []
pairs_by_id = {} pairs_by_id = {}
# For each key, either add all matches, or only the filtered, # For each key, either add all matches, or only the filtered,
@ -268,21 +269,21 @@ cdef class Matcher:
pairs_by_id[key] = pairs pairs_by_id[key] = pairs
else: else:
final_matches.append((key, *match)) final_matches.append((key, *match))
matched = <char *> tmp_pool.alloc(length, sizeof(char)) matched = <char*> tmp_pool.alloc(length, sizeof(char))
empty = <char *> tmp_pool.alloc(length, sizeof(char)) empty = <char*> tmp_pool.alloc(length, sizeof(char))
for key, pairs in pairs_by_id.items(): for key, pairs in pairs_by_id.items():
memset(matched, 0, length * sizeof(matched[0])) memset(matched, 0, length * sizeof(matched[0]))
span_filter = self._filter.get(key) span_filter = self._filter.get(key)
if span_filter == "FIRST": if span_filter == "FIRST":
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
elif span_filter == "LONGEST": elif span_filter == "LONGEST":
sorted_pairs = sorted(pairs, key=lambda x: (x[1] - x[0], -x[0]), reverse=True) # reverse sort by length sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
else: else:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter)) raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
for match in sorted_pairs: for match in sorted_pairs:
start, end = match[:2] start, end = match[:2]
assert 0 <= start < end # Defend against segfaults assert 0 <= start < end # Defend against segfaults
span_len = end - start span_len = end-start
# If no tokens in the span have matched # If no tokens in the span have matched
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0: if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
final_matches.append((key, *match)) final_matches.append((key, *match))
@ -302,9 +303,9 @@ cdef class Matcher:
final_results = [] final_results = []
for key, start, end, alignments in final_matches: for key, start, end, alignments in final_matches:
sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False) sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False)
alignments = [0] * (end - start) alignments = [0] * (end-start)
for align in sorted_alignments: for align in sorted_alignments:
if align['length'] >= end - start: if align['length'] >= end-start:
continue continue
# Since alignments are sorted in order of (length, token_idx) # Since alignments are sorted in order of (length, token_idx)
# this overwrites smaller token_idx when they have same length. # this overwrites smaller token_idx when they have same length.
@ -326,6 +327,7 @@ cdef class Matcher:
else: else:
return key return key
def unpickle_matcher(vocab, patterns, callbacks): def unpickle_matcher(vocab, patterns, callbacks):
matcher = Matcher(vocab) matcher = Matcher(vocab)
for key, pattern in patterns.items(): for key, pattern in patterns.items():
@ -333,8 +335,7 @@ def unpickle_matcher(vocab, patterns, callbacks):
matcher.add(key, pattern, on_match=callback) matcher.add(key, pattern, on_match=callback)
return matcher return matcher
cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, extensions=None, predicates=tuple(), cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, extensions=None, predicates=tuple(), bint with_alignments=0):
bint with_alignments=0):
"""Find matches in a doc, with a compiled array of patterns. Matches are """Find matches in a doc, with a compiled array of patterns. Matches are
returned as a list of (id, start, end) tuples or (id, start, end, alignments) tuples (if with_alignments != 0) returned as a list of (id, start, end) tuples or (id, start, end, alignments) tuples (if with_alignments != 0)
@ -358,13 +359,13 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
# avoid any processing or mem alloc if the document is empty # avoid any processing or mem alloc if the document is empty
return output return output
if len(predicates) > 0: if len(predicates) > 0:
predicate_cache = <int8_t *> mem.alloc(length * len(predicates), sizeof(int8_t)) predicate_cache = <int8_t*> mem.alloc(length * len(predicates), sizeof(int8_t))
if extensions is not None and len(extensions) >= 1: if extensions is not None and len(extensions) >= 1:
nr_extra_attr = max(extensions.values()) + 1 nr_extra_attr = max(extensions.values()) + 1
extra_attr_values = <attr_t *> mem.alloc(length * nr_extra_attr, sizeof(attr_t)) extra_attr_values = <attr_t*> mem.alloc(length * nr_extra_attr, sizeof(attr_t))
else: else:
nr_extra_attr = 0 nr_extra_attr = 0
extra_attr_values = <attr_t *> mem.alloc(length, sizeof(attr_t)) extra_attr_values = <attr_t*> mem.alloc(length, sizeof(attr_t))
for i, token in enumerate(doclike): for i, token in enumerate(doclike):
for name, index in extensions.items(): for name, index in extensions.items():
value = token._.get(name) value = token._.get(name)
@ -378,8 +379,8 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
states.push_back(PatternStateC(patterns[j], i, 0)) states.push_back(PatternStateC(patterns[j], i, 0))
if with_alignments != 0: if with_alignments != 0:
align_states.resize(states.size()) align_states.resize(states.size())
transition_states(states, matches, align_states, align_matches, predicate_cache, doclike[i], extra_attr_values, transition_states(states, matches, align_states, align_matches, predicate_cache,
predicates, with_alignments) doclike[i], extra_attr_values, predicates, with_alignments)
extra_attr_values += nr_extra_attr extra_attr_values += nr_extra_attr
predicate_cache += len(predicates) predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns # Handle matches that end in 0-width patterns
@ -389,7 +390,7 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
match = ( match = (
matches[i].pattern_id, matches[i].pattern_id,
matches[i].start, matches[i].start,
matches[i].start + matches[i].length matches[i].start+matches[i].length
) )
# We need to deduplicate, because we could otherwise arrive at the same # We need to deduplicate, because we could otherwise arrive at the same
# match through two paths, e.g. .?.? matching 'a'. Are we matching the # match through two paths, e.g. .?.? matching 'a'. Are we matching the
@ -404,19 +405,21 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
seen.add(match) seen.add(match)
return output return output
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches,
vector[vector[MatchAlignmentC]]& align_matches, int8_t * cached_py_predicates,
int8_t * cached_py_predicates, Token token, Token token, const attr_t * extra_attrs, py_predicates, bint with_alignments) except *:
const attr_t * extra_attrs, py_predicates, bint with_alignments) except *:
cdef int q = 0 cdef int q = 0
cdef vector[PatternStateC] new_states cdef vector[PatternStateC] new_states
cdef vector[vector[MatchAlignmentC]] align_new_states cdef vector[vector[MatchAlignmentC]] align_new_states
cdef int nr_predicate = len(py_predicates) cdef int nr_predicate = len(py_predicates)
for i in range(states.size()): for i in range(states.size()):
if states[i].pattern.nr_py >= 1: if states[i].pattern.nr_py >= 1:
update_predicate_cache(cached_py_predicates, states[i].pattern, token, py_predicates) update_predicate_cache(cached_py_predicates,
action = get_action(states[i], token.c, extra_attrs, cached_py_predicates) states[i].pattern, token, py_predicates)
action = get_action(states[i], token.c, extra_attrs,
cached_py_predicates)
if action == REJECT: if action == REJECT:
continue continue
# Keep only a subset of states (the active ones). Index q is the # Keep only a subset of states (the active ones). Index q is the
@ -437,19 +440,23 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
if action in [RETRY_EXTEND, RETRY_OR_EXTEND]: if action in [RETRY_EXTEND, RETRY_OR_EXTEND]:
# This handles the 'extend' # This handles the 'extend'
new_states.push_back( new_states.push_back(
PatternStateC(pattern=states[q].pattern, start=state.start, length=state.length + 1)) PatternStateC(pattern=states[q].pattern, start=state.start,
length=state.length + 1))
if with_alignments != 0: if with_alignments != 0:
align_new_states.push_back(align_states[q]) align_new_states.push_back(align_states[q])
if action == RETRY_ADVANCE: if action == RETRY_ADVANCE:
# This handles the 'advance' # This handles the 'advance'
new_states.push_back( new_states.push_back(
PatternStateC(pattern=states[q].pattern + 1, start=state.start, length=state.length + 1)) PatternStateC(pattern=states[q].pattern + 1, start=state.start,
length=state.length + 1))
if with_alignments != 0: if with_alignments != 0:
align_new_states.push_back(align_states[q]) align_new_states.push_back(align_states[q])
states[q].pattern += 1 states[q].pattern += 1
if states[q].pattern.nr_py != 0: if states[q].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates, states[q].pattern, token, py_predicates) update_predicate_cache(cached_py_predicates,
next_action = get_action(states[q], token.c, extra_attrs, cached_py_predicates) states[q].pattern, token, py_predicates)
next_action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates)
# To account for *? and +? # To account for *? and +?
if get_quantifier(state) == ZERO_MINUS: if get_quantifier(state) == ZERO_MINUS:
next_action = cast_to_non_greedy_action(action, next_action, new_states, align_new_states, next_action = cast_to_non_greedy_action(action, next_action, new_states, align_new_states,
@ -470,37 +477,49 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
else: else:
ent_id = get_ent_id(state.pattern) ent_id = get_ent_id(state.pattern)
if action == MATCH: if action == MATCH:
matches.push_back(MatchC(pattern_id=ent_id, start=state.start, length=state.length + 1)) matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length + 1))
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
elif action == MATCH_DOUBLE: elif action == MATCH_DOUBLE:
# push match without last token if length > 0 # push match without last token if length > 0
if state.length > 0: if state.length > 0:
matches.push_back(MatchC(pattern_id=ent_id, start=state.start, length=state.length)) matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
# MATCH_DOUBLE emits matches twice, # MATCH_DOUBLE emits matches twice,
# add one more to align_matches in order to keep 1:1 relationship # add one more to align_matches in order to keep 1:1 relationship
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
# push match with last token # push match with last token
matches.push_back(MatchC(pattern_id=ent_id, start=state.start, length=state.length + 1)) matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length + 1))
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
elif action == MATCH_REJECT: elif action == MATCH_REJECT:
matches.push_back(MatchC(pattern_id=ent_id, start=state.start, length=state.length)) matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
elif action == MATCH_EXTEND: elif action == MATCH_EXTEND:
matches.push_back(MatchC(pattern_id=ent_id, start=state.start, length=state.length)) matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length))
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
states[q].length += 1 states[q].length += 1
q += 1 q += 1
elif action == MATCH_ADVANCE: elif action == MATCH_ADVANCE:
matches.push_back(MatchC(pattern_id=ent_id, start=state.start, length=state.length + 1)) matches.push_back(
MatchC(pattern_id=ent_id, start=state.start,
length=state.length + 1))
# `align_matches` always corresponds to `matches` 1:1 # `align_matches` always corresponds to `matches` 1:1
if with_alignments != 0: if with_alignments != 0:
align_matches.push_back(align_states[q]) align_matches.push_back(align_states[q])
@ -516,8 +535,8 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
for i in range(align_new_states.size()): for i in range(align_new_states.size()):
align_states.push_back(align_new_states[i]) align_states.push_back(align_new_states[i])
cdef int update_predicate_cache(int8_t * cache, cdef int update_predicate_cache(int8_t* cache,
const TokenPatternC * pattern, Token token, predicates) except -1: const TokenPatternC* pattern, Token token, predicates) except -1:
# If the state references any extra predicates, check whether they match. # If the state references any extra predicates, check whether they match.
# These are cached, so that we don't call these potentially expensive # These are cached, so that we don't call these potentially expensive
# Python functions more than we need to. # Python functions more than we need to.
@ -535,6 +554,7 @@ cdef int update_predicate_cache(int8_t * cache,
else: else:
raise ValueError(Errors.E125.format(value=result)) raise ValueError(Errors.E125.format(value=result))
cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states, cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
vector[vector[MatchAlignmentC]]& align_matches, vector[vector[MatchAlignmentC]]& align_matches,
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_states,
@ -565,8 +585,9 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
else: else:
state.pattern += 1 state.pattern += 1
cdef action_t get_action(PatternStateC state, const TokenC * token, const attr_t * extra_attrs, cdef action_t get_action(PatternStateC state,
const int8_t * predicate_matches) nogil: const TokenC* token, const attr_t* extra_attrs,
const int8_t* predicate_matches) nogil:
"""We need to consider: """We need to consider:
a) Does the token match the specification? [Yes, No] a) Does the token match the specification? [Yes, No]
b) What's the quantifier? [1, 0+, 0-, ?] b) What's the quantifier? [1, 0+, 0-, ?]
@ -656,6 +677,7 @@ cdef action_t get_action(PatternStateC state, const TokenC * token, const attr_t
# is_non_greedy_plus() verifies that the current state's pattern is +? # is_non_greedy_plus() verifies that the current state's pattern is +?
# has_star_tail() verifies the remaining pattern tokens are either * or *?, # has_star_tail() verifies the remaining pattern tokens are either * or *?,
# so that it is valid for the current match to exist. # so that it is valid for the current match to exist.
# TODO if this impacts the performance, "ONE_MINUS" could be created
return MATCH_ADVANCE return MATCH_ADVANCE
elif is_match and not is_final: elif is_match and not is_final:
# Yes, non-final: 0100 # Yes, non-final: 0100
@ -712,8 +734,8 @@ cdef action_t get_action(PatternStateC state, const TokenC * token, const attr_t
return RETRY return RETRY
cdef int8_t get_is_match(PatternStateC state, cdef int8_t get_is_match(PatternStateC state,
const TokenC * token, const attr_t * extra_attrs, const TokenC* token, const attr_t* extra_attrs,
const int8_t * predicate_matches) nogil: const int8_t* predicate_matches) nogil:
for i in range(state.pattern.nr_py): for i in range(state.pattern.nr_py):
if predicate_matches[state.pattern.py_predicates[i]] == -1: if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0 return 0
@ -823,26 +845,26 @@ cdef inline int8_t has_non_greedy_tail(PatternStateC state) nogil:
return 0 return 0
return 1 return 1
cdef TokenPatternC * init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
pattern = <TokenPatternC *> mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) pattern = <TokenPatternC*> mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
cdef int i, index cdef int i, index
for i, (quantifier, spec, extensions, predicates, token_idx) in enumerate(token_specs): for i, (quantifier, spec, extensions, predicates, token_idx) in enumerate(token_specs):
pattern[i].quantifier = quantifier pattern[i].quantifier = quantifier
# Ensure attrs refers to a null pointer if nr_attr == 0 # Ensure attrs refers to a null pointer if nr_attr == 0
if len(spec) > 0: if len(spec) > 0:
pattern[i].attrs = <AttrValueC *> mem.alloc(len(spec), sizeof(AttrValueC)) pattern[i].attrs = <AttrValueC*> mem.alloc(len(spec), sizeof(AttrValueC))
pattern[i].nr_attr = len(spec) pattern[i].nr_attr = len(spec)
for j, (attr, value) in enumerate(spec): for j, (attr, value) in enumerate(spec):
pattern[i].attrs[j].attr = attr pattern[i].attrs[j].attr = attr
pattern[i].attrs[j].value = value pattern[i].attrs[j].value = value
if len(extensions) > 0: if len(extensions) > 0:
pattern[i].extra_attrs = <IndexValueC *> mem.alloc(len(extensions), sizeof(IndexValueC)) pattern[i].extra_attrs = <IndexValueC*> mem.alloc(len(extensions), sizeof(IndexValueC))
for j, (index, value) in enumerate(extensions): for j, (index, value) in enumerate(extensions):
pattern[i].extra_attrs[j].index = index pattern[i].extra_attrs[j].index = index
pattern[i].extra_attrs[j].value = value pattern[i].extra_attrs[j].value = value
pattern[i].nr_extra_attr = len(extensions) pattern[i].nr_extra_attr = len(extensions)
if len(predicates) > 0: if len(predicates) > 0:
pattern[i].py_predicates = <int32_t *> mem.alloc(len(predicates), sizeof(int32_t)) pattern[i].py_predicates = <int32_t*> mem.alloc(len(predicates), sizeof(int32_t))
for j, index in enumerate(predicates): for j, index in enumerate(predicates):
pattern[i].py_predicates[j] = index pattern[i].py_predicates[j] = index
pattern[i].nr_py = len(predicates) pattern[i].nr_py = len(predicates)
@ -852,7 +874,7 @@ cdef TokenPatternC * init_pattern(Pool mem, attr_t entity_id, object token_specs
# Use quantifier to identify final ID pattern node (rather than previous # Use quantifier to identify final ID pattern node (rather than previous
# uninitialized quantifier == 0/ZERO + nr_attr == 0 + non-zero-length attrs) # uninitialized quantifier == 0/ZERO + nr_attr == 0 + non-zero-length attrs)
pattern[i].quantifier = FINAL_ID pattern[i].quantifier = FINAL_ID
pattern[i].attrs = <AttrValueC *> mem.alloc(1, sizeof(AttrValueC)) pattern[i].attrs = <AttrValueC*> mem.alloc(1, sizeof(AttrValueC))
pattern[i].attrs[0].attr = ID pattern[i].attrs[0].attr = ID
pattern[i].attrs[0].value = entity_id pattern[i].attrs[0].value = entity_id
pattern[i].nr_attr = 1 pattern[i].nr_attr = 1
@ -862,7 +884,7 @@ cdef TokenPatternC * init_pattern(Pool mem, attr_t entity_id, object token_specs
return pattern return pattern
cdef attr_t get_ent_id(const TokenPatternC * pattern) nogil: cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
while pattern.quantifier != FINAL_ID: while pattern.quantifier != FINAL_ID:
pattern += 1 pattern += 1
id_attr = pattern[0].attrs[0] id_attr = pattern[0].attrs[0]
@ -1097,7 +1119,7 @@ def _get_extra_predicates(spec, extra_predicates, vocab):
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types, def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
seen_predicates): seen_predicates):
output = [] output = []
for attr, value in spec.items(): for attr, value in spec.items():
if isinstance(value, dict): if isinstance(value, dict):