mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	* Add missing int value option to top-level pattern validation in Matcher
* Adjust existing tests accordingly
* Add new test for valid pattern `{"LENGTH": int}`
		
	
			
		
			
				
	
	
		
			846 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			846 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True
 | |
| # cython: profile=True
 | |
| from __future__ import unicode_literals
 | |
| 
 | |
| from libcpp.vector cimport vector
 | |
| from libc.stdint cimport int32_t
 | |
| from cymem.cymem cimport Pool
 | |
| from murmurhash.mrmr cimport hash64
 | |
| 
 | |
| import re
 | |
| import srsly
 | |
| 
 | |
| from ..typedefs cimport attr_t
 | |
| from ..structs cimport TokenC
 | |
| from ..vocab cimport Vocab
 | |
| from ..tokens.doc cimport Doc, get_token_attr
 | |
| from ..tokens.token cimport Token
 | |
| from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
 | |
| 
 | |
| from ._schemas import TOKEN_PATTERN_SCHEMA
 | |
| from ..util import get_json_validator, validate_json
 | |
| from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning
 | |
| from ..strings import get_string_id
 | |
| from ..attrs import IDS
 | |
| 
 | |
| 
 | |
| DEF PADDING = 5
 | |
| 
 | |
| 
 | |
| cdef class Matcher:
 | |
|     """Match sequences of tokens, based on pattern rules.
 | |
| 
 | |
|     DOCS: https://spacy.io/api/matcher
 | |
|     USAGE: https://spacy.io/usage/rule-based-matching
 | |
|     """
 | |
| 
 | |
|     def __init__(self, vocab, validate=False):
 | |
|         """Create the Matcher.
 | |
| 
 | |
|         vocab (Vocab): The vocabulary object, which must be shared with the
 | |
|             documents the matcher will operate on.
 | |
|         RETURNS (Matcher): The newly constructed object.
 | |
|         """
 | |
|         self._extra_predicates = []
 | |
|         self._patterns = {}
 | |
|         self._callbacks = {}
 | |
|         self._extensions = {}
 | |
|         self._seen_attrs = set()
 | |
|         self.vocab = vocab
 | |
|         self.mem = Pool()
 | |
|         if validate:
 | |
|             self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA)
 | |
|         else:
 | |
|             self.validator = None
 | |
| 
 | |
|     def __reduce__(self):
 | |
|         data = (self.vocab, self._patterns, self._callbacks)
 | |
|         return (unpickle_matcher, data, None, None)
 | |
| 
 | |
|     def __len__(self):
 | |
|         """Get the number of rules added to the matcher. Note that this only
 | |
|         returns the number of rules (identical with the number of IDs), not the
 | |
|         number of individual patterns.
 | |
| 
 | |
|         RETURNS (int): The number of rules.
 | |
|         """
 | |
|         return len(self._patterns)
 | |
| 
 | |
|     def __contains__(self, key):
 | |
|         """Check whether the matcher contains rules for a match ID.
 | |
| 
 | |
|         key (unicode): The match ID.
 | |
|         RETURNS (bool): Whether the matcher contains rules for this match ID.
 | |
|         """
 | |
|         return self._normalize_key(key) in self._patterns
 | |
| 
 | |
|     def add(self, key, on_match, *patterns):
 | |
|         """Add a match-rule to the matcher. A match-rule consists of: an ID
 | |
|         key, an on_match callback, and one or more patterns.
 | |
| 
 | |
|         If the key exists, the patterns are appended to the previous ones, and
 | |
|         the previous on_match callback is replaced. The `on_match` callback
 | |
|         will receive the arguments `(matcher, doc, i, matches)`. You can also
 | |
|         set `on_match` to `None` to not perform any actions.
 | |
| 
 | |
|         A pattern consists of one or more `token_specs`, where a `token_spec`
 | |
|         is a dictionary mapping attribute IDs to values, and optionally a
 | |
|         quantifier operator under the key "op". The available quantifiers are:
 | |
| 
 | |
|         '!': Negate the pattern, by requiring it to match exactly 0 times.
 | |
|         '?': Make the pattern optional, by allowing it to match 0 or 1 times.
 | |
|         '+': Require the pattern to match 1 or more times.
 | |
|         '*': Allow the pattern to zero or more times.
 | |
| 
 | |
|         The + and * operators are usually interpretted "greedily", i.e. longer
 | |
|         matches are returned where possible. However, if you specify two '+'
 | |
|         and '*' patterns in a row and their matches overlap, the first
 | |
|         operator will behave non-greedily. This quirk in the semantics makes
 | |
|         the matcher more efficient, by avoiding the need for back-tracking.
 | |
| 
 | |
|         key (unicode): The match ID.
 | |
|         on_match (callable): Callback executed on match.
 | |
|         *patterns (list): List of token descriptions.
 | |
|         """
 | |
