* Refactor context extraction, and start breaking out gold standards into their own functions

This commit is contained in:
Matthew Honnibal 2014-11-09 15:43:07 +11:00
parent 602f993af9
commit f307eb2e36
4 changed files with 73 additions and 112 deletions

View File

@ -60,9 +60,7 @@ cdef class Slots:
cdef int N_FIELDS
cdef hash_t fill_slots(Slots s, int i, Tokens tokens) except 0
cdef int fill_flat(atom_t* context, Slots s) except -1
cdef int fill_context(atom_t* context, int i, Tokens tokens) except -1
cpdef Slots FIELD_IDS

View File

@ -60,111 +60,59 @@ cdef void _number_token(Token t, int* n_fields):
n_fields[0] = i
cdef int fill_token(Token t, Lexeme* lex, atom_t pos, atom_t ner):
t.sic = lex.sic
t.cluster = lex.cluster
t.norm = lex.norm if (lex.prob != 0 and lex.prob >= -10) else lex.shape
t.shape = lex.shape
t.asciied = lex.asciied
t.prefix = lex.prefix
t.suffix = lex.suffix
t.length = lex.length
cdef int _fill_token(atom_t* c, Token t, Lexeme* lex, atom_t pos, atom_t ner):
c[t.sic] = lex.sic
c[t.cluster] = lex.cluster
c[t.norm] = lex.norm if (lex.prob != 0 and lex.prob >= -10) else lex.shape
c[t.shape] = lex.shape
c[t.asciied] = lex.asciied
c[t.prefix] = lex.prefix
c[t.suffix] = lex.suffix
c[t.length] = lex.length
t.postype = lex.postype
t.nertype = 0
t.sensetype = 0
c[t.postype] = lex.postype
c[t.nertype] = 0
c[t.sensetype] = 0
t.is_alpha = lex.flags & (1 << IS_ALPHA)
t.is_digit = lex.flags & (1 << IS_DIGIT)
t.is_lower = lex.flags & (1 << IS_LOWER)
t.is_punct = lex.flags & (1 << IS_PUNCT)
t.is_space = lex.flags & (1 << IS_SPACE)
t.is_title = lex.flags & (1 << IS_TITLE)
t.is_upper = lex.flags & (1 << IS_UPPER)
t.like_url = lex.flags & (1 << LIKE_URL)
t.like_number = lex.flags & (1 << LIKE_NUMBER)
t.oft_lower = lex.flags & (1 << OFT_LOWER)
t.oft_title = lex.flags & (1 << OFT_TITLE)
t.oft_upper = lex.flags & (1 << OFT_UPPER)
c[t.is_alpha] = lex.flags & (1 << IS_ALPHA)
c[t.is_digit] = lex.flags & (1 << IS_DIGIT)
c[t.is_lower] = lex.flags & (1 << IS_LOWER)
c[t.is_punct] = lex.flags & (1 << IS_PUNCT)
c[t.is_space] = lex.flags & (1 << IS_SPACE)
c[t.is_title] = lex.flags & (1 << IS_TITLE)
c[t.is_upper] = lex.flags & (1 << IS_UPPER)
c[t.like_url] = lex.flags & (1 << LIKE_URL)
c[t.like_number] = lex.flags & (1 << LIKE_NUMBER)
c[t.oft_lower] = lex.flags & (1 << OFT_LOWER)
c[t.oft_title] = lex.flags & (1 << OFT_TITLE)
c[t.oft_upper] = lex.flags & (1 << OFT_UPPER)
t.in_males = lex.flags & (1 << IN_MALES)
t.in_females = lex.flags & (1 << IN_FEMALES)
t.in_surnames = lex.flags & (1 << IN_SURNAMES)
t.in_places = lex.flags & (1 << IN_PLACES)
t.in_games = lex.flags & (1 << IN_GAMES)
t.in_celebs = lex.flags & (1 << IN_CELEBS)
t.in_names = lex.flags & (1 << IN_NAMES)
c[t.in_males] = lex.flags & (1 << IN_MALES)
c[t.in_females] = lex.flags & (1 << IN_FEMALES)
c[t.in_surnames] = lex.flags & (1 << IN_SURNAMES)
c[t.in_places] = lex.flags & (1 << IN_PLACES)
c[t.in_games] = lex.flags & (1 << IN_GAMES)
c[t.in_celebs] = lex.flags & (1 << IN_CELEBS)
c[t.in_names] = lex.flags & (1 << IN_NAMES)
t.pos = pos
t.sense = 0
t.ner = ner
c[t.pos] = pos
c[t.sense] = 0
c[t.ner] = ner
cdef int _flatten_token(atom_t* context, Token ids, Token vals) except -1:
context[ids.sic] = vals.sic
context[ids.cluster] = vals.cluster
context[ids.norm] = vals.norm
context[ids.shape] = vals.shape
context[ids.asciied] = vals.asciied
context[ids.prefix] = vals.prefix
context[ids.suffix] = vals.suffix
context[ids.length] = vals.length
context[ids.postype] = vals.postype
context[ids.nertype] = vals.nertype
context[ids.sensetype] = vals.sensetype
context[ids.is_alpha] = vals.is_alpha
context[ids.is_ascii] = vals.is_ascii
context[ids.is_digit] = vals.is_digit
context[ids.is_lower] = vals.is_lower
context[ids.is_punct] = vals.is_punct
context[ids.is_title] = vals.is_title
context[ids.is_upper] = vals.is_upper
context[ids.like_url] = vals.like_url
context[ids.like_number] = vals.like_number
context[ids.oft_lower] = vals.oft_lower
context[ids.oft_title] = vals.oft_title
context[ids.oft_upper] = vals.oft_upper
context[ids.in_males] = vals.in_males
context[ids.in_females] = vals.in_females
context[ids.in_surnames] = vals.in_surnames
context[ids.in_places] = vals.in_places
context[ids.in_games] = vals.in_games
context[ids.in_celebs] = vals.in_celebs
context[ids.in_names] = vals.in_names
context[ids.pos] = vals.pos
context[ids.sense] = vals.sense
context[ids.ner] = vals.ner
cdef hash_t fill_slots(Slots s, int i, Tokens tokens) except 0:
fill_token(s.P4, tokens.lex[i-4], tokens.pos[i-4], tokens.ner[i-4])
fill_token(s.P3, tokens.lex[i-3], tokens.pos[i-3], tokens.ner[i-3])
fill_token(s.P2, tokens.lex[i-2], tokens.pos[i-2], tokens.ner[i-2])
fill_token(s.P1, tokens.lex[i-1], tokens.pos[i-1], tokens.ner[i-1])
fill_token(s.N0, tokens.lex[i], tokens.pos[i], tokens.ner[i])
fill_token(s.N1, tokens.lex[i+1], tokens.pos[i+1], tokens.ner[i+1])
fill_token(s.N2, tokens.lex[i+2], tokens.pos[i+2], tokens.ner[i+2])
fill_token(s.N3, tokens.lex[i+3], tokens.pos[i+3], tokens.ner[i+3])
fill_token(s.N4, tokens.lex[i+4], tokens.pos[i+4], tokens.ner[i+4])
cdef int fill_context(atom_t* context, int i, Tokens tokens) except -1:
_fill_token(context, FIELD_IDS.P4, tokens.lex[i-4], tokens.pos[i-4], tokens.ner[i-4])
_fill_token(context, FIELD_IDS.P3, tokens.lex[i-3], tokens.pos[i-3], tokens.ner[i-3])
_fill_token(context, FIELD_IDS.P2, tokens.lex[i-2], tokens.pos[i-2], tokens.ner[i-2])
_fill_token(context, FIELD_IDS.P1, tokens.lex[i-1], tokens.pos[i-1], tokens.ner[i-1])
_fill_token(context, FIELD_IDS.N0, tokens.lex[i], tokens.pos[i], tokens.ner[i])
_fill_token(context, FIELD_IDS.N1, tokens.lex[i+1], tokens.pos[i+1], tokens.ner[i+1])
_fill_token(context, FIELD_IDS.N2, tokens.lex[i+2], tokens.pos[i+2], tokens.ner[i+2])
_fill_token(context, FIELD_IDS.N3, tokens.lex[i+3], tokens.pos[i+3], tokens.ner[i+3])
_fill_token(context, FIELD_IDS.N4, tokens.lex[i+4], tokens.pos[i+4], tokens.ner[i+4])
return 1
cdef int fill_flat(atom_t* context, Slots s) except -1:
_flatten_token(context, FIELD_IDS.P4, s.P4)
_flatten_token(context, FIELD_IDS.P3, s.P3)
_flatten_token(context, FIELD_IDS.P2, s.P2)
_flatten_token(context, FIELD_IDS.P1, s.P1)
_flatten_token(context, FIELD_IDS.N0, s.N0)
_flatten_token(context, FIELD_IDS.N1, s.N1)
_flatten_token(context, FIELD_IDS.N2, s.N2)
_flatten_token(context, FIELD_IDS.N3, s.N4)
_flatten_token(context, FIELD_IDS.N4, s.N4)
N_FIELDS = 0
FIELD_IDS = Slots()
_number_token(FIELD_IDS.P4, &N_FIELDS)

