mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-23 12:14:27 +03:00 
			
		
		
		
	Update Matcher API
This commit is contained in:
		
							parent
							
								
									8c9b3d5ad7
								
							
						
					
					
						commit
						ce9234f593
					
				|  | @ -87,7 +87,7 @@ ctypedef TokenPatternC* TokenPatternC_ptr | |||
| ctypedef pair[int, TokenPatternC_ptr] StateC | ||||
| 
 | ||||
| 
 | ||||
| cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label, | ||||
| cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, | ||||
|                                  object token_specs) except NULL: | ||||
|     pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) | ||||
|     cdef int i | ||||
|  | @ -99,15 +99,21 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label, | |||
|             pattern[i].attrs[j].attr = attr | ||||
|             pattern[i].attrs[j].value = value | ||||
|     i = len(token_specs) | ||||
|     pattern[i].attrs = <AttrValueC*>mem.alloc(3, sizeof(AttrValueC)) | ||||
|     pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC)) | ||||
|     pattern[i].attrs[0].attr = ID | ||||
|     pattern[i].attrs[0].value = entity_id | ||||
|     pattern[i].attrs[1].attr = ENT_TYPE | ||||
|     pattern[i].attrs[1].value = label | ||||
|     pattern[i].nr_attr = 0 | ||||
|     return pattern | ||||
| 
 | ||||
| 
 | ||||
| cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0: | ||||
|     while pattern.nr_attr != 0: | ||||
|         pattern += 1 | ||||
|     id_attr = pattern[0].attrs[0] | ||||
|     assert id_attr.attr == ID | ||||
|     return id_attr.value | ||||
| 
 | ||||
| 
 | ||||
| cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil: | ||||
|     for attr in pattern.attrs[:pattern.nr_attr]: | ||||
|         if get_token_attr(token, attr.attr) != attr.value: | ||||
|  | @ -175,12 +181,11 @@ cdef class Matcher: | |||
|     cdef public object _callbacks | ||||
|     cdef public object _acceptors | ||||
| 
 | ||||
|     def __init__(self, vocab, patterns={}): | ||||
|     def __init__(self, vocab): | ||||
|         """Create the Matcher. | ||||
| 
 | ||||
|         vocab (Vocab): The vocabulary object, which must be shared with the | ||||
|             documents the matcher will operate on. | ||||
|         patterns (dict): Patterns to add to the matcher. | ||||
|         RETURNS (Matcher): The newly constructed object. | ||||
|         """ | ||||
|         self._patterns = {} | ||||
|  | @ -189,123 +194,105 @@ cdef class Matcher: | |||
|         self._callbacks = {} | ||||
|         self.vocab = vocab | ||||
|         self.mem = Pool() | ||||
|         for entity_key, (etype, attrs, specs) in sorted(patterns.items()): | ||||
|             self.add_entity(entity_key, attrs) | ||||
|             for spec in specs: | ||||
|                 self.add_pattern(entity_key, spec, label=etype) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, (self.vocab, self._patterns), None, None) | ||||
| 
 | ||||
|     property n_patterns: | ||||
|         def __get__(self): return self.patterns.size() | ||||
|     def __len__(self): | ||||
|         return len(self._patterns) | ||||
| 
 | ||||
|     def add_entity(self, entity_key, attrs=None, if_exists='raise', | ||||
|                    acceptor=None, on_match=None): | ||||
|         # TODO: replace with new Matcher.add() | ||||
|         """Add an entity to the matcher. | ||||
|     def __contains__(self, key): | ||||
|         return len(self._patterns) | ||||
| 
 | ||||
|         entity_key (unicode or int): An ID for the entity. | ||||
|         attrs (dict): Attributes to associate with the `Matcher`. | ||||
|         if_exists (unicode): `'raise'`, `'ignore'` or `'update'`. Controls what | ||||
|             happens if the entity ID already exists. Defaults to `'raise'`. | ||||
|         acceptor (function): Callback function to filter matches of the entity. | ||||
|         on_match (function): Callback function to act on matches of the entity. | ||||
|     def add(self, key, on_match, *patterns): | ||||
|         """Add a match-rule to the matcher. | ||||
| 
 | ||||