|         errors = {}
 | |
|         if on_match is not None and not hasattr(on_match, "__call__"):
 | |
|             raise ValueError(Errors.E171.format(arg_type=type(on_match)))
 | |
|         for i, pattern in enumerate(patterns):
 | |
|             if len(pattern) == 0:
 | |
|                 raise ValueError(Errors.E012.format(key=key))
 | |
|             if self.validator:
 | |
|                 errors[i] = validate_json(pattern, self.validator)
 | |
|         if any(err for err in errors.values()):
 | |
|             raise MatchPatternError(key, errors)
 | |
|         key = self._normalize_key(key)
 | |
|         for pattern in patterns:
 | |
|             try:
 | |
|                 specs = _preprocess_pattern(pattern, self.vocab.strings,
 | |
|                     self._extensions, self._extra_predicates)
 | |
|                 self.patterns.push_back(init_pattern(self.mem, key, specs))
 | |
|                 for spec in specs:
 | |
|                     for attr, _ in spec[1]:
 | |
|                         self._seen_attrs.add(attr)
 | |
|             except OverflowError, AttributeError:
 | |
|                 raise ValueError(Errors.E154.format())
 | |
|         self._patterns.setdefault(key, [])
 | |
|         self._callbacks[key] = on_match
 | |
|         self._patterns[key].extend(patterns)
 | |
| 
 | |
|     def remove(self, key):
 | |
|         """Remove a rule from the matcher. A KeyError is raised if the key does
 | |
|         not exist.
 | |
| 
 | |
|         key (unicode): The ID of the match rule.
 | |
|         """
 | |
|         norm_key = self._normalize_key(key)
 | |
|         if not norm_key in self._patterns:
 | |
|             raise ValueError(Errors.E175.format(key=key))
 | |
|         self._patterns.pop(norm_key)
 | |
|         self._callbacks.pop(norm_key)
 | |
|         cdef int i = 0
 | |
|         while i < self.patterns.size():
 | |
|             pattern_key = get_ent_id(self.patterns.at(i))
 | |
|             if pattern_key == norm_key:
 | |
|                 self.patterns.erase(self.patterns.begin()+i)
 | |
|             else:
 | |
|                 i += 1
 | |
| 
 | |
|     def has_key(self, key):
 | |
|         """Check whether the matcher has a rule with a given key.
 | |
| 
 | |
|         key (string or int): The key to check.
 | |
|         RETURNS (bool): Whether the matcher has the rule.
 | |
|         """
 | |
|         key = self._normalize_key(key)
 | |
|         return key in self._patterns
 | |
| 
 | |
|     def get(self, key, default=None):
 | |
|         """Retrieve the pattern stored for a key.
 | |
| 
 | |
|         key (unicode or int): The key to retrieve.
 | |
|         RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
 | |
|         """
 | |
|         key = self._normalize_key(key)
 | |
|         if key not in self._patterns:
 | |
|             return default
 | |
|         return (self._callbacks[key], self._patterns[key])
 | |
| 
 | |
|     def pipe(self, docs, batch_size=1000, n_threads=-1, return_matches=False,
 | |
|              as_tuples=False):
 | |
|         """Match a stream of documents, yielding them in turn.
 | |
| 
 | |
|         docs (iterable): A stream of documents.
 | |
|         batch_size (int): Number of documents to accumulate into a working set.
 | |
|         return_matches (bool): Yield the match lists along with the docs, making
 | |
|             results (doc, matches) tuples.
 | |
|         as_tuples (bool): Interpret the input stream as (doc, context) tuples,
 | |
|             and yield (result, context) tuples out.
 | |
|             If both return_matches and as_tuples are True, the output will
 | |
|             be a sequence of ((doc, matches), context) tuples.
 | |
|         YIELDS (Doc): Documents, in order.
 | |
|         """
 | |
|         if n_threads != -1:
 | |
|             deprecation_warning(Warnings.W016)
 | |
| 
 | |
|         if as_tuples:
 | |
|             for doc, context in docs:
 | |
|                 matches = self(doc)
 | |
|                 if return_matches:
 | |
|                     yield ((doc, matches), context)
 | |
|                 else:
 | |
|                     yield (doc, context)
 | |
|         else:
 | |
|             for doc in docs:
 | |
|                 matches = self(doc)
 | |
|                 if return_matches:
 | |
|                     yield (doc, matches)
 | |
|                 else:
 | |
|                     yield doc
 | |
| 
 | |
|     def __call__(self, Doc doc):
 | |
