mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Return ArcEagerGoldParse from ArcEager
This commit is contained in:
parent
0c6f1f3891
commit
5ae9e3480d
|
@ -132,6 +132,15 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
|
|||
)
|
||||
return gs
|
||||
|
||||
|
||||
cdef class ArcEagerGoldParse:
|
||||
cdef GoldParseStateC c
|
||||
def __init__(self, StateClass stcls, Example example):
|
||||
self.mem = Pool()
|
||||
self.c = create_gold_state(self.mem, stcls, example)
|
||||
|
||||
|
||||
|
||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||
cdef char one = 1
|
||||
return state_bits & (one << flag)
|
||||
|
@ -156,7 +165,6 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil:
|
|||
cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil:
|
||||
return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN)
|
||||
|
||||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
|
||||
|
@ -500,6 +508,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
def preprocess_gold(self, example):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_gold_batch(self, examples):
|
||||
states = self.init_batch([eg.predicted for eg in examples])
|
||||
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
||||
states = [states[i] for i in keeps]
|
||||
examples = [examples[i] for i in keeps]
|
||||
n_steps = sum([len(s.buffer_length()) * 4 for s in states])
|
||||
return states, examples, n_steps
|
||||
|
||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||
if isinstance(name_or_id, int):
|
||||
return self.c[name_or_id]
|
||||
|
|
|
@ -268,15 +268,10 @@ cdef class Parser:
|
|||
for multitask in self._multitasks:
|
||||
multitask.update(examples, drop=drop, sgd=sgd)
|
||||
set_dropout_rate(self.model, drop)
|
||||
try:
|
||||
states, golds, max_steps = self._init_gold_batch_no_cut(examples)
|
||||
except AttributeError:
|
||||
types = set([type(eg) for eg in examples])
|
||||
raise ValueError(Errors.E978.format(name="Parser", method="update", types=types))
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||
if not s.is_final() and g is not None]
|
||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update([eg.doc for eg in examples])
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
[eg.predicted for eg in examples])
|
||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
||||
all_states = list(states)
|
||||
for _ in range(max_steps):
|
||||
if not states_golds:
|
||||
|
@ -287,12 +282,12 @@ cdef class Parser:
|
|||
backprop(d_scores)
|
||||
# Follow the predicted action
|
||||
self.transition_states(states, scores)
|
||||
states_golds = [eg for eg in states_golds if not eg[0].is_final()]
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||
backprop_tok2vec(golds)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
if set_annotations:
|
||||
docs = [eg.doc for eg in examples]
|
||||
docs = [eg.predicted for eg in examples]
|
||||
self.set_annotations(docs, all_states)
|
||||
return losses
|
||||
|
||||
|
@ -307,7 +302,7 @@ cdef class Parser:
|
|||
return None
|
||||
losses.setdefault(self.name, 0.)
|
||||
|
||||
docs = [eg.doc for eg in examples]
|
||||
docs = [eg.predicted for eg in examples]
|
||||
states = self.moves.init_batch(docs)
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
|
@ -356,11 +351,7 @@ cdef class Parser:
|
|||
queue.extend(node._layers)
|
||||
return gradients
|
||||
|
||||
def _init_gold_batch_no_cut(self, examples):
|
||||
states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||
return states, examples
|
||||
|
||||
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
|
||||
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
||||
cdef StateClass state
|
||||
cdef Example example
|
||||
cdef Pool mem = Pool()
|
||||
|
@ -375,10 +366,10 @@ cdef class Parser:
|
|||
dtype='f', order='C')
|
||||
c_d_scores = <float*>d_scores.data
|
||||
unseen_classes = self.model.attrs["unseen_classes"]
|
||||
for i, (state, eg) in enumerate(zip(states, examples)):
|
||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
||||
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
||||
self.moves.set_costs(is_valid, costs, state, eg)
|
||||
self.moves.set_costs(is_valid, costs, state, gold)
|
||||
for j in range(self.moves.n_moves):
|
||||
if costs[j] <= 0.0 and j in unseen_classes:
|
||||
unseen_classes.remove(j)
|
||||
|
@ -403,9 +394,11 @@ cdef class Parser:
|
|||
if not hasattr(get_examples, '__call__'):
|
||||
gold_tuples = get_examples
|
||||
get_examples = lambda: gold_tuples
|
||||
actions = self.moves.get_actions(gold_parses=get_examples(),
|
||||
min_freq=self.cfg['min_action_freq'],
|
||||
learn_tokens=self.cfg["learn_tokens"])
|
||||
actions = self.moves.get_actions(
|
||||
examples=get_examples(),
|
||||
min_freq=self.cfg['min_action_freq'],
|
||||
learn_tokens=self.cfg["learn_tokens"]
|
||||
)
|
||||
for action, labels in self.moves.labels.items():
|
||||
actions.setdefault(action, {})
|
||||
for label, freq in labels.items():
|
||||
|
|
Loading…
Reference in New Issue
Block a user