View File

@ -28,8 +28,7 @@ cdef class Tagger:
cpdef readonly list tag_names
cdef class_t _guess
cdef atom_t* _context_flat
cdef Slots _context_slots
cdef atom_t* _context
cdef feat_t* _feats
cdef weight_t* _values
cdef weight_t* _scores

View File

@ -11,8 +11,7 @@ import json
import cython
from .context cimport fill_slots
from .context cimport fill_flat
from .context cimport fill_context
from .context cimport N_FIELDS
from thinc.features cimport ConjFeat
@ -35,19 +34,24 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
def train(train_sents, model_dir, nr_iter=10):
cdef Tokens tokens
tagger = Tagger(model_dir)
for _ in range(nr_iter):
n_corr = 0
total = 0
for tokens, golds in train_sents:
assert len(tokens) == len(golds), [t.string for t in tokens]
for i, gold in enumerate(golds):
for i in range(tokens.length):
if tagger.tag_type == POS:
gold = _get_gold_pos(i, golds, tokens.pos)
elif tagger.tag_type == ENTITY:
gold = _get_gold_ner(i, golds, tokens.ner)
guess = tagger.predict(i, tokens)
tokens.set_tag(i, tagger.tag_type, guess)
if gold != NULL_TAG:
tagger.tell_answer([gold])
if gold is not None:
tagger.tell_answer(gold)
total += 1
n_corr += guess == gold
n_corr += guess in gold
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
print('%.4f' % ((n_corr / total) * 100))
random.shuffle(train_sents)
@ -55,6 +59,20 @@ def train(train_sents, model_dir, nr_iter=10):
tagger.model.dump(path.join(model_dir, 'model'))
cdef object _get_gold_pos(i, golds, int* pred):
if golds[i] == 0:
return None
else:
return [golds[i]]
cdef object _get_gold_ner(i, golds, int* ner):
if golds[i] == 0:
return None
else:
return [golds[i]]
def evaluate(tagger, sents):
n_corr = 0
total = 0
@ -83,8 +101,7 @@ cdef class Tagger:
if path.exists(path.join(model_dir, 'model')):
self.model.load(path.join(model_dir, 'model'))
self._context_flat = <atom_t*>self.mem.alloc(N_FIELDS, sizeof(atom_t))
self._context_slots = Slots()
self._context = <atom_t*>self.mem.alloc(N_FIELDS, sizeof(atom_t))
self._feats = <feat_t*>self.mem.alloc(self.extractor.n+1, sizeof(feat_t))
self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t))
self._scores = <weight_t*>self.mem.alloc(self.model.nr_class, sizeof(weight_t))
@ -110,9 +127,8 @@ cdef class Tagger:
>>> tag = EN.pos_tagger.predict(0, tokens)
>>> assert tag == EN.pos_tagger.tag_id('DT') == 5
"""
cdef hash_t hashed = fill_slots(self._context_slots, i, tokens)
fill_flat(self._context_flat, self._context_slots)
self.extractor.extract(self._feats, self._values, self._context_flat, NULL)
fill_context(self._context, i, tokens)
self.extractor.extract(self._feats, self._values, self._context, NULL)
self._guess = self.model.score(self._scores, self._feats, self._values)
return self._guess