|         A match-rule consists of: an ID key, an on_match callback, | ||||
|         and one or more patterns. If the key exists, the patterns | ||||
|         are appended to the previous ones, and the previous on_match | ||||
|         callback is replaced. | ||||
| 
 | ||||
|         The on_match callback will receive the arguments | ||||
|         (matcher, doc, i, matches). Note that if no `on_match` | ||||
|         callback is specified, the document will not be modified. | ||||
| 
 | ||||
|         A pattern consists of one or more token_specs, | ||||
|         where a token_spec is a dictionary mapping | ||||
|         attribute IDs to values. Token descriptors can also | ||||
|         include quantifiers. There are currently important | ||||
|         known problems with the quantifiers --- see the docs. | ||||
|         """ | ||||
|         if if_exists not in ('raise', 'ignore', 'update'): | ||||
|             raise ValueError( | ||||
|                 "Unexpected value for if_exists: %s.\n" | ||||
|                 "Expected one of: ['raise', 'ignore', 'update']" % if_exists) | ||||
|         if attrs is None: | ||||
|             attrs = {} | ||||
|         entity_key = self.normalize_entity_key(entity_key) | ||||
|         if self.has_entity(entity_key): | ||||
|             if if_exists == 'raise': | ||||
|                 raise KeyError( | ||||
|                     "Tried to add entity %s. Entity exists, and if_exists='raise'.\n" | ||||
|                     "Set if_exists='ignore' or if_exists='update', or check with " | ||||
|                     "matcher.has_entity()") | ||||
|             elif if_exists == 'ignore': | ||||
|                 return | ||||
|         self._entities[entity_key] = dict(attrs) | ||||
|         self._patterns.setdefault(entity_key, []) | ||||
|         self._acceptors[entity_key] = acceptor | ||||
|         self._callbacks[entity_key] = on_match | ||||
|         for pattern in patterns: | ||||
|             if len(pattern) == 0: | ||||
|                 msg = ("Cannot add pattern for zero tokens to matcher.\n" | ||||
|                        "key: {key}\n") | ||||
|                 raise ValueError(msg.format(key=key)) | ||||
|         key = self._normalize_key(key) | ||||
|         self._patterns.setdefault(key, []) | ||||
|         self._callbacks[key] = on_match | ||||
| 
 | ||||
|     def add_pattern(self, entity_key, token_specs, label=""): | ||||
|         # TODO: replace with new Matcher.add() | ||||
|         """Add a pattern to the matcher. | ||||
|         for pattern in patterns: | ||||
|             specs = _convert_strings(pattern, self.vocab.strings) | ||||
|             self.patterns.push_back(init_pattern(self.mem, key, specs)) | ||||
|             self._patterns[key].append(specs) | ||||
| 
 | ||||
|         entity_key (unicode): An ID for the entity. | ||||
|         token_specs (list): Description of the pattern to be matched. | ||||
|         label (unicode): Label to assign to the matched pattern. Defaults to `""`. | ||||
|     def remove(self, key): | ||||
|         """Remove a rule from the matcher. | ||||
| 
 | ||||
|         A KeyError is raised if the key does not exist. | ||||
|         """ | ||||
|         token_specs = list(token_specs) | ||||
|         if len(token_specs) == 0: | ||||
|             msg = ("Cannot add pattern for zero tokens to matcher.\n" | ||||
|                    "entity_key: {entity_key}\n" | ||||
|                    "label: {label}") | ||||
|             raise ValueError(msg.format(entity_key=entity_key, label=label)) | ||||
|         entity_key = self.normalize_entity_key(entity_key) | ||||
|         if not self.has_entity(entity_key): | ||||
|             self.add_entity(entity_key) | ||||
|         if isinstance(label, basestring): | ||||
|             label = self.vocab.strings[label] | ||||
|         elif label is None: | ||||
|             label = 0 | ||||
|         spec = _convert_strings(token_specs, self.vocab.strings) | ||||
|         key = self._normalize_key(key) | ||||
|         self._patterns.pop(key) | ||||
|         self._callbacks.pop(key) | ||||
|         cdef int i = 0 | ||||
|         while i < self.patterns.size(): | ||||
|             pattern_key = get_pattern_key(self.patterns.at(i)) | ||||
|             if pattern_key == key: | ||||
|                 self.patterns.erase(self.patterns.begin()+i) | ||||
|             else: | ||||
|                 i += 1 | ||||
| 
 | ||||
