mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-10 09:16:31 +03:00
830 lines
31 KiB
Cython
830 lines
31 KiB
Cython
# cython: infer_types=True
|
|
# cython: profile=True
|
|
from __future__ import unicode_literals
|
|
|
|
from libcpp.vector cimport vector
|
|
from libc.stdint cimport int32_t
|
|
from cymem.cymem cimport Pool
|
|
from murmurhash.mrmr cimport hash64
|
|
|
|
import re
|
|
import srsly
|
|
|
|
from ..typedefs cimport attr_t
|
|
from ..structs cimport TokenC
|
|
from ..vocab cimport Vocab
|
|
from ..tokens.doc cimport Doc, get_token_attr
|
|
from ..tokens.token cimport Token
|
|
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
|
|
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
|
from ..util import get_json_validator, validate_json
|
|
from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning
|
|
from ..strings import get_string_id
|
|
from ..attrs import IDS
|
|
|
|
|
|
DEF PADDING = 5
|
|
|
|
|
|
cdef class Matcher:
|
|
"""Match sequences of tokens, based on pattern rules.
|
|
|
|
DOCS: https://spacy.io/api/matcher
|
|
USAGE: https://spacy.io/usage/rule-based-matching
|
|
"""
|
|
|
|
def __init__(self, vocab, validate=False):
|
|
"""Create the Matcher.
|
|
|
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
|
documents the matcher will operate on.
|
|
RETURNS (Matcher): The newly constructed object.
|
|
"""
|
|
self._extra_predicates = []
|
|
self._patterns = {}
|
|
self._callbacks = {}
|
|
self._extensions = {}
|
|
self._seen_attrs = set()
|
|
self.vocab = vocab
|
|
self.mem = Pool()
|
|
if validate:
|
|
self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA)
|
|
else:
|
|
self.validator = None
|
|
|
|
def __reduce__(self):
|
|
data = (self.vocab, self._patterns, self._callbacks)
|
|
return (unpickle_matcher, data, None, None)
|
|
|
|
def __len__(self):
|
|
"""Get the number of rules added to the matcher. Note that this only
|
|
returns the number of rules (identical with the number of IDs), not the
|
|
number of individual patterns.
|
|
|
|
RETURNS (int): The number of rules.
|
|
"""
|
|
return len(self._patterns)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (unicode): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
"""
|
|
return self._normalize_key(key) in self._patterns
|
|
|
|
def add(self, key, on_match, *patterns):
|
|
"""Add a match-rule to the matcher. A match-rule consists of: an ID
|
|
key, an on_match callback, and one or more patterns.
|
|
|
|
If the key exists, the patterns are appended to the previous ones, and
|
|
the previous on_match callback is replaced. The `on_match` callback
|
|
will receive the arguments `(matcher, doc, i, matches)`. You can also
|
|
set `on_match` to `None` to not perform any actions.
|
|
|
|
A pattern consists of one or more `token_specs`, where a `token_spec`
|
|
is a dictionary mapping attribute IDs to values, and optionally a
|
|
quantifier operator under the key "op". The available quantifiers are:
|
|
|
|
'!': Negate the pattern, by requiring it to match exactly 0 times.
|
|
'?': Make the pattern optional, by allowing it to match 0 or 1 times.
|
|
'+': Require the pattern to match 1 or more times.
|
|
'*': Allow the pattern to zero or more times.
|
|
|
|
The + and * operators are usually interpretted "greedily", i.e. longer
|
|
matches are returned where possible. However, if you specify two '+'
|
|
and '*' patterns in a row and their matches overlap, the first
|
|
operator will behave non-greedily. This quirk in the semantics makes
|
|
the matcher more efficient, by avoiding the need for back-tracking.
|
|
|
|
key (unicode): The match ID.
|
|
on_match (callable): Callback executed on match.
|
|
*patterns (list): List of token descriptions.
|
|
"""
|
|
errors = {}
|
|
if on_match is not None and not hasattr(on_match, "__call__"):
|
|
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
|
|
for i, pattern in enumerate(patterns):
|
|
if len(pattern) == 0:
|
|
raise ValueError(Errors.E012.format(key=key))
|
|
if self.validator:
|
|
errors[i] = validate_json(pattern, self.validator)
|
|
if any(err for err in errors.values()):
|
|
raise MatchPatternError(key, errors)
|
|
key = self._normalize_key(key)
|
|
for pattern in patterns:
|
|
try:
|
|
specs = _preprocess_pattern(pattern, self.vocab.strings,
|
|
self._extensions, self._extra_predicates)
|
|
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
|
for spec in specs:
|
|
for attr, _ in spec[1]:
|
|
self._seen_attrs.add(attr)
|
|
except OverflowError, AttributeError:
|
|
raise ValueError(Errors.E154.format())
|
|
self._patterns.setdefault(key, [])
|
|
self._callbacks[key] = on_match
|
|
self._patterns[key].extend(patterns)
|
|
|
|
def remove(self, key):
|
|
"""Remove a rule from the matcher. A KeyError is raised if the key does
|
|
not exist.
|
|
|
|
key (unicode): The ID of the match rule.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
self._patterns.pop(key)
|
|
self._callbacks.pop(key)
|
|
cdef int i = 0
|
|
while i < self.patterns.size():
|
|
pattern_key = get_pattern_key(self.patterns.at(i))
|
|
if pattern_key == key:
|
|
self.patterns.erase(self.patterns.begin()+i)
|
|
else:
|
|
i += 1
|
|
|
|
def has_key(self, key):
|
|
"""Check whether the matcher has a rule with a given key.
|
|
|
|
key (string or int): The key to check.
|
|
RETURNS (bool): Whether the matcher has the rule.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
return key in self._patterns
|
|
|
|
def get(self, key, default=None):
|
|
"""Retrieve the pattern stored for a key.
|
|
|
|
key (unicode or int): The key to retrieve.
|
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
if key not in self._patterns:
|
|
return default
|
|
return (self._callbacks[key], self._patterns[key])
|
|
|
|
def pipe(self, docs, batch_size=1000, n_threads=-1, return_matches=False,
|
|
as_tuples=False):
|
|
"""Match a stream of documents, yielding them in turn.
|
|
|
|
docs (iterable): A stream of documents.
|
|
batch_size (int): Number of documents to accumulate into a working set.
|
|
return_matches (bool): Yield the match lists along with the docs, making
|
|
results (doc, matches) tuples.
|
|
as_tuples (bool): Interpret the input stream as (doc, context) tuples,
|
|
and yield (result, context) tuples out.
|
|
If both return_matches and as_tuples are True, the output will
|
|
be a sequence of ((doc, matches), context) tuples.
|
|
YIELDS (Doc): Documents, in order.
|
|
"""
|
|
if n_threads != -1:
|
|
deprecation_warning(Warnings.W016)
|
|
|
|
if as_tuples:
|
|
for doc, context in docs:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield ((doc, matches), context)
|
|
else:
|
|
yield (doc, context)
|
|
else:
|
|
for doc in docs:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield (doc, matches)
|
|
else:
|
|
yield doc
|
|
|
|
def __call__(self, Doc doc):
|
|
"""Find all token sequences matching the supplied pattern.
|
|
|
|
doc (Doc): The document to match over.
|
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
|
"""
|
|
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
|
and not doc.is_tagged:
|
|
raise ValueError(Errors.E155.format())
|
|
if DEP in self._seen_attrs and not doc.is_parsed:
|
|
raise ValueError(Errors.E156.format())
|
|
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
|
|
extensions=self._extensions,
|
|
predicates=self._extra_predicates)
|
|
for i, (key, start, end) in enumerate(matches):
|
|
on_match = self._callbacks.get(key, None)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matches)
|
|
return matches
|
|
|
|
def _normalize_key(self, key):
|
|
if isinstance(key, basestring):
|
|
return self.vocab.strings.add(key)
|
|
else:
|
|
return key
|
|
|
|
|
|
def unpickle_matcher(vocab, patterns, callbacks):
|
|
matcher = Matcher(vocab)
|
|
for key, specs in patterns.items():
|
|
callback = callbacks.get(key, None)
|
|
matcher.add(key, callback, *specs)
|
|
return matcher
|
|
|
|
|
|
|
|
cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
|
predicates=tuple()):
|
|
"""Find matches in a doc, with a compiled array of patterns. Matches are
|
|
returned as a list of (id, start, end) tuples.
|
|
|
|
To augment the compiled patterns, we optionally also take two Python lists.
|
|
|
|
The "predicates" list contains functions that take a Python list and return a
|
|
boolean value. It's mostly used for regular expressions.
|
|
|
|
The "extra_getters" list contains functions that take a Python list and return
|
|
an attr ID. It's mostly used for extension attributes.
|
|
"""
|
|
cdef vector[PatternStateC] states
|
|
cdef vector[MatchC] matches
|
|
cdef PatternStateC state
|
|
cdef int i, j, nr_extra_attr
|
|
cdef Pool mem = Pool()
|
|
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
|
|
if extensions is not None and len(extensions) >= 1:
|
|
nr_extra_attr = max(extensions.values()) + 1
|
|
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
|
|
else:
|
|
nr_extra_attr = 0
|
|
extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t))
|
|
for i, token in enumerate(doc):
|
|
for name, index in extensions.items():
|
|
value = token._.get(name)
|
|
if isinstance(value, basestring):
|
|
value = token.vocab.strings[value]
|
|
extra_attr_values[i * nr_extra_attr + index] = value
|
|
# Main loop
|
|
cdef int nr_predicate = len(predicates)
|
|
for i in range(doc.length):
|
|
for j in range(n):
|
|
states.push_back(PatternStateC(patterns[j], i, 0))
|
|
transition_states(states, matches, predicate_cache,
|
|
doc[i], extra_attr_values, predicates)
|
|
extra_attr_values += nr_extra_attr
|
|
predicate_cache += len(predicates)
|
|
# Handle matches that end in 0-width patterns
|
|
finish_states(matches, states)
|
|
output = []
|
|
seen = set()
|
|
for i in range(matches.size()):
|
|
match = (
|
|
matches[i].pattern_id,
|
|
matches[i].start,
|
|
matches[i].start+matches[i].length
|
|
)
|
|
# We need to deduplicate, because we could otherwise arrive at the same
|
|
# match through two paths, e.g. .?.? matching 'a'. Are we matching the
|
|
# first .?, or the second .? -- it doesn't matter, it's just one match.
|
|
if match not in seen:
|
|
output.append(match)
|
|
seen.add(match)
|
|
return output
|
|
|
|
|
|
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
|
# There have been a few bugs here.
|
|
# The code was originally designed to always have pattern[1].attrs.value
|
|
# be the ent_id when we get to the end of a pattern. However, Issue #2671
|
|
# showed this wasn't the case when we had a reject-and-continue before a
|
|
# match.
|
|
# The patch to #2671 was wrong though, which came up in #3839.
|
|
while pattern.attrs.attr != ID:
|
|
pattern += 1
|
|
return pattern.attrs.value
|
|
|
|
|
|
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
|
char* cached_py_predicates,
|
|
Token token, const attr_t* extra_attrs, py_predicates) except *:
|
|
cdef int q = 0
|
|
cdef vector[PatternStateC] new_states
|
|
cdef int nr_predicate = len(py_predicates)
|
|
for i in range(states.size()):
|
|
if states[i].pattern.nr_py >= 1:
|
|
update_predicate_cache(cached_py_predicates,
|
|
states[i].pattern, token, py_predicates)
|
|
action = get_action(states[i], token.c, extra_attrs,
|
|
cached_py_predicates)
|
|
if action == REJECT:
|
|
continue
|
|
# Keep only a subset of states (the active ones). Index q is the
|
|
# states which are still alive. If we reject a state, we overwrite
|
|
# it in the states list, because q doesn't advance.
|
|
state = states[i]
|
|
states[q] = state
|
|
while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
|
|
if action == RETRY_EXTEND:
|
|
# This handles the 'extend'
|
|
new_states.push_back(
|
|
PatternStateC(pattern=states[q].pattern, start=state.start,
|
|
length=state.length+1))
|
|
if action == RETRY_ADVANCE:
|
|
# This handles the 'advance'
|
|
new_states.push_back(
|
|
PatternStateC(pattern=states[q].pattern+1, start=state.start,
|
|
length=state.length+1))
|
|
states[q].pattern += 1
|
|
if states[q].pattern.nr_py != 0:
|
|
update_predicate_cache(cached_py_predicates,
|
|
states[q].pattern, token, py_predicates)
|
|
action = get_action(states[q], token.c, extra_attrs,
|
|
cached_py_predicates)
|
|
if action == REJECT:
|
|
pass
|
|
elif action == ADVANCE:
|
|
states[q].pattern += 1
|
|
states[q].length += 1
|
|
q += 1
|
|
else:
|
|
ent_id = get_ent_id(state.pattern)
|
|
if action == MATCH:
|
|
matches.push_back(
|
|
MatchC(pattern_id=ent_id, start=state.start,
|
|
length=state.length+1))
|
|
elif action == MATCH_DOUBLE:
|
|
# push match without last token if length > 0
|
|
if state.length > 0:
|
|
matches.push_back(
|
|
MatchC(pattern_id=ent_id, start=state.start,
|
|
length=state.length))
|
|
# push match with last token
|
|
matches.push_back(
|
|
MatchC(pattern_id=ent_id, start=state.start,
|
|
length=state.length+1))
|
|
elif action == MATCH_REJECT:
|
|
matches.push_back(
|
|
MatchC(pattern_id=ent_id, start=state.start,
|
|
length=state.length))
|
|
elif action == MATCH_EXTEND:
|
|
matches.push_back(
|
|
MatchC(pattern_id=ent_id, start=state.start,
|
|
length=state.length))
|
|
states[q].length += 1
|
|
q += 1
|
|
states.resize(q)
|
|
for i in range(new_states.size()):
|
|
states.push_back(new_states[i])
|
|
|
|
|
|
cdef int update_predicate_cache(char* cache,
|
|
const TokenPatternC* pattern, Token token, predicates) except -1:
|
|
# If the state references any extra predicates, check whether they match.
|
|
# These are cached, so that we don't call these potentially expensive
|
|
# Python functions more than we need to.
|
|
for i in range(pattern.nr_py):
|
|
index = pattern.py_predicates[i]
|
|
if cache[index] == 0:
|
|
predicate = predicates[index]
|
|
result = predicate(token)
|
|
if result is True:
|
|
cache[index] = 1
|
|
elif result is False:
|
|
cache[index] = -1
|
|
elif result is None:
|
|
pass
|
|
else:
|
|
raise ValueError(Errors.E125.format(value=result))
|
|
|
|
|
|
cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *:
|
|
"""Handle states that end in zero-width patterns."""
|
|
cdef PatternStateC state
|
|
for i in range(states.size()):
|
|
state = states[i]
|
|
while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
|
|
is_final = get_is_final(state)
|
|
if is_final:
|
|
ent_id = get_ent_id(state.pattern)
|
|
matches.push_back(
|
|
MatchC(pattern_id=ent_id, start=state.start, length=state.length))
|
|
break
|
|
else:
|
|
state.pattern += 1
|
|
|
|
|
|
cdef action_t get_action(PatternStateC state,
|
|
const TokenC* token, const attr_t* extra_attrs,
|
|
const char* predicate_matches) nogil:
|
|
"""We need to consider:
|
|
a) Does the token match the specification? [Yes, No]
|
|
b) What's the quantifier? [1, 0+, ?]
|
|
c) Is this the last specification? [final, non-final]
|
|
|
|
We can transition in the following ways:
|
|
a) Do we emit a match?
|
|
b) Do we add a state with (next state, next token)?
|
|
c) Do we add a state with (next state, same token)?
|
|
d) Do we add a state with (same state, next token)?
|
|
|
|
We'll code the actions as boolean strings, so 0000 means no to all 4,
|
|
1000 means match but no states added, etc.
|
|
|
|
1:
|
|
Yes, final:
|
|
1000
|
|
Yes, non-final:
|
|
0100
|
|
No, final:
|
|
0000
|
|
No, non-final
|
|
0000
|
|
0+:
|
|
Yes, final:
|
|
1001
|
|
Yes, non-final:
|
|
0011
|
|
No, final:
|
|
1000 (note: Don't include last token!)
|
|
No, non-final:
|
|
0010
|
|
?:
|
|
Yes, final:
|
|
1000
|
|
Yes, non-final:
|
|
0100
|
|
No, final:
|
|
1000 (note: Don't include last token!)
|
|
No, non-final:
|
|
0010
|
|
|
|
Possible combinations: 1000, 0100, 0000, 1001, 0110, 0011, 0010,
|
|
|
|
We'll name the bits "match", "advance", "retry", "extend"
|
|
REJECT = 0000
|
|
MATCH = 1000
|
|
ADVANCE = 0100
|
|
RETRY = 0010
|
|
MATCH_EXTEND = 1001
|
|
RETRY_ADVANCE = 0110
|
|
RETRY_EXTEND = 0011
|
|
MATCH_REJECT = 2000 # Match, but don't include last token
|
|
MATCH_DOUBLE = 3000 # Match both with and without last token
|
|
|
|
Problem: If a quantifier is matching, we're adding a lot of open partials
|
|
"""
|
|
cdef char is_match
|
|
is_match = get_is_match(state, token, extra_attrs, predicate_matches)
|
|
quantifier = get_quantifier(state)
|
|
is_final = get_is_final(state)
|
|
if quantifier == ZERO:
|
|
is_match = not is_match
|
|
quantifier = ONE
|
|
if quantifier == ONE:
|
|
if is_match and is_final:
|
|
# Yes, final: 1000
|
|
return MATCH
|
|
elif is_match and not is_final:
|
|
# Yes, non-final: 0100
|
|
return ADVANCE
|
|
elif not is_match and is_final:
|
|
# No, final: 0000
|
|
return REJECT
|
|
else:
|
|
return REJECT
|
|
elif quantifier == ZERO_PLUS:
|
|
if is_match and is_final:
|
|
# Yes, final: 1001
|
|
return MATCH_EXTEND
|
|
elif is_match and not is_final:
|
|
# Yes, non-final: 0011
|
|
return RETRY_EXTEND
|
|
elif not is_match and is_final:
|
|
# No, final 2000 (note: Don't include last token!)
|
|
return MATCH_REJECT
|
|
else:
|
|
# No, non-final 0010
|
|
return RETRY
|
|
elif quantifier == ZERO_ONE:
|
|
if is_match and is_final:
|
|
# Yes, final: 3000
|
|
# To cater for a pattern ending in "?", we need to add
|
|
# a match both with and without the last token
|
|
return MATCH_DOUBLE
|
|
elif is_match and not is_final:
|
|
# Yes, non-final: 0110
|
|
# We need both branches here, consider a pair like:
|
|
# pattern: .?b string: b
|
|
# If we 'ADVANCE' on the .?, we miss the match.
|
|
return RETRY_ADVANCE
|
|
elif not is_match and is_final:
|
|
# No, final 2000 (note: Don't include last token!)
|
|
return MATCH_REJECT
|
|
else:
|
|
# No, non-final 0010
|
|
return RETRY
|
|
|
|
|
|
cdef char get_is_match(PatternStateC state,
|
|
const TokenC* token, const attr_t* extra_attrs,
|
|
const char* predicate_matches) nogil:
|
|
for i in range(state.pattern.nr_py):
|
|
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
|
return 0
|
|
spec = state.pattern
|
|
for attr in spec.attrs[:spec.nr_attr]:
|
|
if get_token_attr(token, attr.attr) != attr.value:
|
|
return 0
|
|
for i in range(spec.nr_extra_attr):
|
|
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
|
|
return 0
|
|
return True
|
|
|
|
|
|
cdef char get_is_final(PatternStateC state) nogil:
|
|
if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0:
|
|
return 1
|
|
else:
|
|
return 0
|
|
|
|
|
|
cdef char get_quantifier(PatternStateC state) nogil:
|
|
return state.pattern.quantifier
|
|
|
|
|
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
|
|
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
|
cdef int i, index
|
|
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
|
pattern[i].quantifier = quantifier
|
|
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
|
pattern[i].nr_attr = len(spec)
|
|
for j, (attr, value) in enumerate(spec):
|
|
pattern[i].attrs[j].attr = attr
|
|
pattern[i].attrs[j].value = value
|
|
pattern[i].extra_attrs = <IndexValueC*>mem.alloc(len(extensions), sizeof(IndexValueC))
|
|
for j, (index, value) in enumerate(extensions):
|
|
pattern[i].extra_attrs[j].index = index
|
|
pattern[i].extra_attrs[j].value = value
|
|
pattern[i].nr_extra_attr = len(extensions)
|
|
pattern[i].py_predicates = <int32_t*>mem.alloc(len(predicates), sizeof(int32_t))
|
|
for j, index in enumerate(predicates):
|
|
pattern[i].py_predicates[j] = index
|
|
pattern[i].nr_py = len(predicates)
|
|
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
|
|
i = len(token_specs)
|
|
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
|
pattern[i].attrs[0].attr = ID
|
|
pattern[i].attrs[0].value = entity_id
|
|
pattern[i].nr_attr = 0
|
|
pattern[i].nr_extra_attr = 0
|
|
pattern[i].nr_py = 0
|
|
return pattern
|
|
|
|
|
|
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
|
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
|
|
pattern += 1
|
|
id_attr = pattern[0].attrs[0]
|
|
if id_attr.attr != ID:
|
|
with gil:
|
|
raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
|
|
return id_attr.value
|
|
|
|
|
|
def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predicates):
|
|
"""This function interprets the pattern, converting the various bits of
|
|
syntactic sugar before we compile it into a struct with init_pattern.
|
|
|
|
We need to split the pattern up into three parts:
|
|
* Normal attribute/value pairs, which are stored on either the token or lexeme,
|
|
can be handled directly.
|
|
* Extension attributes are handled specially, as we need to prefetch the
|
|
values from Python for the doc before we begin matching.
|
|
* Extra predicates also call Python functions, so we have to create the
|
|
functions and store them. So we store these specially as well.
|
|
* Extension attributes that have extra predicates are stored within the
|
|
extra_predicates.
|
|
"""
|
|
tokens = []
|
|
for spec in token_specs:
|
|
if not spec:
|
|
# Signifier for 'any token'
|
|
tokens.append((ONE, [(NULL_ATTR, 0)], [], []))
|
|
continue
|
|
if not isinstance(spec, dict):
|
|
raise ValueError(Errors.E154.format())
|
|
ops = _get_operators(spec)
|
|
attr_values = _get_attr_values(spec, string_store)
|
|
extensions = _get_extensions(spec, string_store, extensions_table)
|
|
predicates = _get_extra_predicates(spec, extra_predicates)
|
|
for op in ops:
|
|
tokens.append((op, list(attr_values), list(extensions), list(predicates)))
|
|
return tokens
|
|
|
|
|
|
def _get_attr_values(spec, string_store):
|
|
attr_values = []
|
|
for attr, value in spec.items():
|
|
if isinstance(attr, basestring):
|
|
attr = attr.upper()
|
|
if attr == '_':
|
|
continue
|
|
elif attr == "OP":
|
|
continue
|
|
if attr == "TEXT":
|
|
attr = "ORTH"
|
|
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
|
raise ValueError(Errors.E152.format(attr=attr))
|
|
attr = IDS.get(attr)
|
|
if isinstance(value, basestring):
|
|
value = string_store.add(value)
|
|
elif isinstance(value, bool):
|
|
value = int(value)
|
|
elif isinstance(value, dict):
|
|
continue
|
|
else:
|
|
raise ValueError(Errors.E153.format(vtype=type(value).__name__))
|
|
if attr is not None:
|
|
attr_values.append((attr, value))
|
|
else:
|
|
# should be caught above using TOKEN_PATTERN_SCHEMA
|
|
raise ValueError(Errors.E152.format(attr=attr))
|
|
return attr_values
|
|
|
|
|
|
# These predicate helper classes are used to match the REGEX, IN, >= etc
|
|
# extensions to the matcher introduced in #3173.
|
|
|
|
class _RegexPredicate(object):
|
|
operators = ("REGEX",)
|
|
|
|
def __init__(self, i, attr, value, predicate, is_extension=False):
|
|
self.i = i
|
|
self.attr = attr
|
|
self.value = re.compile(value)
|
|
self.predicate = predicate
|
|
self.is_extension = is_extension
|
|
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
|
if self.predicate not in self.operators:
|
|
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
|
|
|
def __call__(self, Token token):
|
|
if self.is_extension:
|
|
value = token._.get(self.attr)
|
|
else:
|
|
value = token.vocab.strings[get_token_attr(token.c, self.attr)]
|
|
return bool(self.value.search(value))
|
|
|
|
|
|
class _SetMemberPredicate(object):
|
|
operators = ("IN", "NOT_IN")
|
|
|
|
def __init__(self, i, attr, value, predicate, is_extension=False):
|
|
self.i = i
|
|
self.attr = attr
|
|
self.value = set(get_string_id(v) for v in value)
|
|
self.predicate = predicate
|
|
self.is_extension = is_extension
|
|
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
|
if self.predicate not in self.operators:
|
|
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
|
|
|
def __call__(self, Token token):
|
|
if self.is_extension:
|
|
value = get_string_id(token._.get(self.attr))
|
|
else:
|
|
value = get_token_attr(token.c, self.attr)
|
|
if self.predicate == "IN":
|
|
return value in self.value
|
|
else:
|
|
return value not in self.value
|
|
|
|
def __repr__(self):
|
|
return repr(("SetMemberPredicate", self.i, self.attr, self.value, self.predicate))
|
|
|
|
|
|
class _ComparisonPredicate(object):
|
|
operators = ("==", "!=", ">=", "<=", ">", "<")
|
|
|
|
def __init__(self, i, attr, value, predicate, is_extension=False):
|
|
self.i = i
|
|
self.attr = attr
|
|
self.value = value
|
|
self.predicate = predicate
|
|
self.is_extension = is_extension
|
|
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
|
if self.predicate not in self.operators:
|
|
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
|
|
|
def __call__(self, Token token):
|
|
if self.is_extension:
|
|
value = token._.get(self.attr)
|
|
else:
|
|
value = get_token_attr(token.c, self.attr)
|
|
if self.predicate == "==":
|
|
return value == self.value
|
|
if self.predicate == "!=":
|
|
return value != self.value
|
|
elif self.predicate == ">=":
|
|
return value >= self.value
|
|
elif self.predicate == "<=":
|
|
return value <= self.value
|
|
elif self.predicate == ">":
|
|
return value > self.value
|
|
elif self.predicate == "<":
|
|
return value < self.value
|
|
|
|
|
|
def _get_extra_predicates(spec, extra_predicates):
|
|
predicate_types = {
|
|
"REGEX": _RegexPredicate,
|
|
"IN": _SetMemberPredicate,
|
|
"NOT_IN": _SetMemberPredicate,
|
|
"==": _ComparisonPredicate,
|
|
">=": _ComparisonPredicate,
|
|
"<=": _ComparisonPredicate,
|
|
">": _ComparisonPredicate,
|
|
"<": _ComparisonPredicate,
|
|
}
|
|
seen_predicates = {pred.key: pred.i for pred in extra_predicates}
|
|
output = []
|
|
for attr, value in spec.items():
|
|
if isinstance(attr, basestring):
|
|
if attr == "_":
|
|
output.extend(
|
|
_get_extension_extra_predicates(
|
|
value, extra_predicates, predicate_types,
|
|
seen_predicates))
|
|
continue
|
|
elif attr.upper() == "OP":
|
|
continue
|
|
if attr.upper() == "TEXT":
|
|
attr = "ORTH"
|
|
attr = IDS.get(attr.upper())
|
|
if isinstance(value, dict):
|
|
for type_, cls in predicate_types.items():
|
|
if type_ in value:
|
|
predicate = cls(len(extra_predicates), attr, value[type_], type_)
|
|
# Don't create a redundant predicates.
|
|
# This helps with efficiency, as we're caching the results.
|
|
if predicate.key in seen_predicates:
|
|
output.append(seen_predicates[predicate.key])
|
|
else:
|
|
extra_predicates.append(predicate)
|
|
output.append(predicate.i)
|
|
seen_predicates[predicate.key] = predicate.i
|
|
return output
|
|
|
|
|
|
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
|
|
seen_predicates):
|
|
output = []
|
|
for attr, value in spec.items():
|
|
if isinstance(value, dict):
|
|
for type_, cls in predicate_types.items():
|
|
if type_ in value:
|
|
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
|
|
if key in seen_predicates:
|
|
output.append(seen_predicates[key])
|
|
else:
|
|
predicate = cls(len(extra_predicates), attr, value[type_], type_,
|
|
is_extension=True)
|
|
extra_predicates.append(predicate)
|
|
output.append(predicate.i)
|
|
seen_predicates[key] = predicate.i
|
|
return output
|
|
|
|
|
|
def _get_operators(spec):
|
|
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
|
lookup = {"*": (ZERO_PLUS,), "+": (ONE, ZERO_PLUS),
|
|
"?": (ZERO_ONE,), "1": (ONE,), "!": (ZERO,)}
|
|
# Fix casing
|
|
spec = {key.upper(): values for key, values in spec.items()
|
|
if isinstance(key, basestring)}
|
|
if "OP" not in spec:
|
|
return (ONE,)
|
|
elif spec["OP"] in lookup:
|
|
return lookup[spec["OP"]]
|
|
else:
|
|
keys = ", ".join(lookup.keys())
|
|
raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))
|
|
|
|
|
|
def _get_extensions(spec, string_store, name2index):
|
|
attr_values = []
|
|
if not isinstance(spec.get("_", {}), dict):
|
|
raise ValueError(Errors.E154.format())
|
|
for name, value in spec.get("_", {}).items():
|
|
if isinstance(value, dict):
|
|
# Handle predicates (e.g. "IN", in the extra_predicates, not here.
|
|
continue
|
|
if isinstance(value, basestring):
|
|
value = string_store.add(value)
|
|
if name not in name2index:
|
|
name2index[name] = len(name2index)
|
|
attr_values.append((name2index[name], value))
|
|
return attr_values
|