* Update for thinc 5.0, including changing cost from int to weight_t, and updating the tagger and parser

This commit is contained in:
Matthew Honnibal 2016-01-30 14:31:12 +01:00
parent ea4ff94cde
commit 10877a7791
8 changed files with 89 additions and 71 deletions

View File

@ -12,6 +12,6 @@ cdef class ArcEager(TransitionSystem):
pass
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil
cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil

View File

@ -48,8 +48,8 @@ MOVE_NAMES[BREAK] = 'B'
# Helper functions for the arc-eager oracle
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0
cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0
cdef int i, S_i
for i in range(stcls.stack_depth()):
S_i = stcls.S(i)
@ -61,8 +61,8 @@ cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
return cost
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0
cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef weight_t cost = 0
cdef int i, B_i
for i in range(stcls.buffer_length()):
B_i = stcls.B(i)
@ -70,11 +70,12 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0
if Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0:
cost += 1
return cost
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
if arc_is_gold(gold, head, child):
return 0
elif stcls.H(child) == gold.heads[child]:
@ -123,15 +124,15 @@ cdef class Shift:
st.fast_forward()
@staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass st, const GoldParseC* gold, int label) nogil:
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
@staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
return push_cost(s, gold, s.B(0))
@staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
@ -149,15 +150,15 @@ cdef class Reduce:
st.fast_forward()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil:
cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
return pop_cost(st, gold, st.S(0))
@staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
@ -173,12 +174,12 @@ cdef class LeftArc:
st.fast_forward()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef int cost = 0
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef weight_t cost = 0
if arc_is_gold(gold, s.B(0), s.S(0)):
return 0
else:
@ -190,7 +191,7 @@ cdef class LeftArc:
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
@ -206,11 +207,11 @@ cdef class RightArc:
st.fast_forward()
@staticmethod
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef inline weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
elif s.shifted[s.B(0)]:
@ -219,7 +220,7 @@ cdef class RightArc:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
@ -247,12 +248,12 @@ cdef class Break:
st.fast_forward()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef int cost = 0
cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef weight_t cost = 0
cdef int i, j, S_i, B_i
for i in range(s.stack_depth()):
S_i = s.S(i)
@ -270,7 +271,7 @@ cdef class Break:
return cost + 1
@staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0
cdef int _get_root(int word, const GoldParseC* gold) nogil:
@ -404,12 +405,12 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* is_valid, int* costs,
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1:
cdef int i, move, label
cdef label_cost_func_t[N_MOVES] label_cost_funcs
cdef move_cost_func_t[N_MOVES] move_cost_funcs
cdef int[N_MOVES] move_costs
cdef weight_t[N_MOVES] move_costs
for i in range(N_MOVES):
move_costs[i] = -1
move_cost_funcs[SHIFT] = Shift.move_cost

View File

@ -158,7 +158,7 @@ cdef class Missing:
pass
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 9000
@ -195,7 +195,7 @@ cdef class Begin:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label
@ -236,7 +236,7 @@ cdef class In:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(0) < s.length else OUT
cdef int g_act = gold.ner[s.B(0)].move
@ -279,7 +279,7 @@ cdef class Last:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
move = LAST
cdef int g_act = gold.ner[s.B(0)].move
@ -329,7 +329,7 @@ cdef class Unit:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label
@ -363,7 +363,7 @@ cdef class Out:
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label

View File

@ -1,6 +1,6 @@
from thinc.search cimport Beam
from thinc.api cimport AveragedPerceptron
from thinc.api cimport Example, ExampleC
from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.extra.eg cimport Example
from thinc.structs cimport ExampleC
from .stateclass cimport StateClass
from .arc_eager cimport TransitionSystem
@ -9,7 +9,7 @@ from ..structs cimport TokenC
cdef class ParserModel(AveragedPerceptron):
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *
cdef void set_featuresC(self, ExampleC* eg, StateClass stcls) except *
cdef class Parser:

View File

