mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	* allow phrasematcher to link one match to multiple original patterns * small fix for defining ent_id in the matcher (anti-ghost prevention) * cleanup * formatting
		
			
				
	
	
		
			281 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			281 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True
 | |
| # cython: profile=True
 | |
| from __future__ import unicode_literals
 | |
| 
 | |
| from libcpp.vector cimport vector
 | |
| from cymem.cymem cimport Pool
 | |
| from murmurhash.mrmr cimport hash64
 | |
| from preshed.maps cimport PreshMap
 | |
| 
 | |
| from .matcher cimport Matcher
 | |
| from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
 | |
| from ..vocab cimport Vocab
 | |
| from ..tokens.doc cimport Doc, get_token_attr
 | |
| from ..typedefs cimport attr_t, hash_t
 | |
| 
 | |
| from ._schemas import TOKEN_PATTERN_SCHEMA
 | |
| from ..errors import Errors, Warnings, deprecation_warning, user_warning
 | |
| from ..attrs import FLAG61 as U_ENT
 | |
| from ..attrs import FLAG60 as B2_ENT
 | |
| from ..attrs import FLAG59 as B3_ENT
 | |
| from ..attrs import FLAG58 as B4_ENT
 | |
| from ..attrs import FLAG43 as L2_ENT
 | |
| from ..attrs import FLAG42 as L3_ENT
 | |
| from ..attrs import FLAG41 as L4_ENT
 | |
| from ..attrs import FLAG42 as I3_ENT
 | |
| from ..attrs import FLAG41 as I4_ENT
 | |
| 
 | |
| 
 | |
| cdef class PhraseMatcher:
 | |
|     """Efficiently match large terminology lists. While the `Matcher` matches
 | |
|     sequences based on lists of token descriptions, the `PhraseMatcher` accepts
 | |
|     match patterns in the form of `Doc` objects.
 | |
| 
 | |
|     DOCS: https://spacy.io/api/phrasematcher
 | |
|     USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
 | |
|     """
 | |
|     cdef Pool mem
 | |
|     cdef Vocab vocab
 | |
|     cdef Matcher matcher
 | |
|     cdef PreshMap phrase_ids
 | |
|     cdef vector[hash_vec] ent_id_matrix
 | |
|     cdef int max_length
 | |
|     cdef attr_id_t attr
 | |
|     cdef public object _callbacks
 | |
|     cdef public object _patterns
 | |
|     cdef public object _docs
 | |
|     cdef public object _validate
 | |
| 
 | |
|     def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
 | |
|         """Initialize the PhraseMatcher.
 | |
| 
 | |
|         vocab (Vocab): The shared vocabulary.
 | |
|         attr (int / unicode): Token attribute to match on.
 | |
|         validate (bool): Perform additional validation when patterns are added.
 | |
|         RETURNS (PhraseMatcher): The newly constructed object.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/phrasematcher#init
 | |
|         """
 | |
|         if max_length != 0:
 | |
|             deprecation_warning(Warnings.W010)
 | |
|         self.mem = Pool()
 | |
|         self.max_length = max_length
 | |
|         self.vocab = vocab
 | |
|         self.matcher = Matcher(self.vocab, validate=False)
 | |
|         if isinstance(attr, long):
 | |
|             self.attr = attr
 | |
|         else:
 | |
|             attr = attr.upper()
 | |
|             if attr == "TEXT":
 | |
|                 attr = "ORTH"
 | |
|             if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
 | |
|                 raise ValueError(Errors.E152.format(attr=attr))
 | |
|             self.attr = self.vocab.strings[attr]
 | |
|         self.phrase_ids = PreshMap()
 | |
