From a46933a8fe8d594c4c48bcd51c1464e815d0b4c4 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 16 Mar 2017 11:58:20 -0500 Subject: [PATCH] Clean up FTRL parsing stuff. --- spacy/syntax/parser.pyx | 20 +++++++++++--------- spacy/tagger.pyx | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 944bc4442..c94d4ebee 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -52,7 +52,7 @@ from ._parse_features cimport fill_context from .stateclass cimport StateClass from ._state cimport StateC - +USE_FTRL = False DEBUG = False def set_debug(val): global DEBUG @@ -86,14 +86,14 @@ cdef class ParserModel(AveragedPerceptron): guess = eg.guess if guess == best or best == -1: return 0.0 - for feat in eg.c.features[:eg.c.nr_feat]: - self.update_weight_ftrl(feat.key, guess, feat.value * eg.c.costs[guess]) - self.update_weight_ftrl(feat.key, best, -feat.value * eg.c.costs[guess]) - #for clas in [guess, best]: - # loss += (-eg.c.costs[clas] - eg.c.scores[clas]) ** 2 - # d_loss = eg.c.scores[clas] - -eg.c.costs[clas] - # for feat in eg.c.features[:eg.c.nr_feat]: - # self.update_weight_ftrl(feat.key, clas, feat.value * d_loss) + if USE_FTRL: + for feat in eg.c.features[:eg.c.nr_feat]: + self.update_weight_ftrl(feat.key, guess, feat.value * eg.c.costs[guess]) + self.update_weight_ftrl(feat.key, best, -feat.value * eg.c.costs[guess]) + else: + for feat in eg.c.features[:eg.c.nr_feat]: + self.update_weight(feat.key, guess, feat.value * eg.c.costs[guess]) + self.update_weight(feat.key, best, -feat.value * eg.c.costs[guess]) return eg.c.costs[guess] def update_from_histories(self, TransitionSystem moves, Doc doc, histories, weight_t min_grad=0.0): @@ -324,6 +324,8 @@ cdef class Parser: eg.fill_scores(0, eg.c.nr_class) eg.fill_costs(0, eg.c.nr_class) eg.fill_is_valid(1, eg.c.nr_class) + + self.moves.finalize_state(stcls.c) return loss def step_through(self, Doc doc): diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 6f034f3de..1f6b587c5 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -76,8 +76,8 @@ cdef class TaggerModel(AveragedPerceptron): best = VecVec.arg_max_if_zero(eg.c.scores, eg.c.costs, eg.c.nr_class) if guess != best: for feat in eg.c.features[:eg.c.nr_feat]: - self.update_weight_ftrl(feat.key, best, -feat.value) - self.update_weight_ftrl(feat.key, guess, feat.value) + self.update_weight(feat.key, best, -feat.value) + self.update_weight(feat.key, guess, feat.value) cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *: _fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])