* Work on gazetteer matching

This commit is contained in:
Matthew Honnibal 2015-08-06 14:33:21 +02:00
parent 9c1724ecae
commit 5737115e1e
2 changed files with 62 additions and 4 deletions

View File

@ -11,6 +11,7 @@ from ..syntax.arc_eager import ArcEager
from ..syntax.ner import BiluoPushDown
from ..syntax.parser import ParserFactory
from ..serialize.bits import BitArray
from ..matcher import Matcher
from ..tokens import Doc
from ..multi_words import RegexMerger
@ -75,6 +76,7 @@ class English(object):
Tagger=EnPosTagger,
Parser=ParserFactory(ParserTransitionSystem),
Entity=ParserFactory(EntityTransitionSystem),
Matcher=Matcher.from_dir,
Packer=None,
load_vectors=True
):
@ -113,6 +115,10 @@ class English(object):
self.entity = Entity(self.vocab.strings, path.join(data_dir, 'ner'))
else:
self.entity = None
if Matcher:
self.matcher = Matcher(self.vocab, data_dir)
else:
self.matcher = None
if Packer:
self.packer = Packer(self.vocab, data_dir)
else:
@ -143,6 +149,8 @@ class English(object):
tokens = self.tokenizer(text)
if self.tagger and tag:
self.tagger(tokens)
if self.matcher and entity:
self.matcher(tokens)
if self.parser and parse:
self.parser(tokens)
if self.entity and entity:

View File

@ -1,3 +1,5 @@
from os import path
from .typedefs cimport attr_t
from .attrs cimport attr_id_t
from .structs cimport TokenC
@ -5,11 +7,16 @@ from .structs cimport TokenC
from cymem.cymem cimport Pool
from libcpp.vector cimport vector
from .attrs cimport LENGTH, ENT_TYPE
from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE
from .tokens.doc cimport get_token_attr
from .tokens.doc cimport Doc
from .vocab cimport Vocab
try:
import ujson as json
except ImportError:
import json
cdef struct AttrValue:
attr_id_t attr
@ -58,18 +65,61 @@ cdef object get_entity(const Pattern* pattern, const TokenC* tokens, int i):
return (pattern.spec[0].value, i - pattern.spec[1].value, i)
def _convert_strings(token_specs, string_store):
converted = []
for spec in token_specs:
converted.append([])
for attr, value in spec.items():
if isinstance(attr, basestring):
attr = map_attr_name(attr)
if isinstance(value, basestring):
value = string_store[value]
converted[-1].append((attr, value))
return converted
def map_attr_name(attr):
attr = attr.upper()
if attr == 'ORTH':
return ORTH
elif attr == 'LEMMA':
return LEMMA
elif attr == 'LOWER':
return LOWER
elif attr == 'SHAOE':
return SHAPE
elif attr == 'NORM':
return NORM
else:
raise Exception("TODO: Finish supporting attr mapping %s" % attr)
cdef class Matcher:
cdef Pool mem
cdef Pattern** patterns
cdef readonly int n_patterns
def __init__(self, patterns):
def __init__(self, vocab, patterns):
self.mem = Pool()
self.patterns = <Pattern**>self.mem.alloc(len(patterns), sizeof(Pattern*))
for i, (token_specs, entity_type) in enumerate(patterns):
self.patterns[i] = init_pattern(self.mem, token_specs, entity_type)
for i, (entity_key, (etype, attrs, specs)) in enumerate(sorted(patterns.items())):
if isinstance(entity_key, basestring):
entity_key = vocab.strings[entity_key]
if isinstance(etype, basestring):
etype = vocab.strings[etype]
specs = _convert_strings(specs, vocab.strings)
self.patterns[i] = init_pattern(self.mem, specs, etype)
self.n_patterns = len(patterns)
@classmethod
def from_dir(cls, vocab, data_dir):
patterns_loc = path.join(data_dir, 'ner', 'patterns.json')
if path.exists(patterns_loc):
patterns = json.loads(open(patterns_loc))
return cls(vocab, patterns)
else:
return cls(vocab, {})
def __call__(self, Doc doc):
cdef vector[Pattern*] partials
cdef int n_partials = 0