|         abstract_patterns = [
 | |
|             [{U_ENT: True}],
 | |
|             [{B2_ENT: True}, {L2_ENT: True}],
 | |
|             [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
 | |
|             [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
 | |
|         ]
 | |
|         self.matcher.add("Candidate", None, *abstract_patterns)
 | |
|         self._callbacks = {}
 | |
|         self._docs = {}
 | |
|         self._validate = validate
 | |
| 
 | |
|     def __len__(self):
 | |
|         """Get the number of rules added to the matcher. Note that this only
 | |
|         returns the number of rules (identical with the number of IDs), not the
 | |
|         number of individual patterns.
 | |
| 
 | |
|         RETURNS (int): The number of rules.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/phrasematcher#len
 | |
|         """
 | |
|         return len(self._docs)
 | |
| 
 | |
|     def __contains__(self, key):
 | |
|         """Check whether the matcher contains rules for a match ID.
 | |
| 
 | |
|         key (unicode): The match ID.
 | |
|         RETURNS (bool): Whether the matcher contains rules for this match ID.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/phrasematcher#contains
 | |
|         """
 | |
|         cdef hash_t ent_id = self.matcher._normalize_key(key)
 | |
|         return ent_id in self._callbacks
 | |
| 
 | |
|     def __reduce__(self):
 | |
|         data = (self.vocab, self._docs, self._callbacks)
 | |
|         return (unpickle_matcher, data, None, None)
 | |
| 
 | |
|     def add(self, key, on_match, *docs):
 | |
|         """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
 | |
|         key, an on_match callback, and one or more patterns.
 | |
| 
 | |
|         key (unicode): The match ID.
 | |
|         on_match (callable): Callback executed on match.
 | |
|         *docs (Doc): `Doc` objects representing match patterns.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/phrasematcher#add
 | |
|         """
 | |
|         cdef Doc doc
 | |
|         cdef hash_t ent_id = self.matcher._normalize_key(key)
 | |
|         self._callbacks[ent_id] = on_match
 | |
|         self._docs[ent_id] = docs
 | |
|         cdef int length
 | |
|         cdef int i
 | |
|         cdef hash_t phrase_hash
 | |
|         cdef Pool mem = Pool()
 | |
|         for doc in docs:
 | |
|             length = doc.length
 | |
|             if length == 0:
 | |
|                 continue
 | |
|             if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
 | |
|                 raise ValueError(Errors.E155.format())
 | |
|             if self.attr == DEP and not doc.is_parsed:
 | |
|                 raise ValueError(Errors.E156.format())
 | |
|             if self._validate and (doc.is_tagged or doc.is_parsed) \
 | |
|               and self.attr not in (DEP, POS, TAG, LEMMA):
 | |
|                 string_attr = self.vocab.strings[self.attr]
 | |
|                 user_warning(Warnings.W012.format(key=key, attr=string_attr))
 | |
|             tags = get_biluo(length)
 | |
|             phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
 | |
|             for i, tag in enumerate(tags):
 | |
|                 attr_value = self.get_lex_value(doc, i)
 | |
|                 lexeme = self.vocab[attr_value]
 | |
|                 lexeme.set_flag(tag, True)
 | |
|                 phrase_key[i] = lexeme.orth
 | |
|             phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
 | |
| 
 | |
|             if phrase_hash in self.phrase_ids:
 | |
|                 phrase_index = self.phrase_ids[phrase_hash]
 | |
|                 ent_id_list = self.ent_id_matrix[phrase_index]
 | |
|                 ent_id_list.append(ent_id)
 | |
|                 self.ent_id_matrix[phrase_index] = ent_id_list
 | |
| 
 | |
|             else:
 | |
|                 ent_id_list = hash_vec(1)
 | |
|                 ent_id_list[0] = ent_id
 | |
|                 new_index = self.ent_id_matrix.size()
 | |
|                 if new_index == 0:
 | |
|                     # PreshMaps can not contain 0 as value, so storing a dummy at 0
 | |
|                     self.ent_id_matrix.push_back(hash_vec(0))
 | |
|                     new_index = 1
 | |
|                 self.ent_id_matrix.push_back(ent_id_list)
 | |
|                 self.phrase_ids.set(phrase_hash,  <void*>new_index)
 | |
| 
 | |
|     def __call__(self, Doc doc):
 | |
|         """Find all sequences matching the supplied patterns on the `Doc`.
 | |
| 
 | |
|         doc (Doc): The document to match over.
 | |
|         RETURNS (list): A list of `(key, start, end)` tuples,
 | |
|             describing the matches. A match tuple describes a span
 | |
|             `doc[start:end]`. The `label_id` and `key` are both integers.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/phrasematcher#call
 | |
|         """
 | |
|         matches = []
 | |
|         if self.attr == ORTH:
 | |
|             match_doc = doc
 | |
|         else:
 | |
|             # If we're not matching on the ORTH, match_doc will be a Doc whose
 | |
|             # token.orth values are the attribute values we're matching on,
 | |
|             # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
 | |
|             words = [self.get_lex_value(doc, i) for i in range(len(doc))]
 | |
|             match_doc = Doc(self.vocab, words=words)
 | |
|         for _, start, end in self.matcher(match_doc):
 | |
|             ent_ids = self.accept_match(match_doc, start, end)
 | |
|             if ent_ids is not None:
 | |
|                 for ent_id in ent_ids:
 | |
|                     matches.append((ent_id, start, end))
 | |
|         for i, (ent_id, start, end) in enumerate(matches):
 | |
|             on_match = self._callbacks.get(ent_id)
 | |
|             if on_match is not None:
 | |
|                 on_match(self, doc, i, matches)
 | |
|         return matches
 | |
| 
 | |
|     def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
 | |
|              as_tuples=False):
 | |
|         """Match a stream of documents, yielding them in turn.
 | |
| 
 | |
|         docs (iterable): A stream of documents.
 | |
|         batch_size (int): Number of documents to accumulate into a working set.
 | |
|         return_matches (bool): Yield the match lists along with the docs, making
 | |
|             results (doc, matches) tuples.
 | |
|         as_tuples (bool): Interpret the input stream as (doc, context) tuples,
 | |
|             and yield (result, context) tuples out.
 | |
|             If both return_matches and as_tuples are True, the output will
 | |
|             be a sequence of ((doc, matches), context) tuples.
 | |
|         YIELDS (Doc): Documents, in order.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/phrasematcher#pipe
 | |
|         """
 | |
|         if n_threads != -1:
 | |
|             deprecation_warning(Warnings.W016)
 | |
|         if as_tuples:
 | |
|             for doc, context in stream:
 | |
|                 matches = self(doc)
 | |
|                 if return_matches:
 | |
|                     yield ((doc, matches), context)
 | |
|                 else:
 | |
|                     yield (doc, context)
 | |
|         else:
 | |
|             for doc in stream:
 | |
|                 matches = self(doc)
 | |
|                 if return_matches:
 | |
|                     yield (doc, matches)
 | |
|                 else:
 | |
|                     yield doc
 | |
| 
 | |
|     def accept_match(self, Doc doc, int start, int end):
 | |
|         cdef int i, j
 | |
|         cdef Pool mem = Pool()
 | |
|         phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
 | |
|         for i, j in enumerate(range(start, end)):
 | |
|             phrase_key[i] = doc.c[j].lex.orth
 | |
|         cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
 | |
| 
 | |
|         ent_index = <hash_t>self.phrase_ids.get(key)
 | |
|         if ent_index == 0:
 | |
|             return None
 | |
|         return self.ent_id_matrix[ent_index]
 | |
| 
 | |
|     def get_lex_value(self, Doc doc, int i):
 | |
|         if self.attr == ORTH:
 | |
|             # Return the regular orth value of the lexeme
 | |
|             return doc.c[i].lex.orth
 | |
|         # Get the attribute value instead, e.g. token.pos
 | |
|         attr_value = get_token_attr(&doc.c[i], self.attr)
 | |
|         if attr_value in (0, 1):
 | |
|             # Value is boolean, convert to string
 | |
|             string_attr_value = str(attr_value)
 | |
|         else:
 | |
|             string_attr_value = self.vocab.strings[attr_value]
 | |
|         string_attr_name = self.vocab.strings[self.attr]
 | |
|         # Concatenate the attr name and value to not pollute lexeme space
 | |
|         # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
 | |
|         # create false positive matches
 | |
|         return "matcher:{}-{}".format(string_attr_name, string_attr_value)
 | |
| 
 | |
| 
 | |
| def get_biluo(length):
 | |
|     if length == 0:
 | |
|         raise ValueError(Errors.E127)
 | |
|     elif length == 1:
 | |
|         return [U_ENT]
 | |
|     elif length == 2:
 | |
|         return [B2_ENT, L2_ENT]
 | |
|     elif length == 3:
 | |
|         return [B3_ENT, I3_ENT, L3_ENT]
 | |
|     else:
 | |
|         return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
 | |
| 
 | |
| 
 | |
| def unpickle_matcher(vocab, docs, callbacks):
 | |
|     matcher = PhraseMatcher(vocab)
 | |
|     for key, specs in docs.items():
 | |
|         callback = callbacks.get(key, None)
 | |
|         matcher.add(key, callback, *specs)
 | |
|     return matcher
 |