diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index c5bf70ce2..bdd9fce29 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -87,7 +87,7 @@ ctypedef TokenPatternC* TokenPatternC_ptr ctypedef pair[int, TokenPatternC_ptr] StateC -cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label, +cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: pattern = mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) cdef int i @@ -99,15 +99,21 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label, pattern[i].attrs[j].attr = attr pattern[i].attrs[j].value = value i = len(token_specs) - pattern[i].attrs = mem.alloc(3, sizeof(AttrValueC)) + pattern[i].attrs = mem.alloc(2, sizeof(AttrValueC)) pattern[i].attrs[0].attr = ID pattern[i].attrs[0].value = entity_id - pattern[i].attrs[1].attr = ENT_TYPE - pattern[i].attrs[1].value = label pattern[i].nr_attr = 0 return pattern +cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0: + while pattern.nr_attr != 0: + pattern += 1 + id_attr = pattern[0].attrs[0] + assert id_attr.attr == ID + return id_attr.value + + cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil: for attr in pattern.attrs[:pattern.nr_attr]: if get_token_attr(token, attr.attr) != attr.value: @@ -175,12 +181,11 @@ cdef class Matcher: cdef public object _callbacks cdef public object _acceptors - def __init__(self, vocab, patterns={}): + def __init__(self, vocab): """Create the Matcher. vocab (Vocab): The vocabulary object, which must be shared with the documents the matcher will operate on. - patterns (dict): Patterns to add to the matcher. RETURNS (Matcher): The newly constructed object. """ self._patterns = {} @@ -189,123 +194,105 @@ cdef class Matcher: self._callbacks = {} self.vocab = vocab self.mem = Pool() - for entity_key, (etype, attrs, specs) in sorted(patterns.items()): - self.add_entity(entity_key, attrs) - for spec in specs: - self.add_pattern(entity_key, spec, label=etype) def __reduce__(self): return (self.__class__, (self.vocab, self._patterns), None, None) - property n_patterns: - def __get__(self): return self.patterns.size() + def __len__(self): + return len(self._patterns) - def add_entity(self, entity_key, attrs=None, if_exists='raise', - acceptor=None, on_match=None): - # TODO: replace with new Matcher.add() - """Add an entity to the matcher. + def __contains__(self, key): + return len(self._patterns) - entity_key (unicode or int): An ID for the entity. - attrs (dict): Attributes to associate with the `Matcher`. - if_exists (unicode): `'raise'`, `'ignore'` or `'update'`. Controls what - happens if the entity ID already exists. Defaults to `'raise'`. - acceptor (function): Callback function to filter matches of the entity. - on_match (function): Callback function to act on matches of the entity. + def add(self, key, on_match, *patterns): + """Add a match-rule to the matcher. + + A match-rule consists of: an ID key, an on_match callback, + and one or more patterns. If the key exists, the patterns + are appended to the previous ones, and the previous on_match + callback is replaced. + + The on_match callback will receive the arguments + (matcher, doc, i, matches). Note that if no `on_match` + callback is specified, the document will not be modified. + + A pattern consists of one or more token_specs, + where a token_spec is a dictionary mapping + attribute IDs to values. Token descriptors can also + include quantifiers. There are currently important + known problems with the quantifiers --- see the docs. """ - if if_exists not in ('raise', 'ignore', 'update'): - raise ValueError( - "Unexpected value for if_exists: %s.\n" - "Expected one of: ['raise', 'ignore', 'update']" % if_exists) - if attrs is None: - attrs = {} - entity_key = self.normalize_entity_key(entity_key) - if self.has_entity(entity_key): - if if_exists == 'raise': - raise KeyError( - "Tried to add entity %s. Entity exists, and if_exists='raise'.\n" - "Set if_exists='ignore' or if_exists='update', or check with " - "matcher.has_entity()") - elif if_exists == 'ignore': - return - self._entities[entity_key] = dict(attrs) - self._patterns.setdefault(entity_key, []) - self._acceptors[entity_key] = acceptor - self._callbacks[entity_key] = on_match + for pattern in patterns: + if len(pattern) == 0: + msg = ("Cannot add pattern for zero tokens to matcher.\n" + "key: {key}\n") + raise ValueError(msg.format(key=key)) + key = self._normalize_key(key) + self._patterns.setdefault(key, []) + self._callbacks[key] = on_match - def add_pattern(self, entity_key, token_specs, label=""): - # TODO: replace with new Matcher.add() - """Add a pattern to the matcher. + for pattern in patterns: + specs = _convert_strings(pattern, self.vocab.strings) + self.patterns.push_back(init_pattern(self.mem, key, specs)) + self._patterns[key].append(specs) - entity_key (unicode): An ID for the entity. - token_specs (list): Description of the pattern to be matched. - label (unicode): Label to assign to the matched pattern. Defaults to `""`. + def remove(self, key): + """Remove a rule from the matcher. + + A KeyError is raised if the key does not exist. """ - token_specs = list(token_specs) - if len(token_specs) == 0: - msg = ("Cannot add pattern for zero tokens to matcher.\n" - "entity_key: {entity_key}\n" - "label: {label}") - raise ValueError(msg.format(entity_key=entity_key, label=label)) - entity_key = self.normalize_entity_key(entity_key) - if not self.has_entity(entity_key): - self.add_entity(entity_key) - if isinstance(label, basestring): - label = self.vocab.strings[label] - elif label is None: - label = 0 - spec = _convert_strings(token_specs, self.vocab.strings) + key = self._normalize_key(key) + self._patterns.pop(key) + self._callbacks.pop(key) + cdef int i = 0 + while i < self.patterns.size(): + pattern_key = get_pattern_key(self.patterns.at(i)) + if pattern_key == key: + self.patterns.erase(self.patterns.begin()+i) + else: + i += 1 - self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec)) - self._patterns[entity_key].append((label, token_specs)) + def has_key(self, key): + """Check whether the matcher has a rule with a given key. - def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None): - # TODO: replace with new Matcher.add() - self.add_entity(entity_key, attrs=attrs, if_exists='update', - acceptor=acceptor, on_match=on_match) - for spec in specs: - self.add_pattern(entity_key, spec, label=label) - - def normalize_entity_key(self, entity_key): - if isinstance(entity_key, basestring): - return self.vocab.strings[entity_key] - else: - return entity_key - - def has_entity(self, entity_key): - # TODO: deprecate - """Check whether the matcher has an entity. - - entity_key (string or int): The entity key to check. - RETURNS (bool): Whether the matcher has the entity. + key (string or int): The key to check. + RETURNS (bool): Whether the matcher has the rule. """ - entity_key = self.normalize_entity_key(entity_key) - return entity_key in self._entities + key = self._normalize_key(key) + return key in self._patterns - def get_entity(self, entity_key): - # TODO: deprecate - """Retrieve the attributes stored for an entity. + def get(self, key, default=None): + """Retrieve the pattern stored for an entity. - entity_key (unicode or int): The entity to retrieve. - RETURNS (dict): The entity attributes if present, otherwise None. + key (unicode or int): The key to retrieve. + RETURNS (tuple): The rule, as an (on_match, patterns) tuple. """ - entity_key = self.normalize_entity_key(entity_key) - if entity_key in self._entities: - return self._entities[entity_key] - else: - return None + key = self._normalize_key(key) + if key not in self._patterns: + return default + return (self._callbacks[key], self._patterns[key]) - def __call__(self, Doc doc, acceptor=None): + def pipe(self, docs, batch_size=1000, n_threads=2): + """Match a stream of documents, yielding them in turn. + + docs (iterable): A stream of documents. + batch_size (int): The number of documents to accumulate into a working set. + n_threads (int): The number of threads with which to work on the buffer + in parallel, if the `Matcher` implementation supports multi-threading. + YIELDS (Doc): Documents, in order. + """ + for doc in docs: + self(doc) + yield doc + + def __call__(self, Doc doc): """Find all token sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. - RETURNS (list): A list of `(entity_key, label_id, start, end)` tuples, + RETURNS (list): A list of `(key, label_id, start, end)` tuples, describing the matches. A match tuple describes a span - `doc[start:end]`. The `label_id` and `entity_key` are both integers. + `doc[start:end]`. The `label_id` and `key` are both integers. """ - if acceptor is not None: - raise ValueError( - "acceptor keyword argument to Matcher deprecated. Specify acceptor " - "functions when you add patterns instead.") cdef vector[StateC] partials cdef int n_partials = 0 cdef int q = 0 @@ -343,13 +330,7 @@ cdef class Matcher: end = token_i+1 ent_id = state.second[1].attrs[0].value label = state.second[1].attrs[1].value - acceptor = self._acceptors.get(ent_id) - if acceptor is None: - matches.append((ent_id, label, start, end)) - else: - match = acceptor(doc, ent_id, label, start, end) - if match: - matches.append(match) + matches.append((ent_id, start, end)) partials.resize(q) # Check whether we open any new patterns on this token for pattern in self.patterns: @@ -374,13 +355,7 @@ cdef class Matcher: end = token_i+1 ent_id = pattern[1].attrs[0].value label = pattern[1].attrs[1].value - acceptor = self._acceptors.get(ent_id) - if acceptor is None: - matches.append((ent_id, label, start, end)) - else: - match = acceptor(doc, ent_id, label, start, end) - if match: - matches.append(match) + matches.append((ent_id, start, end)) # Look for open patterns that are actually satisfied for state in partials: while state.second.quantifier in (ZERO, ZERO_PLUS): @@ -390,13 +365,7 @@ cdef class Matcher: end = len(doc) ent_id = state.second.attrs[0].value label = state.second.attrs[0].value - acceptor = self._acceptors.get(ent_id) - if acceptor is None: - matches.append((ent_id, label, start, end)) - else: - match = acceptor(doc, ent_id, label, start, end) - if match: - matches.append(match) + matches.append((ent_id, start, end)) for i, (ent_id, label, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: @@ -404,18 +373,11 @@ cdef class Matcher: # TODO: only return (match_id, start, end) return matches - def pipe(self, docs, batch_size=1000, n_threads=2): - """Match a stream of documents, yielding them in turn. - - docs (iterable): A stream of documents. - batch_size (int): The number of documents to accumulate into a working set. - n_threads (int): The number of threads with which to work on the buffer - in parallel, if the `Matcher` implementation supports multi-threading. - YIELDS (Doc): Documents, in order. - """ - for doc in docs: - self(doc) - yield doc + def _normalize_key(self, key): + if isinstance(key, basestring): + return self.vocab.strings[key] + else: + return key def get_bilou(length): diff --git a/website/docs/api/matcher.jade b/website/docs/api/matcher.jade index 245f32eec..a2764e309 100644 --- a/website/docs/api/matcher.jade +++ b/website/docs/api/matcher.jade @@ -116,8 +116,9 @@ p Match a stream of documents, yielding them in turn. +tag method p - | Add one or more patterns to the matcher, along with a callback function - | to handle the matches. The callback function will receive the arguments + | Add a rule to the matcher, consisting of an ID key, one or more patterns, and + | a callback function to act on the matches. + | The callback function will receive the arguments | #[code matcher], #[code doc], #[code i] and #[code matches]. +aside-code("Example").