Lots of updates to Matcher, to make entity handling sane.

This commit is contained in:
Matthew Honnibal 2016-10-17 15:23:31 +02:00
parent 7fd98fc91c
commit 6cbdc94959

View File

@ -92,8 +92,8 @@ ctypedef TokenPatternC* TokenPatternC_ptr
ctypedef pair[int, TokenPatternC_ptr] StateC
cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
attr_t entity_type) except NULL:
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
object token_specs) except NULL:
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
cdef int i
for i, (quantifier, spec) in enumerate(token_specs):
@ -108,7 +108,7 @@ cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
pattern[i].attrs[0].attr = ID
pattern[i].attrs[0].value = entity_id
pattern[i].attrs[1].attr = ENT_TYPE
pattern[i].attrs[1].value = entity_type
pattern[i].attrs[1].value = label
pattern[i].nr_attr = 0
return pattern
@ -163,37 +163,14 @@ def _convert_strings(token_specs, string_store):
return tokens
def get_bilou(length):
if length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
elif length == 4:
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
elif length == 5:
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
elif length == 6:
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
elif length == 7:
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
elif length == 8:
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
elif length == 9:
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
elif length == 10:
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
I10_ENT, I10_ENT, L10_ENT]
else:
raise ValueError("Max length currently 10 for phrase matching")
cdef class Matcher:
cdef Pool mem
cdef vector[TokenPatternC*] patterns
cdef readonly Vocab vocab
cdef public object _patterns
cdef public object _entities
cdef public object _callbacks
cdef public object _acceptors
@classmethod
def load(cls, path, vocab):
@ -205,12 +182,17 @@ cdef class Matcher:
return cls(vocab, patterns)
def __init__(self, vocab, patterns={}):
self._patterns = dict(patterns) # Make sure we own the object
self._patterns = {}
self._entities = {}
self._acceptors = {}
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
self.vocab = vocab
for entity_key, (etype, attrs, specs) in sorted(self._patterns.items()):
self.add(entity_key, etype, attrs, specs)
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add_entity(entity_key, attrs)
for spec in specs:
self.add_pattern(entity_key, spec, label=etype)
def __reduce__(self):
return (self.__class__, (self.vocab, self._patterns), None, None)
@ -218,21 +200,67 @@ cdef class Matcher:
property n_patterns:
def __get__(self): return self.patterns.size()
def add(self, entity_key, etype, attrs, specs):
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
if isinstance(entity_key, basestring):
entity_key = self.vocab.strings[entity_key]
if isinstance(etype, basestring):
etype = self.vocab.strings[etype]
elif etype is None:
etype = -1
# TODO: Do something more clever about multiple patterns for single
# entity
def add_entity(self, entity_key, attrs=None, if_exists='raise',
acceptor=None, on_match=None):
if if_exists not in ('raise', 'ignore', 'update'):
raise ValueError(
"Unexpected value for if_exists: %s.\n"
"Expected one of: ['raise', 'ignore', 'update']" % if_exists)
if attrs is None:
attrs = {}
entity_key = self.normalize_entity_key(entity_key)
if self.has_entity(entity_key):
if if_exists == 'raise':
raise KeyError(
"Tried to add entity %s. Entity exists, and if_exists='raise'.\n"
"Set if_exists='ignore' or if_exists='update', or check with "
"matcher.has_entity()")
elif if_exists == 'ignore':
return
self._entities[entity_key] = dict(attrs)
self._patterns.setdefault(entity_key, [])
self._acceptors[entity_key] = acceptor
self._callbacks[entity_key] = on_match
def add_pattern(self, entity_key, token_specs, label=""):
entity_key = self.normalize_entity_key(entity_key)
if not self.has_entity(entity_key):
self.add_entity(entity_key)
if isinstance(label, basestring):
label = self.vocab.strings[label]
spec = _convert_strings(token_specs, self.vocab.strings)
self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec))
self._patterns[entity_key].append((label, token_specs))
def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None):
self.add_entity(entity_key, attrs=attrs, if_exists='update',
acceptor=acceptor, on_match=on_match)
for spec in specs:
spec = _convert_strings(spec, self.vocab.strings)
self.patterns.push_back(init_pattern(self.mem, spec, entity_key, etype))
self.add_pattern(entity_key, spec, label=label)
def normalize_entity_key(self, entity_key):
if isinstance(entity_key, basestring):
return self.vocab.strings[entity_key]
else:
return entity_key
def has_entity(self, entity_key):
entity_key = self.normalize_entity_key(entity_key)
return entity_key in self._entities
def get_entity(self, entity_key):
entity_key = self.normalize_entity_key(entity_key)
if entity_key in self._entities:
return self._entities[entity_key]
else:
return None
def __call__(self, Doc doc, acceptor=None):
if acceptor is not None:
raise ValueError(
"acceptor keyword argument to Matcher deprecated. Specify acceptor "
"functions when you add patterns instead.")
cdef vector[StateC] partials
cdef int n_partials = 0
cdef int q = 0
@ -267,7 +295,11 @@ cdef class Matcher:
end = token_i+1
ent_id = state.second[1].attrs[0].value
label = state.second[1].attrs[1].value
if acceptor is None or acceptor(doc, ent_id, label, start, end):
acceptor = self._acceptors.get(ent_id)
if acceptor is not None:
match = acceptor(doc, ent_id, label, start, end)
if match:
ent_id, label, start, end = match
matches.append((ent_id, label, start, end))
partials.resize(q)
# Check whether we open any new patterns on this token
@ -293,6 +325,10 @@ cdef class Matcher:
label = pattern[1].attrs[1].value
if acceptor is None or acceptor(doc, ent_id, label, start, end):
matches.append((ent_id, label, start, end))
for i, (ent_id, label, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matches)
return matches
def pipe(self, docs, batch_size=1000, n_threads=2):
@ -301,6 +337,32 @@ cdef class Matcher:
yield doc
def get_bilou(length):
if length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
elif length == 4:
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
elif length == 5:
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
elif length == 6:
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
elif length == 7:
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
elif length == 8:
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
elif length == 9:
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
elif length == 10:
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
I10_ENT, I10_ENT, L10_ENT]
else:
raise ValueError("Max length currently 10 for phrase matching")
cdef class PhraseMatcher:
cdef Pool mem
cdef Vocab vocab