|         """Find all token sequences matching the supplied pattern.
 | |
| 
 | |
|         doc (Doc): The document to match over.
 | |
|         RETURNS (list): A list of `(key, start, end)` tuples,
 | |
|             describing the matches. A match tuple describes a span
 | |
|             `doc[start:end]`. The `label_id` and `key` are both integers.
 | |
|         """
 | |
|         if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
 | |
|           and not doc.is_tagged:
 | |
|             raise ValueError(Errors.E155.format())
 | |
|         if DEP in self._seen_attrs and not doc.is_parsed:
 | |
|             raise ValueError(Errors.E156.format())
 | |
|         matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
 | |
|                                extensions=self._extensions,
 | |
|                                predicates=self._extra_predicates)
 | |
|         for i, (key, start, end) in enumerate(matches):
 | |
|             on_match = self._callbacks.get(key, None)
 | |
|             if on_match is not None:
 | |
|                 on_match(self, doc, i, matches)
 | |
|         return matches
 | |
| 
 | |
|     def _normalize_key(self, key):
 | |
|         if isinstance(key, basestring):
 | |
|             return self.vocab.strings.add(key)
 | |
|         else:
 | |
|             return key
 | |
| 
 | |
| 
 | |
| def unpickle_matcher(vocab, patterns, callbacks):
 | |
|     matcher = Matcher(vocab)
 | |
|     for key, specs in patterns.items():
 | |
|         callback = callbacks.get(key, None)
 | |
|         matcher.add(key, callback, *specs)
 | |
|     return matcher
 | |
| 
 | |
| 
 | |
| 
 | |
| cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
 | |
|         predicates=tuple()):
 | |
|     """Find matches in a doc, with a compiled array of patterns. Matches are
 | |
|     returned as a list of (id, start, end) tuples.
 | |
| 
 | |
|     To augment the compiled patterns, we optionally also take two Python lists.
 | |
| 
 | |
|     The "predicates" list contains functions that take a Python list and return a
 | |
|     boolean value. It's mostly used for regular expressions.
 | |
| 
 | |
|     The "extra_getters" list contains functions that take a Python list and return
 | |
|     an attr ID. It's mostly used for extension attributes.
 | |
|     """
 | |
|     cdef vector[PatternStateC] states
 | |
|     cdef vector[MatchC] matches
 | |
|     cdef PatternStateC state
 | |
|     cdef int i, j, nr_extra_attr
 | |
|     cdef Pool mem = Pool()
 | |
|     predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
 | |
|     if extensions is not None and len(extensions) >= 1:
 | |
|         nr_extra_attr = max(extensions.values()) + 1
 | |
|         extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
 | |
|     else:
 | |
|         nr_extra_attr = 0
 | |
|         extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t))
 | |
|     for i, token in enumerate(doc):
 | |
|         for name, index in extensions.items():
 | |
|             value = token._.get(name)
 | |
|             if isinstance(value, basestring):
 | |
|                 value = token.vocab.strings[value]
 | |
|             extra_attr_values[i * nr_extra_attr + index] = value
 | |
|     # Main loop
 | |
|     cdef int nr_predicate = len(predicates)
 | |
|     for i in range(doc.length):
 | |
|         for j in range(n):
 | |
|             states.push_back(PatternStateC(patterns[j], i, 0))
 | |
|         transition_states(states, matches, predicate_cache,
 | |
|             doc[i], extra_attr_values, predicates)
 | |
|         extra_attr_values += nr_extra_attr
 | |
|         predicate_cache += len(predicates)
 | |
|     # Handle matches that end in 0-width patterns
 | |
|     finish_states(matches, states)
 | |
|     output = []
 | |
|     seen = set()
 | |
|     for i in range(matches.size()):
 | |
|         match = (
 | |
|             matches[i].pattern_id,
 | |
|             matches[i].start,
 | |
|             matches[i].start+matches[i].length
 | |
|         )
 | |
|         # We need to deduplicate, because we could otherwise arrive at the same
 | |
|         # match through two paths, e.g. .?.? matching 'a'. Are we matching the
 | |
|         # first .?, or the second .? -- it doesn't matter, it's just one match.
 | |
|         if match not in seen:
 | |
|             output.append(match)
 | |
|             seen.add(match)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
 | |
|                             char* cached_py_predicates,
 | |
|         Token token, const attr_t* extra_attrs, py_predicates) except *:
 | |
|     cdef int q = 0
 | |
|     cdef vector[PatternStateC] new_states
 | |
|     cdef int nr_predicate = len(py_predicates)
 | |
|     for i in range(states.size()):
 | |
|         if states[i].pattern.nr_py >= 1:
 | |
|             update_predicate_cache(cached_py_predicates,
 | |
|                 states[i].pattern, token, py_predicates)
 | |
|         action = get_action(states[i], token.c, extra_attrs,
 | |
|                             cached_py_predicates)
 | |
