Add costs property to StepwiseState, to show which moves are gold.

This commit is contained in:
Matthew Honnibal 2017-04-10 11:37:04 +02:00
parent 8bd7e641ab
commit 49e2de900e

View File

@ -334,15 +334,16 @@ cdef class Parser:
self.moves.finalize_state(stcls.c) self.moves.finalize_state(stcls.c)
return loss return loss
def step_through(self, Doc doc): def step_through(self, Doc doc, GoldParse gold=None):
"""Set up a stepwise state, to introspect and control the transition sequence. """Set up a stepwise state, to introspect and control the transition sequence.
Arguments: Arguments:
doc (Doc): The document to step through. doc (Doc): The document to step through.
gold (GoldParse): Optional gold parse
Returns (StepwiseState): Returns (StepwiseState):
A state object, to step through the annotation process. A state object, to step through the annotation process.
""" """
return StepwiseState(self, doc) return StepwiseState(self, doc, gold=gold)
def from_transition_sequence(self, Doc doc, sequence): def from_transition_sequence(self, Doc doc, sequence):
"""Control the annotations on a document by specifying a transition sequence """Control the annotations on a document by specifying a transition sequence
@ -367,13 +368,19 @@ cdef class StepwiseState:
cdef readonly StateClass stcls cdef readonly StateClass stcls
cdef readonly Example eg cdef readonly Example eg
cdef readonly Doc doc cdef readonly Doc doc
cdef readonly GoldParse gold
cdef readonly Parser parser cdef readonly Parser parser
def __init__(self, Parser parser, Doc doc): def __init__(self, Parser parser, Doc doc, GoldParse gold=None):
self.parser = parser self.parser = parser
self.doc = doc self.doc = doc
if gold:
self.gold = gold
else:
self.gold = GoldParse(doc)
self.stcls = StateClass.init(doc.c, doc.length) self.stcls = StateClass.init(doc.c, doc.length)
self.parser.moves.initialize_state(self.stcls.c) self.parser.moves.initialize_state(self.stcls.c)
self.parser.moves.preprocess_gold(gold)
self.eg = Example( self.eg = Example(
nr_class=self.parser.moves.n_moves, nr_class=self.parser.moves.n_moves,
nr_atom=CONTEXT_SIZE, nr_atom=CONTEXT_SIZE,
@ -406,6 +413,20 @@ cdef class StepwiseState:
return [self.doc.vocab.strings[self.stcls.c._sent[i].dep] return [self.doc.vocab.strings[self.stcls.c._sent[i].dep]
for i in range(self.stcls.c.length)] for i in range(self.stcls.c.length)]
@property
def costs(self):
'''Find the action-costs for the current state'''
self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs,
self.stcls, self.gold)
costs = {}
for i in range(self.parser.moves.n_moves):
if not self.eg.c.is_valid[i]:
continue
transition = self.parser.moves.c[i]
name = self.parser.moves.move_name(transition.move, transition.label)
costs[name] = self.eg.c.costs[i]
return costs
def predict(self): def predict(self):
self.eg.reset() self.eg.reset()
self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features, self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,