|         self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec)) | ||||
|         self._patterns[entity_key].append((label, token_specs)) | ||||
|     def has_key(self, key): | ||||
|         """Check whether the matcher has a rule with a given key. | ||||
| 
 | ||||
|     def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None): | ||||
|         # TODO: replace with new Matcher.add() | ||||
|         self.add_entity(entity_key, attrs=attrs, if_exists='update', | ||||
|                         acceptor=acceptor, on_match=on_match) | ||||
|         for spec in specs: | ||||
|             self.add_pattern(entity_key, spec, label=label) | ||||
| 
 | ||||
|     def normalize_entity_key(self, entity_key): | ||||
|         if isinstance(entity_key, basestring): | ||||
|             return self.vocab.strings[entity_key] | ||||
|         else: | ||||
|             return entity_key | ||||
| 
 | ||||
|     def has_entity(self, entity_key): | ||||
|         # TODO: deprecate | ||||
|         """Check whether the matcher has an entity. | ||||
| 
 | ||||
|         entity_key (string or int): The entity key to check. | ||||
|         RETURNS (bool): Whether the matcher has the entity. | ||||
|         key (string or int): The key to check. | ||||
|         RETURNS (bool): Whether the matcher has the rule. | ||||
|         """ | ||||
|         entity_key = self.normalize_entity_key(entity_key) | ||||
|         return entity_key in self._entities | ||||
|         key = self._normalize_key(key) | ||||
|         return key in self._patterns | ||||
| 
 | ||||
|     def get_entity(self, entity_key): | ||||
|         # TODO: deprecate | ||||
|         """Retrieve the attributes stored for an entity. | ||||
|     def get(self, key, default=None): | ||||
|         """Retrieve the pattern stored for an entity. | ||||
| 
 | ||||
|         entity_key (unicode or int): The entity to retrieve. | ||||
|         RETURNS (dict): The entity attributes if present, otherwise None. | ||||
|         key (unicode or int): The key to retrieve. | ||||
|         RETURNS (tuple): The rule, as an (on_match, patterns) tuple. | ||||
|         """ | ||||
|         entity_key = self.normalize_entity_key(entity_key) | ||||
|         if entity_key in self._entities: | ||||
|             return self._entities[entity_key] | ||||
|         else: | ||||
|             return None | ||||
|         key = self._normalize_key(key) | ||||
|         if key not in self._patterns: | ||||
|             return default | ||||
|         return (self._callbacks[key], self._patterns[key]) | ||||
| 
 | ||||
|     def __call__(self, Doc doc, acceptor=None): | ||||
|     def pipe(self, docs, batch_size=1000, n_threads=2): | ||||
|         """Match a stream of documents, yielding them in turn. | ||||
| 
 | ||||
|         docs (iterable): A stream of documents. | ||||
|         batch_size (int): The number of documents to accumulate into a working set. | ||||
|         n_threads (int): The number of threads with which to work on the buffer | ||||
|             in parallel, if the `Matcher` implementation supports multi-threading. | ||||
|         YIELDS (Doc): Documents, in order. | ||||
|         """ | ||||
|         for doc in docs: | ||||
|             self(doc) | ||||
|             yield doc | ||||
| 
 | ||||
