Return ArcEagerGoldParse from ArcEager

This commit is contained in:
Matthew Honnibal 2020-06-19 00:11:59 +02:00
parent 0c6f1f3891
commit 5ae9e3480d
2 changed files with 31 additions and 22 deletions

View File

@ -132,6 +132,15 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
)
return gs
cdef class ArcEagerGoldParse:
cdef GoldParseStateC c
def __init__(self, StateClass stcls, Example example):
self.mem = Pool()
self.c = create_gold_state(self.mem, stcls, example)
cdef int check_state_gold(char state_bits, char flag) nogil:
cdef char one = 1
return state_bits & (one << flag)
@ -156,7 +165,6 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
# Helper functions for the arc-eager oracle
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
@ -500,6 +508,14 @@ cdef class ArcEager(TransitionSystem):
def preprocess_gold(self, example):
raise NotImplementedError
def init_gold_batch(self, examples):
states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, s in enumerate(states) if not s.is_final()]
states = [states[i] for i in keeps]
examples = [examples[i] for i in keeps]
n_steps = sum([len(s.buffer_length()) * 4 for s in states])
return states, examples, n_steps
cdef Transition lookup_transition(self, object name_or_id) except *:
if isinstance(name_or_id, int):
return self.c[name_or_id]

View File

@ -268,15 +268,10 @@ cdef class Parser:
for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd)
set_dropout_rate(self.model, drop)
try:
states, golds, max_steps = self._init_gold_batch_no_cut(examples)
except AttributeError:
types = set([type(eg) for eg in examples])
raise ValueError(Errors.E978.format(name="Parser", method="update", types=types))
states_golds = [(s, g) for (s, g) in zip(states, golds)
if not s.is_final() and g is not None]
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update([eg.doc for eg in examples])
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
states, golds, max_steps = self.moves.init_gold_batch(examples)
all_states = list(states)
for _ in range(max_steps):
if not states_golds:
@ -287,12 +282,12 @@ cdef class Parser:
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, scores)
states_golds = [eg for eg in states_golds if not eg[0].is_final()]
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
backprop_tok2vec(golds)
if sgd is not None:
self.model.finish_update(sgd)
if set_annotations:
docs = [eg.doc for eg in examples]
docs = [eg.predicted for eg in examples]
self.set_annotations(docs, all_states)
return losses
@ -307,7 +302,7 @@ cdef class Parser:
return None
losses.setdefault(self.name, 0.)
docs = [eg.doc for eg in examples]
docs = [eg.predicted for eg in examples]
states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
@ -356,11 +351,7 @@ cdef class Parser:
queue.extend(node._layers)
return gradients
def _init_gold_batch_no_cut(self, examples):
states = self.moves.init_batch([eg.predicted for eg in examples])
return states, examples
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
cdef StateClass state
cdef Example example
cdef Pool mem = Pool()
@ -375,10 +366,10 @@ cdef class Parser:
dtype='f', order='C')
c_d_scores = <float*>d_scores.data
unseen_classes = self.model.attrs["unseen_classes"]
for i, (state, eg) in enumerate(zip(states, examples)):
for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state, eg)
self.moves.set_costs(is_valid, costs, state, gold)
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j)
@ -403,9 +394,11 @@ cdef class Parser:
if not hasattr(get_examples, '__call__'):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
actions = self.moves.get_actions(gold_parses=get_examples(),
actions = self.moves.get_actions(
examples=get_examples(),
min_freq=self.cfg['min_action_freq'],
learn_tokens=self.cfg["learn_tokens"])
learn_tokens=self.cfg["learn_tokens"]
)
for action, labels in self.moves.labels.items():
actions.setdefault(action, {})
for label, freq in labels.items():