* Update for thinc 5.0, including changing cost from int to weight_t, and updating the tagger and parser

This commit is contained in:
Matthew Honnibal 2016-01-30 14:31:12 +01:00
parent ea4ff94cde
commit 10877a7791
8 changed files with 89 additions and 71 deletions

View File

@ -12,6 +12,6 @@ cdef class ArcEager(TransitionSystem):
pass pass
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil

View File

@ -48,8 +48,8 @@ MOVE_NAMES[BREAK] = 'B'
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0 cdef weight_t cost = 0
cdef int i, S_i cdef int i, S_i
for i in range(stcls.stack_depth()): for i in range(stcls.stack_depth()):
S_i = stcls.S(i) S_i = stcls.S(i)
@ -61,8 +61,8 @@ cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
return cost return cost
cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cdef int cost = 0 cdef weight_t cost = 0
cdef int i, B_i cdef int i, B_i
for i in range(stcls.buffer_length()): for i in range(stcls.buffer_length()):
B_i = stcls.B(i) B_i = stcls.B(i)
@ -70,11 +70,12 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil:
cost += gold.heads[target] == B_i cost += gold.heads[target] == B_i
if gold.heads[B_i] == B_i or gold.heads[B_i] < target: if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
break break
cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0 if Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0:
cost += 1
return cost return cost
cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil: cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil:
if arc_is_gold(gold, head, child): if arc_is_gold(gold, head, child):
return 0 return 0
elif stcls.H(child) == gold.heads[child]: elif stcls.H(child) == gold.heads[child]:
@ -123,15 +124,15 @@ cdef class Shift:
st.fast_forward() st.fast_forward()
@staticmethod @staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass st, const GoldParseC* gold, int label) nogil:
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label) return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
@staticmethod @staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
return push_cost(s, gold, s.B(0)) return push_cost(s, gold, s.B(0))
@staticmethod @staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0 return 0
@ -149,15 +150,15 @@ cdef class Reduce:
st.fast_forward() st.fast_forward()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil:
return pop_cost(st, gold, st.S(0)) return pop_cost(st, gold, st.S(0))
@staticmethod @staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0 return 0
@ -173,12 +174,12 @@ cdef class LeftArc:
st.fast_forward() st.fast_forward()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef int cost = 0 cdef weight_t cost = 0
if arc_is_gold(gold, s.B(0), s.S(0)): if arc_is_gold(gold, s.B(0), s.S(0)):
return 0 return 0
else: else:
@ -190,7 +191,7 @@ cdef class LeftArc:
return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0)) return pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod @staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label) return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
@ -206,11 +207,11 @@ cdef class RightArc:
st.fast_forward() st.fast_forward()
@staticmethod @staticmethod
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef inline weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
if arc_is_gold(gold, s.S(0), s.B(0)): if arc_is_gold(gold, s.S(0), s.B(0)):
return 0 return 0
elif s.shifted[s.B(0)]: elif s.shifted[s.B(0)]:
@ -219,7 +220,7 @@ cdef class RightArc:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod @staticmethod
cdef int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label) return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
@ -247,12 +248,12 @@ cdef class Break:
st.fast_forward() st.fast_forward()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label) return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod @staticmethod
cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil:
cdef int cost = 0 cdef weight_t cost = 0
cdef int i, j, S_i, B_i cdef int i, j, S_i, B_i
for i in range(s.stack_depth()): for i in range(s.stack_depth()):
S_i = s.S(i) S_i = s.S(i)
@ -270,7 +271,7 @@ cdef class Break:
return cost + 1 return cost + 1
@staticmethod @staticmethod
cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 0 return 0
cdef int _get_root(int word, const GoldParseC* gold) nogil: cdef int _get_root(int word, const GoldParseC* gold) nogil:
@ -404,12 +405,12 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves): for i in range(self.n_moves):
output[i] = is_valid[self.c[i].move] output[i] = is_valid[self.c[i].move]
cdef int set_costs(self, int* is_valid, int* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1: StateClass stcls, GoldParse gold) except -1:
cdef int i, move, label cdef int i, move, label
cdef label_cost_func_t[N_MOVES] label_cost_funcs cdef label_cost_func_t[N_MOVES] label_cost_funcs
cdef move_cost_func_t[N_MOVES] move_cost_funcs cdef move_cost_func_t[N_MOVES] move_cost_funcs
cdef int[N_MOVES] move_costs cdef weight_t[N_MOVES] move_costs
for i in range(N_MOVES): for i in range(N_MOVES):
move_costs[i] = -1 move_costs[i] = -1
move_cost_funcs[SHIFT] = Shift.move_cost move_cost_funcs[SHIFT] = Shift.move_cost

