* Refactor context extraction, and start breaking out gold standards into their own functions

This commit is contained in:
Matthew Honnibal 2014-11-09 15:43:07 +11:00
parent 602f993af9
commit f307eb2e36
4 changed files with 73 additions and 112 deletions

View File

@ -60,9 +60,7 @@ cdef class Slots:
cdef int N_FIELDS cdef int N_FIELDS
cdef hash_t fill_slots(Slots s, int i, Tokens tokens) except 0 cdef int fill_context(atom_t* context, int i, Tokens tokens) except -1
cdef int fill_flat(atom_t* context, Slots s) except -1
cpdef Slots FIELD_IDS cpdef Slots FIELD_IDS

View File

@ -60,111 +60,59 @@ cdef void _number_token(Token t, int* n_fields):
n_fields[0] = i n_fields[0] = i
cdef int fill_token(Token t, Lexeme* lex, atom_t pos, atom_t ner): cdef int _fill_token(atom_t* c, Token t, Lexeme* lex, atom_t pos, atom_t ner):
t.sic = lex.sic c[t.sic] = lex.sic
t.cluster = lex.cluster c[t.cluster] = lex.cluster
t.norm = lex.norm if (lex.prob != 0 and lex.prob >= -10) else lex.shape c[t.norm] = lex.norm if (lex.prob != 0 and lex.prob >= -10) else lex.shape
t.shape = lex.shape c[t.shape] = lex.shape
t.asciied = lex.asciied c[t.asciied] = lex.asciied
t.prefix = lex.prefix c[t.prefix] = lex.prefix
t.suffix = lex.suffix c[t.suffix] = lex.suffix
t.length = lex.length c[t.length] = lex.length
t.postype = lex.postype c[t.postype] = lex.postype
t.nertype = 0 c[t.nertype] = 0
t.sensetype = 0 c[t.sensetype] = 0
t.is_alpha = lex.flags & (1 << IS_ALPHA) c[t.is_alpha] = lex.flags & (1 << IS_ALPHA)
t.is_digit = lex.flags & (1 << IS_DIGIT) c[t.is_digit] = lex.flags & (1 << IS_DIGIT)
t.is_lower = lex.flags & (1 << IS_LOWER) c[t.is_lower] = lex.flags & (1 << IS_LOWER)
t.is_punct = lex.flags & (1 << IS_PUNCT) c[t.is_punct] = lex.flags & (1 << IS_PUNCT)
t.is_space = lex.flags & (1 << IS_SPACE) c[t.is_space] = lex.flags & (1 << IS_SPACE)
t.is_title = lex.flags & (1 << IS_TITLE) c[t.is_title] = lex.flags & (1 << IS_TITLE)
t.is_upper = lex.flags & (1 << IS_UPPER) c[t.is_upper] = lex.flags & (1 << IS_UPPER)
t.like_url = lex.flags & (1 << LIKE_URL) c[t.like_url] = lex.flags & (1 << LIKE_URL)
t.like_number = lex.flags & (1 << LIKE_NUMBER) c[t.like_number] = lex.flags & (1 << LIKE_NUMBER)
t.oft_lower = lex.flags & (1 << OFT_LOWER) c[t.oft_lower] = lex.flags & (1 << OFT_LOWER)
t.oft_title = lex.flags & (1 << OFT_TITLE) c[t.oft_title] = lex.flags & (1 << OFT_TITLE)
t.oft_upper = lex.flags & (1 << OFT_UPPER) c[t.oft_upper] = lex.flags & (1 << OFT_UPPER)
t.in_males = lex.flags & (1 << IN_MALES) c[t.in_males] = lex.flags & (1 << IN_MALES)
t.in_females = lex.flags & (1 << IN_FEMALES) c[t.in_females] = lex.flags & (1 << IN_FEMALES)
t.in_surnames = lex.flags & (1 << IN_SURNAMES) c[t.in_surnames] = lex.flags & (1 << IN_SURNAMES)
t.in_places = lex.flags & (1 << IN_PLACES) c[t.in_places] = lex.flags & (1 << IN_PLACES)
t.in_games = lex.flags & (1 << IN_GAMES) c[t.in_games] = lex.flags & (1 << IN_GAMES)
t.in_celebs = lex.flags & (1 << IN_CELEBS) c[t.in_celebs] = lex.flags & (1 << IN_CELEBS)
t.in_names = lex.flags & (1 << IN_NAMES) c[t.in_names] = lex.flags & (1 << IN_NAMES)
t.pos = pos c[t.pos] = pos
t.sense = 0 c[t.sense] = 0
t.ner = ner c[t.ner] = ner
cdef int _flatten_token(atom_t* context, Token ids, Token vals) except -1: cdef int fill_context(atom_t* context, int i, Tokens tokens) except -1:
context[ids.sic] = vals.sic _fill_token(context, FIELD_IDS.P4, tokens.lex[i-4], tokens.pos[i-4], tokens.ner[i-4])
context[ids.cluster] = vals.cluster _fill_token(context, FIELD_IDS.P3, tokens.lex[i-3], tokens.pos[i-3], tokens.ner[i-3])
context[ids.norm] = vals.norm _fill_token(context, FIELD_IDS.P2, tokens.lex[i-2], tokens.pos[i-2], tokens.ner[i-2])
context[ids.shape] = vals.shape _fill_token(context, FIELD_IDS.P1, tokens.lex[i-1], tokens.pos[i-1], tokens.ner[i-1])
context[ids.asciied] = vals.asciied _fill_token(context, FIELD_IDS.N0, tokens.lex[i], tokens.pos[i], tokens.ner[i])
context[ids.prefix] = vals.prefix _fill_token(context, FIELD_IDS.N1, tokens.lex[i+1], tokens.pos[i+1], tokens.ner[i+1])
context[ids.suffix] = vals.suffix _fill_token(context, FIELD_IDS.N2, tokens.lex[i+2], tokens.pos[i+2], tokens.ner[i+2])
context[ids.length] = vals.length _fill_token(context, FIELD_IDS.N3, tokens.lex[i+3], tokens.pos[i+3], tokens.ner[i+3])
_fill_token(context, FIELD_IDS.N4, tokens.lex[i+4], tokens.pos[i+4], tokens.ner[i+4])
context[ids.postype] = vals.postype
context[ids.nertype] = vals.nertype
context[ids.sensetype] = vals.sensetype
context[ids.is_alpha] = vals.is_alpha
context[ids.is_ascii] = vals.is_ascii
context[ids.is_digit] = vals.is_digit
context[ids.is_lower] = vals.is_lower
context[ids.is_punct] = vals.is_punct
context[ids.is_title] = vals.is_title
context[ids.is_upper] = vals.is_upper
context[ids.like_url] = vals.like_url
context[ids.like_number] = vals.like_number
context[ids.oft_lower] = vals.oft_lower
context[ids.oft_title] = vals.oft_title
context[ids.oft_upper] = vals.oft_upper
context[ids.in_males] = vals.in_males
context[ids.in_females] = vals.in_females
context[ids.in_surnames] = vals.in_surnames
context[ids.in_places] = vals.in_places
context[ids.in_games] = vals.in_games
context[ids.in_celebs] = vals.in_celebs
context[ids.in_names] = vals.in_names
context[ids.pos] = vals.pos
context[ids.sense] = vals.sense
context[ids.ner] = vals.ner
cdef hash_t fill_slots(Slots s, int i, Tokens tokens) except 0:
fill_token(s.P4, tokens.lex[i-4], tokens.pos[i-4], tokens.ner[i-4])
fill_token(s.P3, tokens.lex[i-3], tokens.pos[i-3], tokens.ner[i-3])
fill_token(s.P2, tokens.lex[i-2], tokens.pos[i-2], tokens.ner[i-2])
fill_token(s.P1, tokens.lex[i-1], tokens.pos[i-1], tokens.ner[i-1])
fill_token(s.N0, tokens.lex[i], tokens.pos[i], tokens.ner[i])
fill_token(s.N1, tokens.lex[i+1], tokens.pos[i+1], tokens.ner[i+1])
fill_token(s.N2, tokens.lex[i+2], tokens.pos[i+2], tokens.ner[i+2])
fill_token(s.N3, tokens.lex[i+3], tokens.pos[i+3], tokens.ner[i+3])
fill_token(s.N4, tokens.lex[i+4], tokens.pos[i+4], tokens.ner[i+4])
return 1 return 1
cdef int fill_flat(atom_t* context, Slots s) except -1:
_flatten_token(context, FIELD_IDS.P4, s.P4)
_flatten_token(context, FIELD_IDS.P3, s.P3)
_flatten_token(context, FIELD_IDS.P2, s.P2)
_flatten_token(context, FIELD_IDS.P1, s.P1)
_flatten_token(context, FIELD_IDS.N0, s.N0)
_flatten_token(context, FIELD_IDS.N1, s.N1)
_flatten_token(context, FIELD_IDS.N2, s.N2)
_flatten_token(context, FIELD_IDS.N3, s.N4)
_flatten_token(context, FIELD_IDS.N4, s.N4)
N_FIELDS = 0 N_FIELDS = 0
FIELD_IDS = Slots() FIELD_IDS = Slots()
_number_token(FIELD_IDS.P4, &N_FIELDS) _number_token(FIELD_IDS.P4, &N_FIELDS)

