* Work on gazetteer matching

This commit is contained in:
Matthew Honnibal 2015-08-06 14:33:21 +02:00
parent 9c1724ecae
commit 5737115e1e
2 changed files with 62 additions and 4 deletions

View File

@ -11,6 +11,7 @@ from ..syntax.arc_eager import ArcEager
from ..syntax.ner import BiluoPushDown from ..syntax.ner import BiluoPushDown
from ..syntax.parser import ParserFactory from ..syntax.parser import ParserFactory
from ..serialize.bits import BitArray from ..serialize.bits import BitArray
from ..matcher import Matcher
from ..tokens import Doc from ..tokens import Doc
from ..multi_words import RegexMerger from ..multi_words import RegexMerger
@ -75,6 +76,7 @@ class English(object):
Tagger=EnPosTagger, Tagger=EnPosTagger,
Parser=ParserFactory(ParserTransitionSystem), Parser=ParserFactory(ParserTransitionSystem),
Entity=ParserFactory(EntityTransitionSystem), Entity=ParserFactory(EntityTransitionSystem),
Matcher=Matcher.from_dir,
Packer=None, Packer=None,
load_vectors=True load_vectors=True
): ):
@ -113,6 +115,10 @@ class English(object):
self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner')) self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner'))
else: else:
self.entity = None self.entity = None
if Matcher:
self.matcher = Matcher(self.vocab, data_dir)
else:
self.matcher = None
if Packer: if Packer:
self.packer = Packer(self.vocab, data_dir) self.packer = Packer(self.vocab, data_dir)
else: else:
@ -143,6 +149,8 @@ class English(object):
tokens = self.tokenizer(text) tokens = self.tokenizer(text)
if self.tagger and tag: if self.tagger and tag:
self.tagger(tokens) self.tagger(tokens)
if self.matcher and entity:
self.matcher(tokens)
if self.parser and parse: if self.parser and parse:
self.parser(tokens) self.parser(tokens)
if self.entity and entity: if self.entity and entity:

View File

@ -1,3 +1,5 @@
from os import path
from .typedefs cimport attr_t from .typedefs cimport attr_t
from .attrs cimport attr_id_t from .attrs cimport attr_id_t
from .structs cimport TokenC from .structs cimport TokenC
@ -5,11 +7,16 @@ from .structs cimport TokenC
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from libcpp.vector cimport vector from libcpp.vector cimport vector
from .attrs cimport LENGTH, ENT_TYPE from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
from .tokens.doc cimport get_token_attr from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .vocab cimport Vocab from .vocab cimport Vocab
try:
import ujson as json
except ImportError:
import json
cdef struct AttrValue: cdef struct AttrValue:
attr_id_t attr attr_id_t attr
@ -58,18 +65,61 @@ cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
return (pattern.spec[0].value, i - pattern.spec[1].value, i) return (pattern.spec[0].value, i - pattern.spec[1].value, i)
def _convert_strings(token_specs, string_store):
converted = []
for spec in token_specs:
converted.append([])
for attr, value in spec.items():
if isinstance(attr, basestring):
attr = map_attr_name(attr)
if isinstance(value, basestring):
value = string_store[value]
converted[-1].append((attr, value))
return converted
def map_attr_name(attr):
attr = attr.upper()
if attr == 'ORTH':
return ORTH
elif attr == 'LEMMA':
return LEMMA
elif attr == 'LOWER':
return LOWER
elif attr == 'SHAOE':
return SHAPE
elif attr == 'NORM':
return NORM
else:
raise Exception("TODO: Finish supporting attr mapping %s" % attr)
cdef class Matcher: cdef class Matcher:
cdef Pool mem cdef Pool mem
cdef Pattern** patterns cdef Pattern** patterns
cdef readonly int n_patterns cdef readonly int n_patterns
def __init__(self, patterns): def __init__(self, vocab, patterns):
self.mem = Pool() self.mem = Pool()
self.patterns = <Pattern**>self.mem.alloc(len(patterns), sizeof(Pattern*)) self.patterns = <Pattern**>self.mem.alloc(len(patterns), sizeof(Pattern*))
for i, (token_specs, entity_type) in enumerate(patterns): for i, (entity_key, (etype, attrs, specs)) in enumerate(sorted(patterns.items())):
self.patterns[i] = init_pattern(self.mem, token_specs, entity_type) if isinstance(entity_key, basestring):
entity_key = vocab.strings[entity_key]
if isinstance(etype, basestring):
etype = vocab.strings[etype]
specs = _convert_strings(specs, vocab.strings)
self.patterns[i] = init_pattern(self.mem, specs, etype)
self.n_patterns = len(patterns) self.n_patterns = len(patterns)
@classmethod
def from_dir(cls, vocab, data_dir):
patterns_loc = path.join(data_dir, 'ner', 'patterns.json')
if path.exists(patterns_loc):
patterns = json.loads(open(patterns_loc))
return cls(vocab, patterns)
else:
return cls(vocab, {})
def __call__(self, Doc doc): def __call__(self, Doc doc):
cdef vector[Pattern*] partials cdef vector[Pattern*] partials
cdef int n_partials = 0 cdef int n_partials = 0