|         if action == REJECT:
 | |
|             continue
 | |
|         # Keep only a subset of states (the active ones). Index q is the
 | |
|         # states which are still alive. If we reject a state, we overwrite
 | |
|         # it in the states list, because q doesn't advance.
 | |
|         state = states[i]
 | |
|         states[q] = state
 | |
|         while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
 | |
|             if action == RETRY_EXTEND:
 | |
|                 # This handles the 'extend'
 | |
|                 new_states.push_back(
 | |
|                     PatternStateC(pattern=states[q].pattern, start=state.start,
 | |
|                                   length=state.length+1))
 | |
|             if action == RETRY_ADVANCE:
 | |
|                 # This handles the 'advance'
 | |
|                 new_states.push_back(
 | |
|                     PatternStateC(pattern=states[q].pattern+1, start=state.start,
 | |
|                                   length=state.length+1))
 | |
|             states[q].pattern += 1
 | |
|             if states[q].pattern.nr_py != 0:
 | |
|                 update_predicate_cache(cached_py_predicates,
 | |
|                     states[q].pattern, token, py_predicates)
 | |
|             action = get_action(states[q], token.c, extra_attrs,
 | |
|                                 cached_py_predicates)
 | |
|         if action == REJECT:
 | |
|             pass
 | |
|         elif action == ADVANCE:
 | |
|             states[q].pattern += 1
 | |
|             states[q].length += 1
 | |
|             q += 1
 | |
|         else:
 | |
|             ent_id = get_ent_id(state.pattern)
 | |
|             if action == MATCH:
 | |
|                 matches.push_back(
 | |
|                     MatchC(pattern_id=ent_id, start=state.start,
 | |
|                             length=state.length+1))
 | |
|             elif action == MATCH_DOUBLE:
 | |
|                 # push match without last token if length > 0
 | |
|                 if state.length > 0:
 | |
|                     matches.push_back(
 | |
|                         MatchC(pattern_id=ent_id, start=state.start,
 | |
|                                 length=state.length))
 | |
|                 # push match with last token
 | |
|                 matches.push_back(
 | |
|                     MatchC(pattern_id=ent_id, start=state.start,
 | |
|                             length=state.length+1))
 | |
|             elif action == MATCH_REJECT:
 | |
|                 matches.push_back(
 | |
|                     MatchC(pattern_id=ent_id, start=state.start,
 | |
|                             length=state.length))
 | |
|             elif action == MATCH_EXTEND:
 | |
|                 matches.push_back(
 | |
|                     MatchC(pattern_id=ent_id, start=state.start,
 | |
|                            length=state.length))
 | |
|                 states[q].length += 1
 | |
|                 q += 1
 | |
|     states.resize(q)
 | |
|     for i in range(new_states.size()):
 | |
|         states.push_back(new_states[i])
 | |
| 
 | |
| 
 | |
| cdef int update_predicate_cache(char* cache,
 | |
|         const TokenPatternC* pattern, Token token, predicates) except -1:
 | |
|     # If the state references any extra predicates, check whether they match.
 | |
|     # These are cached, so that we don't call these potentially expensive
 | |
|     # Python functions more than we need to.
 | |
|     for i in range(pattern.nr_py):
 | |
|         index = pattern.py_predicates[i]
 | |
|         if cache[index] == 0:
 | |
|             predicate = predicates[index]
 | |
|             result = predicate(token)
 | |
|             if result is True:
 | |
|                 cache[index] = 1
 | |
|             elif result is False:
 | |
|                 cache[index] = -1
 | |
|             elif result is None:
 | |
|                 pass
 | |
|             else:
 | |
|                 raise ValueError(Errors.E125.format(value=result))
 | |
| 
 | |
| 
 | |
| cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *:
 | |
|     """Handle states that end in zero-width patterns."""
 | |
|     cdef PatternStateC state
 | |
|     for i in range(states.size()):
 | |
|         state = states[i]
 | |
|         while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
 | |
|             is_final = get_is_final(state)
 | |
|             if is_final:
 | |
|                 ent_id = get_ent_id(state.pattern)
 | |
|                 matches.push_back(
 | |
|                     MatchC(pattern_id=ent_id, start=state.start, length=state.length))
 | |
|                 break
 | |
|             else:
 | |
|                 state.pattern += 1
 | |
| 
 | |
| 
 | |
| cdef action_t get_action(PatternStateC state,
 | |
|         const TokenC* token, const attr_t* extra_attrs,
 | |
|         const char* predicate_matches) nogil:
 | |
