# cython: infer_types=True, profile=True from typing import List from collections import defaultdict from libc.stdint cimport uintptr_t from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter import warnings from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, MORPH from ..attrs import IDS from ..structs cimport TokenC from ..tokens.token cimport Token from ..tokens.span cimport Span from ..typedefs cimport attr_t from ..schemas import TokenPattern from ..errors import Errors, Warnings cdef class PhraseMatcher: """Efficiently match large terminology lists. While the `Matcher` matches sequences based on lists of token descriptions, the `PhraseMatcher` accepts match patterns in the form of `Doc` objects. DOCS: https://spacy.io/api/phrasematcher USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher Adapted from FlashText: https://github.com/vi3k6i5/flashtext MIT License (see `LICENSE`) Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com) """ def __init__(self, Vocab vocab, attr="ORTH", validate=False): """Initialize the PhraseMatcher. vocab (Vocab): The shared vocabulary. attr (int / str): Token attribute to match on. validate (bool): Perform additional validation when patterns are added. DOCS: https://spacy.io/api/phrasematcher#init """ self.vocab = vocab self._callbacks = {} self._docs = defaultdict(set) self._validate = validate self.mem = Pool() self.c_map = self.mem.alloc(1, sizeof(MapStruct)) self._terminal_hash = 826361138722620965 map_init(self.mem, self.c_map, 8) if isinstance(attr, (int, long)): self.attr = attr else: if attr is None: attr = "ORTH" attr = attr.upper() if attr == "TEXT": attr = "ORTH" if attr == "IS_SENT_START": attr = "SENT_START" if attr.lower() not in TokenPattern().dict(): raise ValueError(Errors.E152.format(attr=attr)) self.attr = IDS.get(attr) def __len__(self): """Get the number of match IDs added to the matcher. RETURNS (int): The number of rules. DOCS: https://spacy.io/api/phrasematcher#len """ return len(self._callbacks) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. key (str): The match ID. RETURNS (bool): Whether the matcher contains rules for this match ID. DOCS: https://spacy.io/api/phrasematcher#contains """ return key in self._callbacks def __reduce__(self): data = (self.vocab, self._docs, self._callbacks, self.attr) return (unpickle_matcher, data, None, None) def remove(self, key): """Remove a rule from the matcher by match ID. A KeyError is raised if the key does not exist. key (str): The match ID. DOCS: https://spacy.io/api/phrasematcher#remove """ if key not in self._docs: raise KeyError(key) cdef MapStruct* current_node cdef MapStruct* terminal_map cdef MapStruct* node_pointer cdef void* result cdef key_t terminal_key cdef void* value cdef int c_i = 0 cdef vector[MapStruct*] path_nodes cdef vector[key_t] path_keys cdef key_t key_to_remove for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True): current_node = self.c_map path_nodes.clear() path_keys.clear() for token in keyword: result = map_get(current_node, token) if result: path_nodes.push_back(current_node) path_keys.push_back(token) current_node = result else: # if token is not found, break out of the loop current_node = NULL break path_nodes.push_back(current_node) path_keys.push_back(self._terminal_hash) # remove the tokens from trie node if there are no other # keywords with them result = map_get(current_node, self._terminal_hash) if current_node != NULL and result: terminal_map = result terminal_keys = [] c_i = 0 while map_iter(terminal_map, &c_i, &terminal_key, &value): terminal_keys.append(self.vocab.strings[terminal_key]) # if this is the only remaining key, remove unnecessary paths if terminal_keys == [key]: while not path_nodes.empty(): node_pointer = path_nodes.back() path_nodes.pop_back() key_to_remove = path_keys.back() path_keys.pop_back() result = map_get(node_pointer, key_to_remove) if node_pointer.filled == 1: map_clear(node_pointer, key_to_remove) self.mem.free(result) else: # more than one key means more than 1 path, # delete not required path and keep the others map_clear(node_pointer, key_to_remove) self.mem.free(result) break # otherwise simply remove the key else: result = map_get(current_node, self._terminal_hash) if result: map_clear(result, self.vocab.strings[key]) del self._callbacks[key] del self._docs[key] def _add_from_arrays(self, key, specs, *, on_match=None): """Add a preprocessed list of specs, with an optional callback. key (str): The match ID. specs (List[List[int]]): A list of lists of hashes to match. on_match (callable): Callback executed on match. """ cdef MapStruct* current_node cdef MapStruct* internal_node cdef void* result self._callbacks[key] = on_match for spec in specs: self._docs[key].add(tuple(spec)) current_node = self.c_map for token in spec: if token == self._terminal_hash: warnings.warn(Warnings.W021) break result = map_get(current_node, token) if not result: internal_node = self.mem.alloc(1, sizeof(MapStruct)) map_init(self.mem, internal_node, 8) map_set(self.mem, current_node, token, internal_node) result = internal_node current_node = result result = map_get(current_node, self._terminal_hash) if not result: internal_node = self.mem.alloc(1, sizeof(MapStruct)) map_init(self.mem, internal_node, 8) map_set(self.mem, current_node, self._terminal_hash, internal_node) result = internal_node map_set(self.mem, result, self.vocab.strings[key], NULL) def add(self, key, docs, *, on_match=None): """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID key, a list of one or more patterns, and (optionally) an on_match callback. key (str): The match ID. docs (list): List of `Doc` objects representing match patterns. on_match (callable): Callback executed on match. If any of the input Docs are invalid, no internal state will be updated. DOCS: https://spacy.io/api/phrasematcher#add """ if isinstance(docs, Doc): raise ValueError(Errors.E179.format(key=key)) if docs is None or not isinstance(docs, List): raise ValueError(Errors.E948.format(name="PhraseMatcher", arg_type=type(docs))) if on_match is not None and not hasattr(on_match, "__call__"): raise ValueError(Errors.E171.format(name="PhraseMatcher", arg_type=type(on_match))) _ = self.vocab[key] specs = [] for doc in docs: if len(doc) == 0: continue if not isinstance(doc, Doc): raise ValueError(Errors.E4000.format(type=type(doc))) attrs = (TAG, POS, MORPH, LEMMA, DEP) has_annotation = {attr: doc.has_annotation(attr) for attr in attrs} for attr in attrs: if self.attr == attr and not has_annotation[attr]: if attr == TAG: pipe = "tagger" elif attr in (POS, MORPH): pipe = "morphologizer or tagger+attribute_ruler" elif attr == LEMMA: pipe = "lemmatizer" elif attr == DEP: pipe = "parser" error_msg = Errors.E155.format(pipe=pipe, attr=self.vocab.strings.as_string(attr)) raise ValueError(error_msg) if self._validate and any(has_annotation.values()) \ and self.attr not in attrs: string_attr = self.vocab.strings[self.attr] warnings.warn(Warnings.W012.format(key=key, attr=string_attr)) specs.append(self._convert_to_array(doc)) self._add_from_arrays(key, specs, on_match=on_match) def __call__(self, object doclike, *, as_spans=False): """Find all sequences matching the supplied patterns on the `Doc`. doclike (Doc or Span): The document to match over. as_spans (bool): Return Span objects with labels instead of (match_id, start, end) tuples. RETURNS (list): A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span `doc[start:end]`. The `match_id` is an integer. If as_spans is set to True, a list of Span objects is returned. DOCS: https://spacy.io/api/phrasematcher#call """ matches = [] if doclike is None or len(doclike) == 0: # if doc is empty or None just return empty list return matches if isinstance(doclike, Doc): doc = doclike start_idx = 0 end_idx = len(doc) elif isinstance(doclike, Span): doc = doclike.doc start_idx = doclike.start end_idx = doclike.end else: raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) cdef vector[SpanC] c_matches self.find_matches(doc, start_idx, end_idx, &c_matches) for i in range(c_matches.size()): matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end)) for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(self.vocab.strings[ent_id]) if on_match is not None: on_match(self, doc, i, matches) if as_spans: return [Span(doc, start, end, label=key) for key, start, end in matches] else: return matches cdef void find_matches(self, Doc doc, int start_idx, int end_idx, vector[SpanC] *matches) nogil: cdef MapStruct* current_node = self.c_map cdef int start = 0 cdef int idx = start_idx cdef int idy = start_idx cdef key_t key cdef void* value cdef int i = 0 cdef SpanC ms cdef void* result while idx < end_idx: start = idx token = Token.get_struct_attr(&doc.c[idx], self.attr) # look for sequences from this position result = map_get(current_node, token) if result: current_node = result idy = idx + 1 while idy < end_idx: result = map_get(current_node, self._terminal_hash) if result: i = 0 while map_iter(result, &i, &key, &value): ms = make_spanstruct(key, start, idy) matches.push_back(ms) inner_token = Token.get_struct_attr(&doc.c[idy], self.attr) result = map_get(current_node, inner_token) if result: current_node = result idy += 1 else: break else: # end of doc reached result = map_get(current_node, self._terminal_hash) if result: i = 0 while map_iter(result, &i, &key, &value): ms = make_spanstruct(key, start, idy) matches.push_back(ms) current_node = self.c_map idx += 1 def pipe(self, stream, batch_size=1000, return_matches=False, as_tuples=False): """Match a stream of documents, yielding them in turn. Deprecated as of spaCy v3.0. """ warnings.warn(Warnings.W105.format(matcher="PhraseMatcher"), DeprecationWarning) if as_tuples: for doc, context in stream: matches = self(doc) if return_matches: yield ((doc, matches), context) else: yield (doc, context) else: for doc in stream: matches = self(doc) if return_matches: yield (doc, matches) else: yield doc def _convert_to_array(self, Doc doc): return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))] def unpickle_matcher(vocab, docs, callbacks, attr): matcher = PhraseMatcher(vocab, attr=attr) for key, specs in docs.items(): callback = callbacks.get(key, None) matcher._add_from_arrays(key, specs, on_match=callback) return matcher cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil: cdef SpanC spanc spanc.label = label spanc.start = start spanc.end = end return spanc