Add costs property to StepwiseState, to show which moves are gold.

This commit is contained in:
Matthew Honnibal 2017-04-10 11:37:04 +02:00
parent 8bd7e641ab
commit 49e2de900e

View File

@ -334,15 +334,16 @@ cdef class Parser:
self.moves.finalize_state(stcls.c)
return loss
def step_through(self, Doc doc):
def step_through(self, Doc doc, GoldParse gold=None):
"""Set up a stepwise state, to introspect and control the transition sequence.
Arguments:
doc (Doc): The document to step through.
gold (GoldParse): Optional gold parse
Returns (StepwiseState):
A state object, to step through the annotation process.
"""
return StepwiseState(self, doc)
return StepwiseState(self, doc, gold=gold)
def from_transition_sequence(self, Doc doc, sequence):
"""Control the annotations on a document by specifying a transition sequence
@ -367,13 +368,19 @@ cdef class StepwiseState:
cdef readonly StateClass stcls
cdef readonly Example eg
cdef readonly Doc doc
cdef readonly GoldParse gold
cdef readonly Parser parser
def __init__(self, Parser parser, Doc doc):
def __init__(self, Parser parser, Doc doc, GoldParse gold=None):
self.parser = parser
self.doc = doc
if gold:
self.gold = gold
else:
self.gold = GoldParse(doc)
self.stcls = StateClass.init(doc.c, doc.length)
self.parser.moves.initialize_state(self.stcls.c)
self.parser.moves.preprocess_gold(gold)
self.eg = Example(
nr_class=self.parser.moves.n_moves,
nr_atom=CONTEXT_SIZE,
@ -406,6 +413,20 @@ cdef class StepwiseState:
return [self.doc.vocab.strings[self.stcls.c._sent[i].dep]
for i in range(self.stcls.c.length)]
@property
def costs(self):
'''Find the action-costs for the current state'''
self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs,
self.stcls, self.gold)
costs = {}
for i in range(self.parser.moves.n_moves):
if not self.eg.c.is_valid[i]:
continue
transition = self.parser.moves.c[i]
name = self.parser.moves.move_name(transition.move, transition.label)
costs[name] = self.eg.c.costs[i]
return costs
def predict(self):
self.eg.reset()
self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,