View File

@ -158,7 +158,7 @@ cdef class Missing:
pass pass
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
return 9000 return 9000
@ -195,7 +195,7 @@ cdef class Begin:
st.pop() st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label cdef int g_tag = gold.ner[s.B(0)].label
@ -236,7 +236,7 @@ cdef class In:
st.pop() st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
move = IN move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(0) < s.length else OUT cdef int next_act = gold.ner[s.B(1)].move if s.B(0) < s.length else OUT
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
@ -279,7 +279,7 @@ cdef class Last:
st.pop() st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
move = LAST move = LAST
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
@ -329,7 +329,7 @@ cdef class Unit:
st.pop() st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label cdef int g_tag = gold.ner[s.B(0)].label
@ -363,7 +363,7 @@ cdef class Out:
st.pop() st.pop()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef weight_t cost(StateClass s, const GoldParseC* gold, int label) nogil:
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
cdef int g_tag = gold.ner[s.B(0)].label cdef int g_tag = gold.ner[s.B(0)].label

View File

@ -1,6 +1,6 @@
from thinc.search cimport Beam from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.api cimport AveragedPerceptron from thinc.extra.eg cimport Example
from thinc.api cimport Example, ExampleC from thinc.structs cimport ExampleC
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
@ -9,7 +9,7 @@ from ..structs cimport TokenC
cdef class ParserModel(AveragedPerceptron): cdef class ParserModel(AveragedPerceptron):
cdef void set_features(self, ExampleC* eg, StateClass stcls) except * cdef void set_featuresC(self, ExampleC* eg, StateClass stcls) except *
cdef class Parser: cdef class Parser:

View File