View File

@ -28,8 +28,7 @@ cdef class Tagger:
cpdef readonly list tag_names cpdef readonly list tag_names
cdef class_t _guess cdef class_t _guess
cdef atom_t* _context_flat cdef atom_t* _context
cdef Slots _context_slots
cdef feat_t* _feats cdef feat_t* _feats
cdef weight_t* _values cdef weight_t* _values
cdef weight_t* _scores cdef weight_t* _scores

View File

@ -11,8 +11,7 @@ import json
import cython import cython
from .context cimport fill_slots from .context cimport fill_context
from .context cimport fill_flat
from .context cimport N_FIELDS from .context cimport N_FIELDS
from thinc.features cimport ConjFeat from thinc.features cimport ConjFeat
@ -35,19 +34,24 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
def train(train_sents, model_dir, nr_iter=10): def train(train_sents, model_dir, nr_iter=10):
cdef Tokens tokens
tagger = Tagger(model_dir) tagger = Tagger(model_dir)
for _ in range(nr_iter): for _ in range(nr_iter):
n_corr = 0 n_corr = 0
total = 0 total = 0
for tokens, golds in train_sents: for tokens, golds in train_sents:
assert len(tokens) == len(golds), [t.string for t in tokens] assert len(tokens) == len(golds), [t.string for t in tokens]
for i, gold in enumerate(golds): for i in range(tokens.length):
if tagger.tag_type == POS:
gold = _get_gold_pos(i, golds, tokens.pos)
elif tagger.tag_type == ENTITY:
gold = _get_gold_ner(i, golds, tokens.ner)
guess = tagger.predict(i, tokens) guess = tagger.predict(i, tokens)
tokens.set_tag(i, tagger.tag_type, guess) tokens.set_tag(i, tagger.tag_type, guess)
if gold != NULL_TAG: if gold is not None:
tagger.tell_answer([gold]) tagger.tell_answer(gold)
total += 1 total += 1
n_corr += guess == gold n_corr += guess in gold
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold)) #print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
print('%.4f' % ((n_corr / total) * 100)) print('%.4f' % ((n_corr / total) * 100))
random.shuffle(train_sents) random.shuffle(train_sents)
@ -55,6 +59,20 @@ def train(train_sents, model_dir, nr_iter=10):
tagger.model.dump(path.join(model_dir, 'model')) tagger.model.dump(path.join(model_dir, 'model'))
cdef object _get_gold_pos(i, golds, int* pred):
if golds[i] == 0:
return None
else:
return [golds[i]]
cdef object _get_gold_ner(i, golds, int* ner):
if golds[i] == 0:
return None
else:
return [golds[i]]
def evaluate(tagger, sents): def evaluate(tagger, sents):
n_corr = 0 n_corr = 0
total = 0 total = 0
@ -83,8 +101,7 @@ cdef class Tagger:
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
self.model.load(path.join(model_dir, 'model')) self.model.load(path.join(model_dir, 'model'))
self._context_flat = <atom_t*>self.mem.alloc(N_FIELDS, sizeof(atom_t)) self._context = <atom_t*>self.mem.alloc(N_FIELDS, sizeof(atom_t))
self._context_slots = Slots()
self._feats = <feat_t*>self.mem.alloc(self.extractor.n+1, sizeof(feat_t)) self._feats = <feat_t*>self.mem.alloc(self.extractor.n+1, sizeof(feat_t))
self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t)) self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t))
self._scores = <weight_t*>self.mem.alloc(self.model.nr_class, sizeof(weight_t)) self._scores = <weight_t*>self.mem.alloc(self.model.nr_class, sizeof(weight_t))
@ -110,9 +127,8 @@ cdef class Tagger:
>>> tag = EN.pos_tagger.predict(0, tokens) >>> tag = EN.pos_tagger.predict(0, tokens)
>>> assert tag == EN.pos_tagger.tag_id('DT') == 5 >>> assert tag == EN.pos_tagger.tag_id('DT') == 5
""" """
cdef hash_t hashed = fill_slots(self._context_slots, i, tokens) fill_context(self._context, i, tokens)
fill_flat(self._context_flat, self._context_slots) self.extractor.extract(self._feats, self._values, self._context, NULL)
self.extractor.extract(self._feats, self._values, self._context_flat, NULL)
self._guess = self.model.score(self._scores, self._feats, self._values) self._guess = self.model.score(self._scores, self._feats, self._values)
return self._guess return self._guess