|     """We need to consider:
 | |
|     a) Does the token match the specification? [Yes, No]
 | |
|     b) What's the quantifier? [1, 0+, ?]
 | |
|     c) Is this the last specification? [final, non-final]
 | |
| 
 | |
|     We can transition in the following ways:
 | |
|     a) Do we emit a match?
 | |
|     b) Do we add a state with (next state, next token)?
 | |
|     c) Do we add a state with (next state, same token)?
 | |
|     d) Do we add a state with (same state, next token)?
 | |
| 
 | |
|     We'll code the actions as boolean strings, so 0000 means no to all 4,
 | |
|     1000 means match but no states added, etc.
 | |
| 
 | |
|     1:
 | |
|       Yes, final:
 | |
|         1000
 | |
|       Yes, non-final:
 | |
|         0100
 | |
|       No, final:
 | |
|         0000
 | |
|       No, non-final
 | |
|         0000
 | |
|     0+:
 | |
|       Yes, final:
 | |
|         1001
 | |
|       Yes, non-final:
 | |
|         0011
 | |
|       No, final:
 | |
|         1000 (note: Don't include last token!)
 | |
|       No, non-final:
 | |
|         0010
 | |
|     ?:
 | |
|       Yes, final:
 | |
|         1000
 | |
|       Yes, non-final:
 | |
|         0100
 | |
|       No, final:
 | |
|         1000 (note: Don't include last token!)
 | |
|       No, non-final:
 | |
|         0010
 | |
| 
 | |
|     Possible combinations:  1000, 0100, 0000, 1001, 0110, 0011, 0010,
 | |
| 
 | |
|     We'll name the bits "match", "advance", "retry", "extend"
 | |
|     REJECT = 0000
 | |
|     MATCH = 1000
 | |
|     ADVANCE = 0100
 | |
|     RETRY = 0010
 | |
|     MATCH_EXTEND = 1001
 | |
|     RETRY_ADVANCE = 0110
 | |
|     RETRY_EXTEND = 0011
 | |
|     MATCH_REJECT = 2000 # Match, but don't include last token
 | |
|     MATCH_DOUBLE = 3000 # Match both with and without last token
 | |
| 
 | |
|     Problem: If a quantifier is matching, we're adding a lot of open partials
 | |
|     """
 | |
|     cdef char is_match
 | |
|     is_match = get_is_match(state, token, extra_attrs, predicate_matches)
 | |
|     quantifier = get_quantifier(state)
 | |
|     is_final = get_is_final(state)
 | |
|     if quantifier == ZERO:
 | |
|         is_match = not is_match
 | |
|         quantifier = ONE
 | |
|     if quantifier == ONE:
 | |
|       if is_match and is_final:
 | |
|           # Yes, final: 1000
 | |
|           return MATCH
 | |
|       elif is_match and not is_final:
 | |
|           # Yes, non-final: 0100
 | |
|           return ADVANCE
 | |
|       elif not is_match and is_final:
 | |
|           # No, final: 0000
 | |
|           return REJECT
 | |
|       else:
 | |
|           return REJECT
 | |
|     elif quantifier == ZERO_PLUS:
 | |
|       if is_match and is_final:
 | |
|           # Yes, final: 1001
 | |
|           return MATCH_EXTEND
 | |
|       elif is_match and not is_final:
 | |
|           # Yes, non-final: 0011
 | |
|           return RETRY_EXTEND
 | |
|       elif not is_match and is_final:
 | |
|           # No, final 2000 (note: Don't include last token!)
 | |
|           return MATCH_REJECT
 | |
|       else:
 | |
|           # No, non-final 0010
 | |
|           return RETRY
 | |
|     elif quantifier == ZERO_ONE:
 | |
|       if is_match and is_final:
 | |
|           # Yes, final: 3000
 | |
|           # To cater for a pattern ending in "?", we need to add
 | |
|           # a match both with and without the last token
 | |
|           return MATCH_DOUBLE
 | |
|       elif is_match and not is_final:
 | |
|           # Yes, non-final: 0110
 | |
|           # We need both branches here, consider a pair like:
 | |
|           # pattern: .?b string: b
 | |
|           # If we 'ADVANCE' on the .?, we miss the match.
 | |
|           return RETRY_ADVANCE
 | |
|       elif not is_match and is_final:
 | |
|           # No, final 2000 (note: Don't include last token!)
 | |
|           return MATCH_REJECT
 | |
|       else:
 | |
|           # No, non-final 0010
 | |
|           return RETRY
 | |
| 
 | |
| 
 | |
| cdef char get_is_match(PatternStateC state,
 | |
|         const TokenC* token, const attr_t* extra_attrs,
 | |
|         const char* predicate_matches) nogil:
 | |
|     for i in range(state.pattern.nr_py):
 | |
|         if predicate_matches[state.pattern.py_predicates[i]] == -1:
 | |
|             return 0
 | |
|     spec = state.pattern
 | |
