Add get_beam_parse method in ArcEager, for Prodigy

This commit is contained in:
Matthew Honnibal 2018-02-15 21:03:16 +01:00
parent 3e541de440
commit 59b7cf9db8

View File

@ -390,6 +390,22 @@ cdef class ArcEager(TransitionSystem):
gold.c.labels[i] = self.strings.add(label) gold.c.labels[i] = self.strings.add(label)
return gold return gold
def get_beam_parses(self, Beam beam):
parses = []
probs = beam.probs
for i in range(beam.size):
state = <StateC*>beam.at(i)
if state.is_final():
self.finalize_state(state)
prob = probs[i]
parse = []
for j in range(state.length):
head = state.H(j)
label = self.strings[state._sent[j].dep]
parse.append((head, j, label))
parses.append((prob, parse))
return parses
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
if '-' in name: if '-' in name:
move_str, label_str = name.split('-', 1) move_str, label_str = name.split('-', 1)