mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Lots of updates to Matcher, to make entity handling sane.
This commit is contained in:
parent
7fd98fc91c
commit
6cbdc94959
|
@ -92,8 +92,8 @@ ctypedef TokenPatternC* TokenPatternC_ptr
|
||||||
ctypedef pair[int, TokenPatternC_ptr] StateC
|
ctypedef pair[int, TokenPatternC_ptr] StateC
|
||||||
|
|
||||||
|
|
||||||
cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, attr_t label,
|
||||||
attr_t entity_type) except NULL:
|
object token_specs) except NULL:
|
||||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||||
cdef int i
|
cdef int i
|
||||||
for i, (quantifier, spec) in enumerate(token_specs):
|
for i, (quantifier, spec) in enumerate(token_specs):
|
||||||
|
@ -108,7 +108,7 @@ cdef TokenPatternC* init_pattern(Pool mem, object token_specs, attr_t entity_id,
|
||||||
pattern[i].attrs[0].attr = ID
|
pattern[i].attrs[0].attr = ID
|
||||||
pattern[i].attrs[0].value = entity_id
|
pattern[i].attrs[0].value = entity_id
|
||||||
pattern[i].attrs[1].attr = ENT_TYPE
|
pattern[i].attrs[1].attr = ENT_TYPE
|
||||||
pattern[i].attrs[1].value = entity_type
|
pattern[i].attrs[1].value = label
|
||||||
pattern[i].nr_attr = 0
|
pattern[i].nr_attr = 0
|
||||||
return pattern
|
return pattern
|
||||||
|
|
||||||
|
@ -163,37 +163,14 @@ def _convert_strings(token_specs, string_store):
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def get_bilou(length):
|
|
||||||
if length == 1:
|
|
||||||
return [U_ENT]
|
|
||||||
elif length == 2:
|
|
||||||
return [B2_ENT, L2_ENT]
|
|
||||||
elif length == 3:
|
|
||||||
return [B3_ENT, I3_ENT, L3_ENT]
|
|
||||||
elif length == 4:
|
|
||||||
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
|
||||||
elif length == 5:
|
|
||||||
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
|
||||||
elif length == 6:
|
|
||||||
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
|
||||||
elif length == 7:
|
|
||||||
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
|
||||||
elif length == 8:
|
|
||||||
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
|
||||||
elif length == 9:
|
|
||||||
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
|
|
||||||
elif length == 10:
|
|
||||||
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
|
||||||
I10_ENT, I10_ENT, L10_ENT]
|
|
||||||
else:
|
|
||||||
raise ValueError("Max length currently 10 for phrase matching")
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Matcher:
|
cdef class Matcher:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef vector[TokenPatternC*] patterns
|
cdef vector[TokenPatternC*] patterns
|
||||||
cdef readonly Vocab vocab
|
cdef readonly Vocab vocab
|
||||||
cdef public object _patterns
|
cdef public object _patterns
|
||||||
|
cdef public object _entities
|
||||||
|
cdef public object _callbacks
|
||||||
|
cdef public object _acceptors
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, vocab):
|
def load(cls, path, vocab):
|
||||||
|
@ -205,12 +182,17 @@ cdef class Matcher:
|
||||||
return cls(vocab, patterns)
|
return cls(vocab, patterns)
|
||||||
|
|
||||||
def __init__(self, vocab, patterns={}):
|
def __init__(self, vocab, patterns={}):
|
||||||
self._patterns = dict(patterns) # Make sure we own the object
|
self._patterns = {}
|
||||||
|
self._entities = {}
|
||||||
|
self._acceptors = {}
|
||||||
|
self._callbacks = {}
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
for entity_key, (etype, attrs, specs) in sorted(self._patterns.items()):
|
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
|
||||||
self.add(entity_key, etype, attrs, specs)
|
self.add_entity(entity_key, attrs)
|
||||||
|
for spec in specs:
|
||||||
|
self.add_pattern(entity_key, spec, label=etype)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (self.__class__, (self.vocab, self._patterns), None, None)
|
return (self.__class__, (self.vocab, self._patterns), None, None)
|
||||||
|
@ -218,21 +200,67 @@ cdef class Matcher:
|
||||||
property n_patterns:
|
property n_patterns:
|
||||||
def __get__(self): return self.patterns.size()
|
def __get__(self): return self.patterns.size()
|
||||||
|
|
||||||
def add(self, entity_key, etype, attrs, specs):
|
def add_entity(self, entity_key, attrs=None, if_exists='raise',
|
||||||
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
|
acceptor=None, on_match=None):
|
||||||
if isinstance(entity_key, basestring):
|
if if_exists not in ('raise', 'ignore', 'update'):
|
||||||
entity_key = self.vocab.strings[entity_key]
|
raise ValueError(
|
||||||
if isinstance(etype, basestring):
|
"Unexpected value for if_exists: %s.\n"
|
||||||
etype = self.vocab.strings[etype]
|
"Expected one of: ['raise', 'ignore', 'update']" % if_exists)
|
||||||
elif etype is None:
|
if attrs is None:
|
||||||
etype = -1
|
attrs = {}
|
||||||
# TODO: Do something more clever about multiple patterns for single
|
entity_key = self.normalize_entity_key(entity_key)
|
||||||
# entity
|
if self.has_entity(entity_key):
|
||||||
|
if if_exists == 'raise':
|
||||||
|
raise KeyError(
|
||||||
|
"Tried to add entity %s. Entity exists, and if_exists='raise'.\n"
|
||||||
|
"Set if_exists='ignore' or if_exists='update', or check with "
|
||||||
|
"matcher.has_entity()")
|
||||||
|
elif if_exists == 'ignore':
|
||||||
|
return
|
||||||
|
self._entities[entity_key] = dict(attrs)
|
||||||
|
self._patterns.setdefault(entity_key, [])
|
||||||
|
self._acceptors[entity_key] = acceptor
|
||||||
|
self._callbacks[entity_key] = on_match
|
||||||
|
|
||||||
|
def add_pattern(self, entity_key, token_specs, label=""):
|
||||||
|
entity_key = self.normalize_entity_key(entity_key)
|
||||||
|
if not self.has_entity(entity_key):
|
||||||
|
self.add_entity(entity_key)
|
||||||
|
if isinstance(label, basestring):
|
||||||
|
label = self.vocab.strings[label]
|
||||||
|
|
||||||
|
spec = _convert_strings(token_specs, self.vocab.strings)
|
||||||
|
self.patterns.push_back(init_pattern(self.mem, entity_key, label, spec))
|
||||||
|
self._patterns[entity_key].append((label, token_specs))
|
||||||
|
|
||||||
|
def add(self, entity_key, label, attrs, specs, acceptor=None, on_match=None):
|
||||||
|
self.add_entity(entity_key, attrs=attrs, if_exists='update',
|
||||||
|
acceptor=acceptor, on_match=on_match)
|
||||||
for spec in specs:
|
for spec in specs:
|
||||||
spec = _convert_strings(spec, self.vocab.strings)
|
self.add_pattern(entity_key, spec, label=label)
|
||||||
self.patterns.push_back(init_pattern(self.mem, spec, entity_key, etype))
|
|
||||||
|
def normalize_entity_key(self, entity_key):
|
||||||
|
if isinstance(entity_key, basestring):
|
||||||
|
return self.vocab.strings[entity_key]
|
||||||
|
else:
|
||||||
|
return entity_key
|
||||||
|
|
||||||
|
def has_entity(self, entity_key):
|
||||||
|
entity_key = self.normalize_entity_key(entity_key)
|
||||||
|
return entity_key in self._entities
|
||||||
|
|
||||||
|
def get_entity(self, entity_key):
|
||||||
|
entity_key = self.normalize_entity_key(entity_key)
|
||||||
|
if entity_key in self._entities:
|
||||||
|
return self._entities[entity_key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def __call__(self, Doc doc, acceptor=None):
|
def __call__(self, Doc doc, acceptor=None):
|
||||||
|
if acceptor is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"acceptor keyword argument to Matcher deprecated. Specify acceptor "
|
||||||
|
"functions when you add patterns instead.")
|
||||||
cdef vector[StateC] partials
|
cdef vector[StateC] partials
|
||||||
cdef int n_partials = 0
|
cdef int n_partials = 0
|
||||||
cdef int q = 0
|
cdef int q = 0
|
||||||
|
@ -267,8 +295,12 @@ cdef class Matcher:
|
||||||
end = token_i+1
|
end = token_i+1
|
||||||
ent_id = state.second[1].attrs[0].value
|
ent_id = state.second[1].attrs[0].value
|
||||||
label = state.second[1].attrs[1].value
|
label = state.second[1].attrs[1].value
|
||||||
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
acceptor = self._acceptors.get(ent_id)
|
||||||
matches.append((ent_id, label, start, end))
|
if acceptor is not None:
|
||||||
|
match = acceptor(doc, ent_id, label, start, end)
|
||||||
|
if match:
|
||||||
|
ent_id, label, start, end = match
|
||||||
|
matches.append((ent_id, label, start, end))
|
||||||
partials.resize(q)
|
partials.resize(q)
|
||||||
# Check whether we open any new patterns on this token
|
# Check whether we open any new patterns on this token
|
||||||
for pattern in self.patterns:
|
for pattern in self.patterns:
|
||||||
|
@ -293,6 +325,10 @@ cdef class Matcher:
|
||||||
label = pattern[1].attrs[1].value
|
label = pattern[1].attrs[1].value
|
||||||
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
if acceptor is None or acceptor(doc, ent_id, label, start, end):
|
||||||
matches.append((ent_id, label, start, end))
|
matches.append((ent_id, label, start, end))
|
||||||
|
for i, (ent_id, label, start, end) in enumerate(matches):
|
||||||
|
on_match = self._callbacks.get(ent_id)
|
||||||
|
if on_match is not None:
|
||||||
|
on_match(self, doc, i, matches)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def pipe(self, docs, batch_size=1000, n_threads=2):
|
def pipe(self, docs, batch_size=1000, n_threads=2):
|
||||||
|
@ -301,6 +337,32 @@ cdef class Matcher:
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
|
|
||||||
|
def get_bilou(length):
|
||||||
|
if length == 1:
|
||||||
|
return [U_ENT]
|
||||||
|
elif length == 2:
|
||||||
|
return [B2_ENT, L2_ENT]
|
||||||
|
elif length == 3:
|
||||||
|
return [B3_ENT, I3_ENT, L3_ENT]
|
||||||
|
elif length == 4:
|
||||||
|
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
||||||
|
elif length == 5:
|
||||||
|
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
||||||
|
elif length == 6:
|
||||||
|
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
||||||
|
elif length == 7:
|
||||||
|
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
||||||
|
elif length == 8:
|
||||||
|
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
||||||
|
elif length == 9:
|
||||||
|
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT]
|
||||||
|
elif length == 10:
|
||||||
|
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
||||||
|
I10_ENT, I10_ENT, L10_ENT]
|
||||||
|
else:
|
||||||
|
raise ValueError("Max length currently 10 for phrase matching")
|
||||||
|
|
||||||
|
|
||||||
cdef class PhraseMatcher:
|
cdef class PhraseMatcher:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef Vocab vocab
|
cdef Vocab vocab
|
||||||
|
|
Loading…
Reference in New Issue
Block a user