|     if spec.nr_attr > 0:
 | |
|         for attr in spec.attrs[:spec.nr_attr]:
 | |
|             if get_token_attr(token, attr.attr) != attr.value:
 | |
|                 return 0
 | |
|     for i in range(spec.nr_extra_attr):
 | |
|         if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
 | |
|             return 0
 | |
|     return True
 | |
| 
 | |
| 
 | |
| cdef char get_is_final(PatternStateC state) nogil:
 | |
|     if state.pattern[1].nr_attr == 0 and state.pattern[1].attrs != NULL:
 | |
|         id_attr = state.pattern[1].attrs[0]
 | |
|         if id_attr.attr != ID:
 | |
|             with gil:
 | |
|                 raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
 | |
|         return 1
 | |
|     else:
 | |
|         return 0
 | |
| 
 | |
| 
 | |
| cdef char get_quantifier(PatternStateC state) nogil:
 | |
|     return state.pattern.quantifier
 | |
| 
 | |
| 
 | |
| cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
 | |
|     pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
 | |
|     cdef int i, index
 | |
|     for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
 | |
|         pattern[i].quantifier = quantifier
 | |
|         # Ensure attrs refers to a null pointer if nr_attr == 0
 | |
|         if len(spec) > 0:
 | |
|             pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
 | |
|         pattern[i].nr_attr = len(spec)
 | |
|         for j, (attr, value) in enumerate(spec):
 | |
|             pattern[i].attrs[j].attr = attr
 | |
|             pattern[i].attrs[j].value = value
 | |
|         pattern[i].extra_attrs = <IndexValueC*>mem.alloc(len(extensions), sizeof(IndexValueC))
 | |
|         for j, (index, value) in enumerate(extensions):
 | |
|             pattern[i].extra_attrs[j].index = index
 | |
|             pattern[i].extra_attrs[j].value = value
 | |
|         pattern[i].nr_extra_attr = len(extensions)
 | |
|         pattern[i].py_predicates = <int32_t*>mem.alloc(len(predicates), sizeof(int32_t))
 | |
|         for j, index in enumerate(predicates):
 | |
|             pattern[i].py_predicates[j] = index
 | |
|         pattern[i].nr_py = len(predicates)
 | |
|         pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
 | |
|     i = len(token_specs)
 | |
|     # Even though here, nr_attr == 0, we're storing the ID value in attrs[0] (bug-prone, thread carefully!)
 | |
|     pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
 | |
|     pattern[i].attrs[0].attr = ID
 | |
|     pattern[i].attrs[0].value = entity_id
 | |
|     pattern[i].nr_attr = 0
 | |
|     pattern[i].nr_extra_attr = 0
 | |
|     pattern[i].nr_py = 0
 | |
|     return pattern
 | |
| 
 | |
| 
 | |
| cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
 | |
|     # There have been a few bugs here. We used to have two functions,
 | |
|     # get_ent_id and get_pattern_key that tried to do the same thing. These
 | |
|     # are now unified to try to solve the "ghost match" problem.
 | |
|     # Below is the previous implementation of get_ent_id and the comment on it,
 | |
|     # preserved for reference while we figure out whether the heisenbug in the
 | |
|     # matcher is resolved.
 | |
|     #
 | |
|     #
 | |
|     #     cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
 | |
|     #         # The code was originally designed to always have pattern[1].attrs.value
 | |
|     #         # be the ent_id when we get to the end of a pattern. However, Issue #2671
 | |
|     #         # showed this wasn't the case when we had a reject-and-continue before a
 | |
|     #         # match.
 | |
|     #         # The patch to #2671 was wrong though, which came up in #3839.
 | |
|     #         while pattern.attrs.attr != ID:
 | |
|     #             pattern += 1
 | |
|     #         return pattern.attrs.value
 | |
|     while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0 \
 | |
|             or pattern.quantifier != ZERO:
 | |
|         pattern += 1
 | |
|     id_attr = pattern[0].attrs[0]
 | |
|     if id_attr.attr != ID:
 | |
|         with gil:
 | |
|             raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
 | |
|     return id_attr.value
 | |
| 
 | |
| 
 | |
| def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predicates):
 | |
|     """This function interprets the pattern, converting the various bits of
 | |
|     syntactic sugar before we compile it into a struct with init_pattern.
 | |
| 
 | |
|     We need to split the pattern up into three parts:
 | |
|     * Normal attribute/value pairs, which are stored on either the token or lexeme,
 | |
|         can be handled directly.
 | |
|     * Extension attributes are handled specially, as we need to prefetch the
 | |
|         values from Python for the doc before we begin matching.
 | |
|     * Extra predicates also call Python functions, so we have to create the
 | |
|         functions and store them. So we store these specially as well.
 | |
|     * Extension attributes that have extra predicates are stored within the
 | |
|         extra_predicates.
 | |
|     """
 | |
