* Fix loading of gazetteer.json file

This commit is contained in:
Matthew Honnibal 2015-08-06 16:08:25 +02:00
parent 9c667b7f15
commit cd7d1682cd

View File

@ -101,21 +101,28 @@ cdef class Matcher:
def __init__(self, vocab, patterns):
self.mem = Pool()
self.patterns = <Pattern**>self.mem.alloc(len(patterns), sizeof(Pattern*))
for i, (entity_key, (etype, attrs, specs)) in enumerate(sorted(patterns.items())):
n_patterns = sum([len(specs) for etype, attrs, specs in patterns.values()])
self.patterns = <Pattern**>self.mem.alloc(n_patterns, sizeof(Pattern*))
cdef int i = 0
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
if isinstance(entity_key, basestring):
entity_key = vocab.strings[entity_key]
if isinstance(etype, basestring):
etype = vocab.strings[etype]
specs = _convert_strings(specs, vocab.strings)
self.patterns[i] = init_pattern(self.mem, specs, etype)
# TODO: Do something more clever about multiple patterns for single
# entity
for spec in specs:
spec = _convert_strings(spec, vocab.strings)
self.patterns[i] = init_pattern(self.mem, spec, etype)
i += 1
self.n_patterns = len(patterns)
@classmethod
def from_dir(cls, vocab, data_dir):
patterns_loc = path.join(data_dir, 'ner', 'patterns.json')
patterns_loc = path.join(data_dir, 'vocab', 'gazetteer.json')
if path.exists(patterns_loc):
patterns = json.loads(open(patterns_loc))
patterns_data = open(patterns_loc).read()
patterns = json.loads(patterns_data)
return cls(vocab, patterns)
else:
return cls(vocab, {})