# cython: infer_types=True, profile=True from collections import defaultdict from typing import List from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set import warnings from ..attrs cimport DEP, LEMMA, MORPH, POS, TAG from ..attrs import IDS from ..tokens.span cimport Span from ..tokens.token cimport Token from ..typedefs cimport attr_t from ..errors import Errors, Warnings from ..schemas import TokenPattern cdef class PhraseMatcher: """Efficiently match large terminology lists. While the `Matcher` matches sequences based on lists of token descriptions, the `PhraseMatcher` accepts match patterns in the form of `Doc` objects. DOCS: https://spacy.io/api/phrasematcher USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher Adapted from FlashText: https://github.com/vi3k6i5/flashtext MIT License (see `LICENSE`) Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com) """ def __init__(self, Vocab vocab, attr="ORTH", validate=False): """Initialize the PhraseMatcher. vocab (Vocab): The shared vocabulary. attr (int / str): Token attribute to match on. validate (bool): Perform additional validation when patterns are added. DOCS: https://spacy.io/api/phrasematcher#init """ self.vocab = vocab self._callbacks = {} self._docs = defaultdict(set) self._validate = validate self.mem = Pool() self.c_map = self.mem.alloc(1, sizeof(MapStruct)) self._terminal_hash = 826361138722620965 map_init(self.mem, self.c_map, 8) if isinstance(attr, (int, long)): self.attr = attr else: if attr is None: attr = "ORTH" attr = attr.upper() if attr == "TEXT": attr = "ORTH" if attr == "IS_SENT_START": attr = "SENT_START" if attr.lower() not in TokenPattern().dict(): raise ValueError(Errors.E152.format(attr=attr)) self.attr = IDS.get(attr) def __len__(self): """Get the number of match IDs added to the matcher. RETURNS (int): The number of rules. DOCS: https://spacy.io/api/phrasematcher#len """ return len(self._callbacks) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. key (str): The match ID. RETURNS (bool): Whether the matcher contains rules for this match ID. DOCS: https://spacy.io/api/phrasematcher#contains """ return key in self._callbacks def __reduce__(self): data = (self.vocab, self._docs, self._callbacks, self.attr) return (unpickle_matcher, data, None, None) def remove(self, key): """Remove a rule from the matcher by match ID. A KeyError is raised if the key does not exist. key (str): The match ID. DOCS: https://spacy.io/api/phrasematcher#remove """ if key not in self._docs: raise KeyError(key) cdef MapStruct* current_node cdef MapStruct* terminal_map cdef MapStruct* node_pointer cdef void* result cdef key_t terminal_key cdef void* value cdef int c_i = 0 cdef vector[MapStruct*] path_nodes cdef vector[key_t] path_keys cdef key_t key_to_remove for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True): current_node = self.c_map path_nodes.clear() path_keys.clear() for token in keyword: result = map_get(current_node, token) if result: path_nodes.push_back(current_node) path_keys.push_back(token) current_node = result else: # if token is not found, break out of the loop current_node = NULL break path_nodes.push_back(current_node) path_keys.push_back(self._terminal_hash) # remove the tokens from trie node if there are no other # keywords with them result = map_get(current_node, self._terminal_hash) if current_node != NULL and result: terminal_map = result terminal_keys = [] c_i = 0 while map_iter(terminal_map, &c_i, &terminal_key, &value): terminal_keys.append(self.vocab.strings[terminal_key]) # if this is the only remaining key, remove unnecessary paths if terminal_keys == [key]: while not path_nodes.empty(): node_pointer = path_nodes.back() path_nodes.pop_back() key_to_remove = path_keys.back() path_keys.pop_back() result = map_get(node_pointer, key_to_remove) if node_pointer.filled == 1: map_clear(node_pointer, key_to_remove) self.mem.free(result) else: # more than one key means more than 1 path, # delete not required path and keep the others map_clear(node_pointer, key_to_remove) self.mem.free(result) break # otherwise simply remove the key else: result = map_get(current_node, self._terminal_hash) if result: map_clear(result, self.vocab.strings[key]) del self._callbacks[key] del self._docs[key] def _add_from_arrays(self, key, specs, *, on_match=None): """Add a preprocessed list of specs, with an optional callback. key (str): The match ID. specs (List[List[int]]): A list of lists of hashes to match. on_match (callable): Callback executed on match. """ cdef MapStruct* current_node cdef MapStruct* internal_node cdef void* result self._callbacks[key] = on_match for spec in specs: self._docs[key].add(tuple(spec)) current_node = self.c_map for token in spec: if token == self._terminal_hash: warnings.warn(Warnings.W021) break result = map_get(current_node, token) if not result: internal_node = self.mem.alloc(1, sizeof(MapStruct)) map_init(self.mem, internal_node, 8) map_set(self.mem, current_node, token, internal_node) result = internal_node current_node = result result = map_get(current_node, self._terminal_hash) if not result: internal_node = self.mem.alloc(1, sizeof(MapStruct)) map_init(self.mem, internal_node, 8) map_set(self.mem, current_node, self._terminal_hash, internal_node) result = internal_node map_set(self.mem, result, self.vocab.strings[key], NULL) def add(self, key, docs, *, on_match=None): """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID key, a list of one or more patterns, and (optionally) an on_match callback. key (str): The match ID. docs (list): List of `Doc` objects representing match patterns. on_match (callable): Callback executed on match. If any of the input Docs are invalid, no internal state will be updated. DOCS: https://spacy.io/api/phrasematcher#add """ if isinstance(docs, Doc): raise ValueError(Errors.E179.format(key=key)) if docs is None or not isinstance(docs, List): raise ValueError(Errors.E948.format(name="PhraseMatcher", arg_type=type(docs))) if on_match is not None and not hasattr(on_match, "__call__"): raise ValueError(Errors.E171.format(name="PhraseMatcher", arg_type=type(on_match))) _ = self.vocab[key] specs = [] for doc in docs: if len(doc) == 0: continue if not isinstance(doc, Doc): raise ValueError(Errors.E4000.format(type=type(doc))) attrs = (TAG, POS, MORPH, LEMMA, DEP) has_annotation = {attr: doc.has_annotation(attr) for attr in attrs} for attr in attrs: if self.attr == attr and not has_annotation[attr]: if attr == TAG: pipe = "tagger" elif attr in (POS, MORPH): pipe = "morphologizer or tagger+attribute_ruler" elif attr == LEMMA: pipe = "lemmatizer" elif attr == DEP: pipe = "parser" error_msg = Errors.E155.format(pipe=pipe, attr=self.vocab.strings.as_string(attr)) raise ValueError(error_msg) if self._validate and any(has_annotation.values()) \ and self.attr not in attrs: string_attr = self.vocab.strings[self.attr] warnings.warn(Warnings.W012.format(key=key, attr=string_attr)) specs.append(self._convert_to_array(doc)) self._add_from_arrays(key, specs, on_match=on_match) def __call__(self, object doclike, *, as_spans=False): """Find all sequences matching the supplied patterns on the `Doc`. doclike (Doc or Span): The document to match over. as_spans (bool): Return Span objects with labels instead of (match_id, start, end) tuples. RETURNS (list): A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end]`. The `match_id` is an integer. If as_spans is set to True, a list of Span objects is returned. DOCS: https://spacy.io/api/phrasematcher#call """ matches = [] if doclike is None or len(doclike) == 0: # if doc is empty or None just return empty list return matches if isinstance(doclike, Doc): doc = doclike start_idx = 0 end_idx = len(doc) elif isinstance(doclike, Span): doc = doclike.doc start_idx = doclike.start end_idx = doclike.end else: raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) cdef vector[SpanC] c_matches self.find_matches(doc, start_idx, end_idx, &c_matches) for i in range(c_matches.size()): matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end)) for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(self.vocab.strings[ent_id]) if on_match is not None: on_match(self, doc, i, matches) if as_spans: return [Span(doc, start, end, label=key) for key, start, end in matches] else: return matches cdef void find_matches(self, Doc doc, int start_idx, int end_idx, vector[SpanC] *matches) nogil: cdef MapStruct* current_node = self.c_map cdef int start = 0 cdef int idx = start_idx cdef int idy = start_idx cdef key_t key cdef void* value cdef int i = 0 cdef SpanC ms cdef void* result while idx < end_idx: start = idx token = Token.get_struct_attr(&doc.c[idx], self.attr) # look for sequences from this position result = map_get(current_node, token) if result: current_node = result idy = idx + 1 while idy < end_idx: result = map_get(current_node, self._terminal_hash) if result: i = 0 while map_iter(result, &i, &key, &value): ms = make_spanstruct(key, start, idy) matches.push_back(ms) inner_token = Token.get_struct_attr(&doc.c[idy], self.attr) result = map_get(current_node, inner_token) if result: current_node = result idy += 1 else: break else: # end of doc reached result = map_get(current_node, self._terminal_hash) if result: i = 0 while map_iter(result, &i, &key, &value): ms = make_spanstruct(key, start, idy) matches.push_back(ms) current_node = self.c_map idx += 1 def pipe(self, stream, batch_size=1000, return_matches=False, as_tuples=False): """Match a stream of documents, yielding them in turn. Deprecated as of spaCy v3.0. """ warnings.warn(Warnings.W105.format(matcher="PhraseMatcher"), DeprecationWarning) if as_tuples: for doc, context in stream: matches = self(doc) if return_matches: yield ((doc, matches), context) else: yield (doc, context) else: for doc in stream: matches = self(doc) if return_matches: yield (doc, matches) else: yield doc def _convert_to_array(self, Doc doc): return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))] def unpickle_matcher(vocab, docs, callbacks, attr): matcher = PhraseMatcher(vocab, attr=attr) for key, specs in docs.items(): callback = callbacks.get(key, None) matcher._add_from_arrays(key, specs, on_match=callback) return matcher cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil: cdef SpanC spanc spanc.label = label spanc.start = start spanc.end = end return spanc