|     tokens = []
 | |
|     for spec in token_specs:
 | |
|         if not spec:
 | |
|             # Signifier for 'any token'
 | |
|             tokens.append((ONE, [(NULL_ATTR, 0)], [], []))
 | |
|             continue
 | |
|         if not isinstance(spec, dict):
 | |
|             raise ValueError(Errors.E154.format())
 | |
|         ops = _get_operators(spec)
 | |
|         attr_values = _get_attr_values(spec, string_store)
 | |
|         extensions = _get_extensions(spec, string_store, extensions_table)
 | |
|         predicates = _get_extra_predicates(spec, extra_predicates)
 | |
|         for op in ops:
 | |
|             tokens.append((op, list(attr_values), list(extensions), list(predicates)))
 | |
|     return tokens
 | |
| 
 | |
| 
 | |
| def _get_attr_values(spec, string_store):
 | |
|     attr_values = []
 | |
|     for attr, value in spec.items():
 | |
|         if isinstance(attr, basestring):
 | |
|             attr = attr.upper()
 | |
|             if attr == '_':
 | |
|                 continue
 | |
|             elif attr == "OP":
 | |
|                 continue
 | |
|             if attr == "TEXT":
 | |
|                 attr = "ORTH"
 | |
|             if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
 | |
|                 raise ValueError(Errors.E152.format(attr=attr))
 | |
|             attr = IDS.get(attr)
 | |
|         if isinstance(value, basestring):
 | |
|             value = string_store.add(value)
 | |
|         elif isinstance(value, bool):
 | |
|             value = int(value)
 | |
|         elif isinstance(value, (dict, int)):
 | |
|             continue
 | |
|         else:
 | |
|             raise ValueError(Errors.E153.format(vtype=type(value).__name__))
 | |
|         if attr is not None:
 | |
|             attr_values.append((attr, value))
 | |
|         else:
 | |
|             # should be caught above using TOKEN_PATTERN_SCHEMA
 | |
|             raise ValueError(Errors.E152.format(attr=attr))
 | |
|     return attr_values
 | |
| 
 | |
| 
 | |
| # These predicate helper classes are used to match the REGEX, IN, >= etc
 | |
| # extensions to the matcher introduced in #3173.
 | |
| 
 | |
| class _RegexPredicate(object):
 | |
|     operators = ("REGEX",)
 | |
| 
 | |
|     def __init__(self, i, attr, value, predicate, is_extension=False):
 | |
|         self.i = i
 | |
|         self.attr = attr
 | |
|         self.value = re.compile(value)
 | |
|         self.predicate = predicate
 | |
|         self.is_extension = is_extension
 | |
|         self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
 | |
|         if self.predicate not in self.operators:
 | |
|             raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
 | |
| 
 | |
|     def __call__(self, Token token):
 | |
|         if self.is_extension:
 | |
|             value = token._.get(self.attr)
 | |
|         else:
 | |
|             value = token.vocab.strings[get_token_attr(token.c, self.attr)]
 | |
|         return bool(self.value.search(value))
 | |
| 
 | |
| 
 | |
| class _SetMemberPredicate(object):
 | |
|     operators = ("IN", "NOT_IN")
 | |
| 
 | |
|     def __init__(self, i, attr, value, predicate, is_extension=False):
 | |
|         self.i = i
 | |
|         self.attr = attr
 | |
|         self.value = set(get_string_id(v) for v in value)
 | |
|         self.predicate = predicate
 | |
|         self.is_extension = is_extension
 | |
|         self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
 | |
|         if self.predicate not in self.operators:
 | |
|             raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
 | |
| 
 | |
|     def __call__(self, Token token):
 | |
|         if self.is_extension:
 | |
|             value = get_string_id(token._.get(self.attr))
 | |
|         else:
 | |
|             value = get_token_attr(token.c, self.attr)
 | |
|         if self.predicate == "IN":
 | |
|             return value in self.value
 | |
|         else:
 | |
|             return value not in self.value
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return repr(("SetMemberPredicate", self.i, self.attr, self.value, self.predicate))
 | |
| 
 | |
| 
 | |
| class _ComparisonPredicate(object):
 | |
|     operators = ("==", "!=", ">=", "<=", ">", "<")
 | |
| 
 | |
|     def __init__(self, i, attr, value, predicate, is_extension=False):
 | |
|         self.i = i
 | |
|         self.attr = attr
 | |
|         self.value = value
 | |
|         self.predicate = predicate
 | |
|         self.is_extension = is_extension
 | |
|         self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
 | |
|         if self.predicate not in self.operators:
 | |
|             raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
 | |
| 
 | |