@ -19,7 +19,8 @@ import sys
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
from thinc.features cimport ConjunctionExtracter from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec
from util import Config from util import Config
@ -64,7 +65,7 @@ def ParserFactory(transition_system):
cdef class ParserModel(AveragedPerceptron): cdef class ParserModel(AveragedPerceptron):
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *: cdef void set_featuresC(self, ExampleC* eg, StateClass stcls) except *:
fill_context(eg.atoms, stcls) fill_context(eg.atoms, stcls)
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms) eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
@ -83,8 +84,7 @@ cdef class Parser:
cfg = Config.read(model_dir, 'config') cfg = Config.read(model_dir, 'config')
moves = transition_system(strings, cfg.labels) moves = transition_system(strings, cfg.labels)
templates = get_templates(cfg.features) templates = get_templates(cfg.features)
model = ParserModel(moves.n_moves, model = ParserModel(templates)
ConjunctionExtracter(CONTEXT_SIZE, templates))
if path.exists(path.join(model_dir, 'model')): if path.exists(path.join(model_dir, 'model')):
model.load(path.join(model_dir, 'model')) model.load(path.join(model_dir, 'model'))
return cls(strings, moves, model) return cls(strings, moves, model)
@ -104,14 +104,19 @@ cdef class Parser:
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls)
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef ExampleC eg = self.model.allocate(mem) cdef Example eg = Example(
nr_class=self.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat)
while not stcls.is_final(): while not stcls.is_final():
self.model.set_features(&eg, stcls) self.model.set_featuresC(&eg.c, stcls)
self.moves.set_valid(eg.is_valid, stcls) self.moves.set_valid(eg.c.is_valid, stcls)
self.model.set_prediction(&eg) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
action = self.moves.c[eg.guess] guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
if not eg.is_valid[eg.guess]:
action = self.moves.c[guess]
if not eg.is_valid[guess]:
raise ValueError( raise ValueError(
"Illegal action: %s" % self.moves.move_name(action.move, action.label) "Illegal action: %s" % self.moves.move_name(action.move, action.label)
) )
@ -119,6 +124,7 @@ cdef class Parser:
action.do(stcls, action.label) action.do(stcls, action.label)
# Check for KeyboardInterrupt etc. Untested # Check for KeyboardInterrupt etc. Untested
PyErr_CheckSignals() PyErr_CheckSignals()
eg.reset_classes(eg.nr_class)
self.moves.finalize_state(stcls) self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent) tokens.set_parse(stcls._sent)
@ -127,18 +133,23 @@ cdef class Parser:
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls)
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef ExampleC eg = self.model.allocate(mem) cdef Example eg = Example(
nr_class=self.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.model.nr_feat)
cdef weight_t loss = 0 cdef weight_t loss = 0
cdef Transition action cdef Transition action
while not stcls.is_final(): while not stcls.is_final():
self.model.set_features(&eg, stcls) self.model.set_featuresC(&eg.c, stcls)
self.moves.set_costs(eg.is_valid, eg.costs, stcls, gold) self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
self.model.set_prediction(&eg) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
self.model.update(&eg) self.model.updateC(&eg.c)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
action = self.moves.c[eg.guess] action = self.moves.c[eg.guess]
action.do(stcls, action.label) action.do(stcls, action.label)
loss += eg.costs[eg.guess] loss += eg.costs[eg.guess]
eg.reset_classes(eg.nr_class)
return loss return loss
def step_through(self, Doc doc): def step_through(self, Doc doc):
@ -147,11 +158,6 @@ cdef class Parser:
def add_label(self, label): def add_label(self, label):
for action in self.moves.action_types: for action in self.moves.action_types:
self.moves.add_action(action, label) self.moves.add_action(action, label)
# This seems pretty dangerous. However, thinc uses sparse vectors for
# classes, so it doesn't need to have the classes pre-specified. Things
# get dicey if people have an Exampe class around, which is being reused.
self.model.nr_class = self.moves.n_moves
cdef class StepwiseState: cdef class StepwiseState:
@ -165,8 +171,10 @@ cdef class StepwiseState:
self.doc = doc self.doc = doc
self.stcls = StateClass.init(doc.c, doc.length) self.stcls = StateClass.init(doc.c, doc.length)
self.parser.moves.initialize_state(self.stcls) self.parser.moves.initialize_state(self.stcls)
self.eg = Example(self.parser.model.nr_class, CONTEXT_SIZE, self.eg = Example(
self.parser.model.nr_templ, self.parser.model.nr_embed) nr_class=self.parser.moves.n_moves,
nr_atom=CONTEXT_SIZE,
nr_feat=self.parser.model.nr_feat)
def __enter__(self): def __enter__(self):
return self return self
@ -196,11 +204,13 @@ cdef class StepwiseState:
for i in range(self.stcls.length)] for i in range(self.stcls.length)]
def predict(self): def predict(self):
self.parser.model.set_features(&self.eg.c, self.stcls) self.eg.reset()
self.parser.model.set_featuresC(&self.eg.c, self.stcls)
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls) self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls)
self.parser.model.set_prediction(&self.eg.c) self.parser.model.set_scoresC(self.eg.c.scores,
self.eg.c.features, self.eg.c.nr_feat)
action = self.parser.moves.c[self.eg.c.guess] cdef Transition action = self.parser.moves.c[self.eg.guess]
return self.parser.moves.move_name(action.move, action.label) return self.parser.moves.move_name(action.move, action.label)
def transition(self, action_name): def transition(self, action_name):

