spaCy/spacy/matcher.pyx
2015-09-06 17:53:12 +02:00

201 lines
6.1 KiB
Cython

from os import path
from .typedefs cimport attr_t
from .attrs cimport attr_id_t
from .structs cimport TokenC
from cymem.cymem cimport Pool
from libcpp.vector cimport vector
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25
from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc
from .vocab cimport Vocab
from libcpp.vector cimport vector
try:
import ujson as json
except ImportError:
import json
cdef struct AttrValue:
attr_id_t attr
attr_t value
cdef struct Pattern:
AttrValue* spec
int length
cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) except NULL:
pattern = <Pattern*>mem.alloc(len(token_specs) + 1, sizeof(Pattern))
cdef int i
for i, spec in enumerate(token_specs):
pattern[i].spec = <AttrValue*>mem.alloc(len(spec), sizeof(AttrValue))
pattern[i].length = len(spec)
for j, (attr, value) in enumerate(spec):
pattern[i].spec[j].attr = attr
pattern[i].spec[j].value = value
i = len(token_specs)
pattern[i].spec = <AttrValue*>mem.alloc(1, sizeof(AttrValue))
pattern[i].spec[0].attr = ENT_TYPE
pattern[i].spec[0].value = entity_type
pattern[i].spec[1].attr = LENGTH
pattern[i].spec[1].value = len(token_specs)
pattern[i].length = 0
return pattern
cdef int match(const Pattern* pattern, const TokenC* token) except -1:
cdef int i
for i in range(pattern.length):
if get_token_attr(token, pattern.spec[i].attr) != pattern.spec[i].value:
print "Pattern fail", pattern.spec[i].attr, pattern.spec[i].value
print get_token_attr(token, pattern.spec[i].attr)
return False
return True
cdef int is_final(const Pattern* pattern) except -1:
return (pattern + 1).length == 0
cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
pattern += 1
i += 1
return (pattern.spec[0].value, i - pattern.spec[1].value, i)
def _convert_strings(token_specs, string_store):
converted = []
for spec in token_specs:
converted.append([])
for attr, value in spec.items():
if isinstance(attr, basestring):
attr = map_attr_name(attr)
if isinstance(value, basestring):
value = string_store[value]
if isinstance(value, bool):
value = int(value)
converted[-1].append((attr, value))
print "Converted", converted[-1]
return converted
def map_attr_name(attr):
attr = attr.upper()
if attr == 'ORTH':
return ORTH
elif attr == 'LEMMA':
return LEMMA
elif attr == 'LOWER':
return LOWER
elif attr == 'SHAPE':
return SHAPE
elif attr == 'NORM':
return NORM
elif attr == 'FLAG13':
return FLAG13
elif attr == 'FLAG14':
return FLAG14
elif attr == 'FLAG15':
return FLAG15
elif attr == 'FLAG16':
return FLAG16
elif attr == 'FLAG17':
return FLAG17
elif attr == 'FLAG18':
return FLAG18
elif attr == 'FLAG19':
return FLAG19
elif attr == 'FLAG20':
return FLAG20
elif attr == 'FLAG21':
return FLAG21
elif attr == 'FLAG22':
return FLAG22
elif attr == 'FLAG23':
return FLAG23
elif attr == 'FLAG24':
return FLAG24
elif attr == 'FLAG25':
return FLAG25
else:
raise Exception("TODO: Finish supporting attr mapping %s" % attr)
cdef class Matcher:
cdef Pool mem
cdef vector[Pattern*] patterns
cdef readonly Vocab vocab
def __init__(self, vocab, patterns):
self.mem = Pool()
self.vocab = vocab
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add(entity_key, etype, attrs, specs)
@classmethod
def from_dir(cls, data_dir, Vocab vocab):
patterns_loc = path.join(data_dir, 'vocab', 'gazetteer.json')
if path.exists(patterns_loc):
patterns_data = open(patterns_loc).read()
patterns = json.loads(patterns_data)
return cls(vocab, patterns)
else:
return cls(vocab, {})
property n_patterns:
def __get__(self): return self.patterns.size()
def add(self, entity_key, etype, attrs, specs):
if isinstance(entity_key, basestring):
entity_key = self.vocab.strings[entity_key]
if isinstance(etype, basestring):
etype = self.vocab.strings[etype]
elif etype is None:
etype = -1
# TODO: Do something more clever about multiple patterns for single
# entity
for spec in specs:
assert len(spec) >= 1
spec = _convert_strings(spec, self.vocab.strings)
self.patterns.push_back(init_pattern(self.mem, spec, etype))
def __call__(self, Doc doc):
cdef vector[Pattern*] partials
cdef int n_partials = 0
cdef int q = 0
cdef int i, token_i
cdef const TokenC* token
cdef Pattern* state
matches = []
for token_i in range(doc.length):
print 'check', doc[token_i].orth_
token = &doc.data[token_i]
q = 0
for i in range(partials.size()):
state = partials.at(i)
if match(state, token):
print 'match!'
if is_final(state):
matches.append(get_entity(state, token, token_i))
else:
partials[q] = state + 1
q += 1
partials.resize(q)
for i in range(self.n_patterns):
state = self.patterns[i]
if match(state, token):
print 'match!'
if is_final(state):
matches.append(get_entity(state, token, token_i))
else:
partials.push_back(state + 1)
doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches
return matches