|     def __call__(self, Doc doc): | ||||
|         """Find all token sequences matching the supplied patterns on the `Doc`. | ||||
| 
 | ||||
|         doc (Doc): The document to match over. | ||||
|         RETURNS (list): A list of `(entity_key, label_id, start, end)` tuples, | ||||
|         RETURNS (list): A list of `(key, label_id, start, end)` tuples, | ||||
|             describing the matches. A match tuple describes a span | ||||
|             `doc[start:end]`. The `label_id` and `entity_key` are both integers. | ||||
|             `doc[start:end]`. The `label_id` and `key` are both integers. | ||||
|         """ | ||||
|         if acceptor is not None: | ||||
|             raise ValueError( | ||||
|                 "acceptor keyword argument to Matcher deprecated. Specify acceptor " | ||||
|                 "functions when you add patterns instead.") | ||||
|         cdef vector[StateC] partials | ||||
|         cdef int n_partials = 0 | ||||
|         cdef int q = 0 | ||||
|  | @ -343,13 +330,7 @@ cdef class Matcher: | |||
|                     end = token_i+1 | ||||
|                     ent_id = state.second[1].attrs[0].value | ||||
|                     label = state.second[1].attrs[1].value | ||||
|                     acceptor = self._acceptors.get(ent_id) | ||||
|                     if acceptor is None: | ||||
|                         matches.append((ent_id, label, start, end)) | ||||
|                     else: | ||||
|                         match = acceptor(doc, ent_id, label, start, end) | ||||
|                         if match: | ||||
|                             matches.append(match) | ||||
|                     matches.append((ent_id, start, end)) | ||||
|             partials.resize(q) | ||||
|             # Check whether we open any new patterns on this token | ||||
|             for pattern in self.patterns: | ||||
|  | @ -374,13 +355,7 @@ cdef class Matcher: | |||
|                     end = token_i+1 | ||||
|                     ent_id = pattern[1].attrs[0].value | ||||
|                     label = pattern[1].attrs[1].value | ||||
|                     acceptor = self._acceptors.get(ent_id) | ||||
|                     if acceptor is None: | ||||
|                         matches.append((ent_id, label, start, end)) | ||||
|                     else: | ||||
|                         match = acceptor(doc, ent_id, label, start, end) | ||||
|                         if match: | ||||
|                             matches.append(match) | ||||
|                     matches.append((ent_id, start, end)) | ||||
|         # Look for open patterns that are actually satisfied | ||||
|         for state in partials: | ||||
|             while state.second.quantifier in (ZERO, ZERO_PLUS): | ||||
|  | @ -390,13 +365,7 @@ cdef class Matcher: | |||
|                     end = len(doc) | ||||
|                     ent_id = state.second.attrs[0].value | ||||
|                     label = state.second.attrs[0].value | ||||
|                     acceptor = self._acceptors.get(ent_id) | ||||
|                     if acceptor is None: | ||||
|                         matches.append((ent_id, label, start, end)) | ||||
|                     else: | ||||
|                         match = acceptor(doc, ent_id, label, start, end) | ||||
|                         if match: | ||||
|                             matches.append(match) | ||||
|                     matches.append((ent_id, start, end)) | ||||
|         for i, (ent_id, label, start, end) in enumerate(matches): | ||||
|             on_match = self._callbacks.get(ent_id) | ||||
|             if on_match is not None: | ||||
|  | @ -404,18 +373,11 @@ cdef class Matcher: | |||
|         # TODO: only return (match_id, start, end) | ||||
|         return matches | ||||
| 
 | ||||
|     def pipe(self, docs, batch_size=1000, n_threads=2): | ||||
|         """Match a stream of documents, yielding them in turn. | ||||
| 
 | ||||
|         docs (iterable): A stream of documents. | ||||
|         batch_size (int): The number of documents to accumulate into a working set. | ||||
|         n_threads (int): The number of threads with which to work on the buffer | ||||
|             in parallel, if the `Matcher` implementation supports multi-threading. | ||||
|         YIELDS (Doc): Documents, in order. | ||||
|         """ | ||||
|         for doc in docs: | ||||
|             self(doc) | ||||
|             yield doc | ||||
|     def _normalize_key(self, key): | ||||
|         if isinstance(key, basestring): | ||||
|             return self.vocab.strings[key] | ||||
|         else: | ||||
|             return key | ||||
| 
 | ||||
| 
 | ||||
| def get_bilou(length): | ||||
|  |  | |||
|  | @ -116,8 +116,9 @@ p Match a stream of documents, yielding them in turn. | |||
|     +tag method | ||||
| 
 | ||||
| p | ||||
|     |  Add one or more patterns to the matcher, along with a callback function | ||||
|     |  to handle the matches. The callback function will receive the arguments | ||||
|     |  Add a rule to the matcher, consisting of an ID key, one or more patterns, and | ||||
|     |  a callback function to act on the matches. | ||||
|     |  The callback function will receive the arguments | ||||
|     |  #[code matcher], #[code doc], #[code i] and #[code matches]. | ||||
| 
 | ||||
| +aside-code("Example"). | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user