From 49e2de900e7d59c5691cc0e49c295d2d9957898c Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 10 Apr 2017 11:37:04 +0200 Subject: [PATCH] Add costs property to StepwiseState, to show which moves are gold. --- spacy/syntax/parser.pyx | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 123ae03da..344ac5568 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -334,15 +334,16 @@ cdef class Parser: self.moves.finalize_state(stcls.c) return loss - def step_through(self, Doc doc): + def step_through(self, Doc doc, GoldParse gold=None): """Set up a stepwise state, to introspect and control the transition sequence. Arguments: doc (Doc): The document to step through. + gold (GoldParse): Optional gold parse Returns (StepwiseState): A state object, to step through the annotation process. """ - return StepwiseState(self, doc) + return StepwiseState(self, doc, gold=gold) def from_transition_sequence(self, Doc doc, sequence): """Control the annotations on a document by specifying a transition sequence @@ -367,13 +368,19 @@ cdef class StepwiseState: cdef readonly StateClass stcls cdef readonly Example eg cdef readonly Doc doc + cdef readonly GoldParse gold cdef readonly Parser parser - def __init__(self, Parser parser, Doc doc): + def __init__(self, Parser parser, Doc doc, GoldParse gold=None): self.parser = parser self.doc = doc + if gold: + self.gold = gold + else: + self.gold = GoldParse(doc) self.stcls = StateClass.init(doc.c, doc.length) self.parser.moves.initialize_state(self.stcls.c) + self.parser.moves.preprocess_gold(gold) self.eg = Example( nr_class=self.parser.moves.n_moves, nr_atom=CONTEXT_SIZE, @@ -406,6 +413,20 @@ cdef class StepwiseState: return [self.doc.vocab.strings[self.stcls.c._sent[i].dep] for i in range(self.stcls.c.length)] + @property + def costs(self): + '''Find the action-costs for the current state''' + self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs, + self.stcls, self.gold) + costs = {} + for i in range(self.parser.moves.n_moves): + if not self.eg.c.is_valid[i]: + continue + transition = self.parser.moves.c[i] + name = self.parser.moves.move_name(transition.move, transition.label) + costs[name] = self.eg.c.costs[i] + return costs + def predict(self): self.eg.reset() self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,