View File

@ -17,13 +17,13 @@ cdef struct Transition:
weight_t score weight_t score
bint (*is_valid)(StateClass state, int label) nogil bint (*is_valid)(StateClass state, int label) nogil
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil weight_t (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil
int (*do)(StateClass state, int label) nogil int (*do)(StateClass state, int label) nogil
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil ctypedef weight_t (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil ctypedef weight_t (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil
ctypedef int (*do_func_t)(StateClass state, int label) nogil ctypedef int (*do_func_t)(StateClass state, int label) nogil
@ -48,5 +48,5 @@ cdef class TransitionSystem:
cdef int set_valid(self, int* output, StateClass state) nogil cdef int set_valid(self, int* output, StateClass state) nogil
cdef int set_costs(self, int* is_valid, int* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass state, GoldParse gold) except -1 StateClass state, GoldParse gold) except -1

View File

@ -71,7 +71,7 @@ cdef class TransitionSystem:
for i in range(self.n_moves): for i in range(self.n_moves):
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label) is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
cdef int set_costs(self, int* is_valid, int* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1: StateClass stcls, GoldParse gold) except -1:
cdef int i cdef int i
self.set_valid(is_valid, stcls) self.set_valid(is_valid, stcls)

View File

@ -72,6 +72,7 @@ cpdef enum:
cdef class TaggerModel(AveragedPerceptron): cdef class TaggerModel(AveragedPerceptron):
cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *: cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *:
_fill_from_token(&eg.atoms[P2_orth], &tokens[i-2]) _fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])
_fill_from_token(&eg.atoms[P1_orth], &tokens[i-1]) _fill_from_token(&eg.atoms[P1_orth], &tokens[i-1])
_fill_from_token(&eg.atoms[W_orth], &tokens[i]) _fill_from_token(&eg.atoms[W_orth], &tokens[i])
@ -198,7 +199,9 @@ cdef class Tagger:
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef int i, tag cdef int i, tag
cdef Example eg = Example(self.vocab.morphology.n_tags) cdef Example eg = Example(nr_atom=N_CONTEXT_FIELDS,
nr_class=self.vocab.morphology.n_tags,
nr_feat=self.model.nr_feat)
for i in range(tokens.length): for i in range(tokens.length):
if tokens.c[i].pos == 0: if tokens.c[i].pos == 0:
self.model.set_featuresC(&eg.c, tokens.c, i) self.model.set_featuresC(&eg.c, tokens.c, i)
@ -206,6 +209,7 @@ cdef class Tagger:
eg.c.features, eg.c.nr_feat) eg.c.features, eg.c.nr_feat)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
self.vocab.morphology.assign_tag(&tokens.c[i], guess) self.vocab.morphology.assign_tag(&tokens.c[i], guess)
eg.reset_classes(eg.c.nr_class)
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
@ -214,7 +218,10 @@ cdef class Tagger:
golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs] golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs]
cdef int correct = 0 cdef int correct = 0
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef Example eg = Example(self.vocab.morphology.n_tags) cdef Example eg = Example(
nr_atom=N_CONTEXT_FIELDS,
nr_class=self.vocab.morphology.n_tags,
nr_feat=self.model.nr_feat)
for i in range(tokens.length): for i in range(tokens.length):
self.model.set_featuresC(&eg.c, tokens.c, i) self.model.set_featuresC(&eg.c, tokens.c, i)
eg.set_label(golds[i]) eg.set_label(golds[i])
@ -227,7 +234,7 @@ cdef class Tagger:
correct += eg.cost == 0 correct += eg.cost == 0
self.freqs[TAG][tokens.c[i].tag] += 1 self.freqs[TAG][tokens.c[i].tag] += 1
eg.wipe(tuple()) eg.reset_classes(eg.c.nr_class)
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
return correct return correct