@ -19,7 +19,8 @@ import sys
from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.features cimport ConjunctionExtracter
from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec
from util import Config
@ -64,7 +65,7 @@ def ParserFactory(transition_system):
cdef class ParserModel(AveragedPerceptron):
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *:
cdef void set_featuresC(self, ExampleC* eg, StateClass stcls) except *:
fill_context(eg.atoms, stcls)
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
@ -83,8 +84,7 @@ cdef class Parser:
cfg = Config.read(model_dir, 'config')
moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features)
model = ParserModel(moves.n_moves,
ConjunctionExtracter(CONTEXT_SIZE, templates))
model = ParserModel(templates)
if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model)
@ -104,14 +104,19 @@ cdef class Parser:
self.moves.initialize_state(stcls)
cdef Pool mem = Pool()
cdef ExampleC eg = self.model.allocate(mem)
cdef Example eg = Example(
nr_class=self.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat)
while not stcls.is_final():
self.model.set_features(&eg, stcls)
self.moves.set_valid(eg.is_valid, stcls)
self.model.set_prediction(&eg)
self.model.set_featuresC(&eg.c, stcls)
self.moves.set_valid(eg.c.is_valid, stcls)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
action = self.moves.c[eg.guess]
if not eg.is_valid[eg.guess]:
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
action = self.moves.c[guess]
if not eg.is_valid[guess]:
raise ValueError(
"Illegal action: %s" % self.moves.move_name(action.move, action.label)
)
@ -119,6 +124,7 @@ cdef class Parser:
action.do(stcls, action.label)
# Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals()
eg.reset_classes(eg.nr_class)
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
@ -127,18 +133,23 @@ cdef class Parser:
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls)
cdef Pool mem = Pool()
cdef ExampleC eg = self.model.allocate(mem)
cdef Example eg = Example(
nr_class=self.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat)
cdef weight_t loss = 0
cdef Transition action
while not stcls.is_final():
self.model.set_features(&eg, stcls)
self.moves.set_costs(eg.is_valid, eg.costs, stcls, gold)
self.model.set_prediction(&eg)
self.model.update(&eg)
self.model.set_featuresC(&eg.c, stcls)
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
self.model.updateC(&eg.c)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
action = self.moves.c[eg.guess]
action.do(stcls, action.label)
loss += eg.costs[eg.guess]
eg.reset_classes(eg.nr_class)
return loss
def step_through(self, Doc doc):
@ -147,11 +158,6 @@ cdef class Parser:
def add_label(self, label):
for action in self.moves.action_types:
self.moves.add_action(action, label)
# This seems pretty dangerous. However, thinc uses sparse vectors for
# classes, so it doesn't need to have the classes pre-specified. Things
# get dicey if people have an Exampe class around, which is being reused.
self.model.nr_class = self.moves.n_moves
cdef class StepwiseState:
@ -165,8 +171,10 @@ cdef class StepwiseState:
self.doc = doc
self.stcls = StateClass.init(doc.c, doc.length)
self.parser.moves.initialize_state(self.stcls)
self.eg = Example(self.parser.model.nr_class, CONTEXT_SIZE,
self.parser.model.nr_templ, self.parser.model.nr_embed)
self.eg = Example(
nr_class=self.parser.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.parser.model.nr_feat)
def __enter__(self):
return self
@ -196,11 +204,13 @@ cdef class StepwiseState:
for i in range(self.stcls.length)]
def predict(self):
self.parser.model.set_features(&self.eg.c, self.stcls)
self.eg.reset()
self.parser.model.set_featuresC(&self.eg.c, self.stcls)
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls)
self.parser.model.set_prediction(&self.eg.c)
self.parser.model.set_scoresC(self.eg.c.scores,
self.eg.c.features, self.eg.c.nr_feat)
action = self.parser.moves.c[self.eg.c.guess]
cdef Transition action = self.parser.moves.c[self.eg.guess]
return self.parser.moves.move_name(action.move, action.label)
def transition(self, action_name):

View File

@ -17,13 +17,13 @@ cdef struct Transition:
weight_t score
bint (*is_valid)(StateClass state, int label) nogil
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
weight_t (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
int (*do)(StateClass state, int label) nogil
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef weight_t (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef weight_t (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*do_func_t)(StateClass state, int label) nogil
@ -48,5 +48,5 @@ cdef class TransitionSystem:
cdef int set_valid(self, int* output, StateClass state) nogil
cdef int set_costs(self, int* is_valid, int* costs,
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass state, GoldParse gold) except -1

View File

@ -71,7 +71,7 @@ cdef class TransitionSystem:
for i in range(self.n_moves):
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
cdef int set_costs(self, int* is_valid, int* costs,
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1:
cdef int i
self.set_valid(is_valid, stcls)

View File

@ -72,6 +72,7 @@ cpdef enum:
cdef class TaggerModel(AveragedPerceptron):
cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *:
_fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])
_fill_from_token(&eg.atoms[P1_orth], &tokens[i-1])
_fill_from_token(&eg.atoms[W_orth], &tokens[i])
@ -198,7 +199,9 @@ cdef class Tagger:
cdef Pool mem = Pool()
cdef int i, tag
cdef Example eg = Example(self.vocab.morphology.n_tags)
cdef Example eg = Example(nr_atom=N_CONTEXT_FIELDS,
nr_class=self.vocab.morphology.n_tags,
nr_feat=self.model.nr_feat)
for i in range(tokens.length):
if tokens.c[i].pos == 0:
self.model.set_featuresC(&eg.c, tokens.c, i)
@ -206,6 +209,7 @@ cdef class Tagger:
eg.c.features, eg.c.nr_feat)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
self.vocab.morphology.assign_tag(&tokens.c[i], guess)
eg.reset_classes(eg.c.nr_class)
tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length
@ -214,7 +218,10 @@ cdef class Tagger:
golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs]
cdef int correct = 0
cdef Pool mem = Pool()
cdef Example eg = Example(self.vocab.morphology.n_tags)
cdef Example eg = Example(
nr_atom=N_CONTEXT_FIELDS,
nr_class=self.vocab.morphology.n_tags,
nr_feat=self.model.nr_feat)
for i in range(tokens.length):
self.model.set_featuresC(&eg.c, tokens.c, i)
eg.set_label(golds[i])
@ -227,7 +234,7 @@ cdef class Tagger:
correct += eg.cost == 0
self.freqs[TAG][tokens.c[i].tag] += 1
eg.wipe(tuple())
eg.reset_classes(eg.c.nr_class)
tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length
return correct