Update Matcher API

This commit is contained in:
Matthew Honnibal 2017-05-20 13:54:53 +02:00
parent 8c9b3d5ad7
commit ce9234f593
2 changed files with 99 additions and 136 deletions

View File

@ -87,7 +87,7 @@ ctypedef TokenPatternC* TokenPatternC_ptr
ctypedef pair[int, TokenPatternC_ptr] StateC
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
object token_specs) except NULL:
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
cdef int i
@ -99,15 +99,21 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
pattern[i].attrs[j].attr = attr
pattern[i].attrs[j].value = value
i = len(token_specs)
pattern[i].attrs = <AttrValueC*>mem.alloc(3, sizeof(AttrValueC))
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
pattern[i].attrs[0].attr = ID
pattern[i].attrs[0].value = entity_id
pattern[i].attrs[1].attr = ENT_TYPE
pattern[i].attrs[1].value = label
pattern[i].nr_attr = 0
return pattern
cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
while pattern.nr_attr != 0:
pattern += 1
id_attr = pattern[0].attrs[0]
assert id_attr.attr == ID
return id_attr.value
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
for attr in pattern.attrs[:pattern.nr_attr]:
if get_token_attr(token, attr.attr) != attr.value:
@ -175,12 +181,11 @@ cdef class Matcher:
cdef public object _callbacks
cdef public object _acceptors
def __init__(self, vocab, patterns={}):
def __init__(self, vocab):
"""Create the Matcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
patterns (dict): Patterns to add to the matcher.
RETURNS (Matcher): The newly constructed object.
"""
self._patterns = {}
@ -189,123 +194,105 @@ cdef class Matcher:
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add_entity(entity_key, attrs)
for spec in specs:
self.add_pattern(entity_key, spec, label=etype)
def __reduce__(self):
return (self.__class__, (self.vocab, self._patterns), None, None)
property n_patterns:
def __get__(self): return self.patterns.size()
def __len__(self):
return len(self._patterns)
def add_entity(self, entity_key, attrs=None, if_exists='raise',
acceptor=None, on_match=None):
# TODO: replace with new Matcher.add()
"""Add an entity to the matcher.
def __contains__(self, key):
return len(self._patterns)
entity_key (unicode or int): An ID for the entity.
attrs (dict): Attributes to associate with the `Matcher`.
if_exists (unicode): `'raise'`, `'ignore'` or `'update'`. Controls what
happens if the entity ID already exists. Defaults to `'raise'`.
acceptor (function): Callback function to filter matches of the entity.
on_match (function): Callback function to act on matches of the entity.
def add(self, key, on_match, *patterns):
"""Add a match-rule to the matcher.
A match-rule consists of: an ID key, an on_match callback,
and one or more patterns. If the key exists, the patterns
are appended to the previous ones, and the previous on_match
callback is replaced.
The on_match callback will receive the arguments
(matcher, doc, i, matches). Note that if no `on_match`
callback is specified, the document will not be modified.
A pattern consists of one or more token_specs,
where a token_spec is a dictionary mapping
attribute IDs to values. Token descriptors can also
include quantifiers. There are currently important
known problems with the quantifiers --- see the docs.
"""
if if_exists not in ('raise', 'ignore', 'update'):
raise ValueError(
"Unexpected value for if_exists: %s.\n"
"Expected one of: ['raise', 'ignore', 'update']" % if_exists)
if attrs is None:
attrs = {}
entity_key = self.normalize_entity_key(entity_key)
if self.has_entity(entity_key):
if if_exists == 'raise':
raise KeyError(
"Tried to add entity %s. Entity exists, and if_exists='raise'.\n"
"Set if_exists='ignore' or if_exists='update', or check with "
"matcher.has_entity()")
elif if_exists == 'ignore':
return
self._entities[entity_key] = dict(attrs)
self._patterns.setdefault(entity_key, [])
self._acceptors[entity_key] = acceptor
self._callbacks[entity_key] = on_match
def add_pattern(self, entity_key, token_specs, label=""):
# TODO: replace with new Matcher.add()
"""Add a pattern to the matcher.
entity_key (unicode): An ID for the entity.
token_specs (list): Description of the pattern to be matched.
label (unicode): Label to assign to the matched pattern. Defaults to `""`.
"""
token_specs = list(token_specs)
if len(token_specs) == 0:
for pattern in patterns:
if len(pattern) == 0:
msg = ("Cannot add pattern for zero tokens to matcher.\n"
"entity_key: {entity_key}\n"
"label: {label}")
raise ValueError(msg.format(entity_key=entity_key, label=label))
entity_key = self.normalize_entity_key(entity_key)
if not self.has_entity(entity_key):
self.add_entity(entity_key)
if isinstance(label, basestring):
label = self.vocab.strings[label]
elif label is None:
label = 0
spec = _convert_strings(token_specs, self.vocab.strings)
"key: {key}\n")
raise ValueError(msg.format(key=key))
key = self._normalize_key(key)
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec))
self._patterns[entity_key].append((label, token_specs))
for pattern in patterns:
specs = _convert_strings(pattern, self.vocab.strings)
self.patterns.push_back(init_pattern(self.mem, key, specs))
self._patterns[key].append(specs)
def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None):
# TODO: replace with new Matcher.add()
self.add_entity(entity_key, attrs=attrs, if_exists='update',
acceptor=acceptor, on_match=on_match)
for spec in specs:
self.add_pattern(entity_key, spec, label=label)
def remove(self, key):
"""Remove a rule from the matcher.
def normalize_entity_key(self, entity_key):
if isinstance(entity_key, basestring):
return self.vocab.strings[entity_key]
else:
return entity_key
def has_entity(self, entity_key):
# TODO: deprecate
"""Check whether the matcher has an entity.
entity_key (string or int): The entity key to check.
RETURNS (bool): Whether the matcher has the entity.
A KeyError is raised if the key does not exist.
"""
entity_key = self.normalize_entity_key(entity_key)
return entity_key in self._entities
def get_entity(self, entity_key):
# TODO: deprecate
"""Retrieve the attributes stored for an entity.
entity_key (unicode or int): The entity to retrieve.
RETURNS (dict): The entity attributes if present, otherwise None.
"""
entity_key = self.normalize_entity_key(entity_key)
if entity_key in self._entities:
return self._entities[entity_key]
key = self._normalize_key(key)
self._patterns.pop(key)
self._callbacks.pop(key)
cdef int i = 0
while i < self.patterns.size():
pattern_key = get_pattern_key(self.patterns.at(i))
if pattern_key == key:
self.patterns.erase(self.patterns.begin()+i)
else:
return None
i += 1
def __call__(self, Doc doc, acceptor=None):
def has_key(self, key):
"""Check whether the matcher has a rule with a given key.
key (string or int): The key to check.
RETURNS (bool): Whether the matcher has the rule.
"""
key = self._normalize_key(key)
return key in self._patterns
def get(self, key, default=None):
"""Retrieve the pattern stored for an entity.
key (unicode or int): The key to retrieve.
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
"""
key = self._normalize_key(key)
if key not in self._patterns:
return default
return (self._callbacks[key], self._patterns[key])
def pipe(self, docs, batch_size=1000, n_threads=2):
"""Match a stream of documents, yielding them in turn.
docs (iterable): A stream of documents.
batch_size (int): The number of documents to accumulate into a working set.
n_threads (int): The number of threads with which to work on the buffer
in parallel, if the `Matcher` implementation supports multi-threading.
YIELDS (Doc): Documents, in order.
"""
for doc in docs:
self(doc)
yield doc
def __call__(self, Doc doc):
"""Find all token sequences matching the supplied patterns on the `Doc`.
doc (Doc): The document to match over.
RETURNS (list): A list of `(entity_key, label_id, start, end)` tuples,
RETURNS (list): A list of `(key, label_id, start, end)` tuples,
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `entity_key` are both integers.
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
if acceptor is not None:
raise ValueError(
"acceptor keyword argument to Matcher deprecated. Specify acceptor "
"functions when you add patterns instead.")
cdef vector[StateC] partials
cdef int n_partials = 0
cdef int q = 0
@ -343,13 +330,7 @@ cdef class Matcher:
end = token_i+1
ent_id = state.second[1].attrs[0].value
label = state.second[1].attrs[1].value
acceptor = self._acceptors.get(ent_id)
if acceptor is None:
matches.append((ent_id, label, start, end))
else:
match = acceptor(doc, ent_id, label, start, end)
if match:
matches.append(match)
matches.append((ent_id, start, end))
partials.resize(q)
# Check whether we open any new patterns on this token
for pattern in self.patterns:
@ -374,13 +355,7 @@ cdef class Matcher:
end = token_i+1
ent_id = pattern[1].attrs[0].value
label = pattern[1].attrs[1].value
acceptor = self._acceptors.get(ent_id)
if acceptor is None:
matches.append((ent_id, label, start, end))
else:
match = acceptor(doc, ent_id, label, start, end)
if match:
matches.append(match)
matches.append((ent_id, start, end))
# Look for open patterns that are actually satisfied
for state in partials:
while state.second.quantifier in (ZERO, ZERO_PLUS):
@ -390,13 +365,7 @@ cdef class Matcher:
end = len(doc)
ent_id = state.second.attrs[0].value
label = state.second.attrs[0].value
acceptor = self._acceptors.get(ent_id)
if acceptor is None:
matches.append((ent_id, label, start, end))
else:
match = acceptor(doc, ent_id, label, start, end)
if match:
matches.append(match)
matches.append((ent_id, start, end))
for i, (ent_id, label, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
@ -404,18 +373,11 @@ cdef class Matcher:
# TODO: only return (match_id, start, end)
return matches
def pipe(self, docs, batch_size=1000, n_threads=2):
"""Match a stream of documents, yielding them in turn.
docs (iterable): A stream of documents.
batch_size (int): The number of documents to accumulate into a working set.
n_threads (int): The number of threads with which to work on the buffer
in parallel, if the `Matcher` implementation supports multi-threading.
YIELDS (Doc): Documents, in order.
"""
for doc in docs:
self(doc)
yield doc
def _normalize_key(self, key):
if isinstance(key, basestring):
return self.vocab.strings[key]
else:
return key
def get_bilou(length):

View File

@ -116,8 +116,9 @@ p Match a stream of documents, yielding them in turn.
+tag method
p
| Add one or more patterns to the matcher, along with a callback function
| to handle the matches. The callback function will receive the arguments
| Add a rule to the matcher, consisting of an ID key, one or more patterns, and
| a callback function to act on the matches.
| The callback function will receive the arguments
| #[code matcher], #[code doc], #[code i] and #[code matches].
+aside-code("Example").