|     def __call__(self, Token token):
 | |
|         if self.is_extension:
 | |
|             value = token._.get(self.attr)
 | |
|         else:
 | |
|             value = get_token_attr(token.c, self.attr)
 | |
|         if self.predicate == "==":
 | |
|             return value == self.value
 | |
|         if self.predicate == "!=":
 | |
|             return value != self.value
 | |
|         elif self.predicate == ">=":
 | |
|             return value >= self.value
 | |
|         elif self.predicate == "<=":
 | |
|             return value <= self.value
 | |
|         elif self.predicate == ">":
 | |
|             return value > self.value
 | |
|         elif self.predicate == "<":
 | |
|             return value < self.value
 | |
| 
 | |
| 
 | |
| def _get_extra_predicates(spec, extra_predicates):
 | |
|     predicate_types = {
 | |
|         "REGEX": _RegexPredicate,
 | |
|         "IN": _SetMemberPredicate,
 | |
|         "NOT_IN": _SetMemberPredicate,
 | |
|         "==": _ComparisonPredicate,
 | |
|         ">=": _ComparisonPredicate,
 | |
|         "<=": _ComparisonPredicate,
 | |
|         ">": _ComparisonPredicate,
 | |
|         "<": _ComparisonPredicate,
 | |
|     }
 | |
|     seen_predicates = {pred.key: pred.i for pred in extra_predicates}
 | |
|     output = []
 | |
|     for attr, value in spec.items():
 | |
|         if isinstance(attr, basestring):
 | |
|             if attr == "_":
 | |
|                 output.extend(
 | |
|                     _get_extension_extra_predicates(
 | |
|                         value, extra_predicates, predicate_types,
 | |
|                         seen_predicates))
 | |
|                 continue
 | |
|             elif attr.upper() == "OP":
 | |
|                 continue
 | |
|             if attr.upper() == "TEXT":
 | |
|                 attr = "ORTH"
 | |
|             attr = IDS.get(attr.upper())
 | |
|         if isinstance(value, dict):
 | |
|             for type_, cls in predicate_types.items():
 | |
|                 if type_ in value:
 | |
|                     predicate = cls(len(extra_predicates), attr, value[type_], type_)
 | |
|                     # Don't create a redundant predicates.
 | |
|                     # This helps with efficiency, as we're caching the results.
 | |
|                     if predicate.key in seen_predicates:
 | |
|                         output.append(seen_predicates[predicate.key])
 | |
|                     else:
 | |
|                         extra_predicates.append(predicate)
 | |
|                         output.append(predicate.i)
 | |
|                         seen_predicates[predicate.key] = predicate.i
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
 | |
|         seen_predicates):
 | |
|     output = []
 | |
|     for attr, value in spec.items():
 | |
|         if isinstance(value, dict):
 | |
|             for type_, cls in predicate_types.items():
 | |
|                 if type_ in value:
 | |
|                     key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
 | |
|                     if key in seen_predicates:
 | |
|                         output.append(seen_predicates[key])
 | |
|                     else:
 | |
|                         predicate = cls(len(extra_predicates), attr, value[type_], type_,
 | |
|                                         is_extension=True)
 | |
|                         extra_predicates.append(predicate)
 | |
|                         output.append(predicate.i)
 | |
|                         seen_predicates[key] = predicate.i
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def _get_operators(spec):
 | |
|     # Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
 | |
|     lookup = {"*": (ZERO_PLUS,), "+": (ONE, ZERO_PLUS),
 | |
|               "?": (ZERO_ONE,), "1": (ONE,), "!": (ZERO,)}
 | |
|     # Fix casing
 | |
|     spec = {key.upper(): values for key, values in spec.items()
 | |
|             if isinstance(key, basestring)}
 | |
|     if "OP" not in spec:
 | |
|         return (ONE,)
 | |
|     elif spec["OP"] in lookup:
 | |
|         return lookup[spec["OP"]]
 | |
|     else:
 | |
|         keys = ", ".join(lookup.keys())
 | |
|         raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))
 | |
| 
 | |
| 
 | |
| def _get_extensions(spec, string_store, name2index):
 | |
|     attr_values = []
 | |
|     if not isinstance(spec.get("_", {}), dict):
 | |
|         raise ValueError(Errors.E154.format())
 | |
|     for name, value in spec.get("_", {}).items():
 | |
|         if isinstance(value, dict):
 | |
|             # Handle predicates (e.g. "IN", in the extra_predicates, not here.
 | |
|             continue
 | |
|         if isinstance(value, basestring):
 | |
|             value = string_store.add(value)
 | |
|         if name not in name2index:
 | |
|             name2index[name] = len(name2index)
 | |
|         attr_values.append((name2index[name], value))
 | |
|     return attr_values
 |