2020-12-13 04:08:32 +03:00
|
|
|
# cython: infer_types=True
|
|
|
|
# cython: profile=True
|
|
|
|
import numpy
|
2023-06-14 18:48:41 +03:00
|
|
|
|
2020-12-13 04:08:32 +03:00
|
|
|
from thinc.extra.search cimport Beam
|
2023-06-14 18:48:41 +03:00
|
|
|
|
2020-12-13 04:08:32 +03:00
|
|
|
from thinc.extra.search import MaxViolation
|
2023-06-14 18:48:41 +03:00
|
|
|
|
2020-12-13 04:08:32 +03:00
|
|
|
from thinc.extra.search cimport MaxViolation
|
|
|
|
|
2023-07-19 13:03:31 +03:00
|
|
|
from ...typedefs cimport class_t
|
2023-06-14 18:48:41 +03:00
|
|
|
from .transition_system cimport Transition, TransitionSystem
|
|
|
|
|
2020-12-13 04:08:32 +03:00
|
|
|
from ...errors import Errors
|
2023-06-14 18:48:41 +03:00
|
|
|
|
2020-12-13 04:08:32 +03:00
|
|
|
from .stateclass cimport StateC, StateClass
|
|
|
|
|
|
|
|
|
|
|
|
# These are passed as callbacks to thinc.search.Beam
|
|
|
|
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
|
|
|
dest = <StateC*>_dest
|
|
|
|
src = <StateC*>_src
|
|
|
|
moves = <const Transition*>_moves
|
|
|
|
dest.clone(src)
|
|
|
|
moves[clas].do(dest, moves[clas].label)
|
|
|
|
|
|
|
|
|
|
|
|
cdef int check_final_state(void* _state, void* extra_args) except -1:
|
|
|
|
state = <StateC*>_state
|
|
|
|
return state.is_final()
|
|
|
|
|
|
|
|
|
|
|
|
cdef class BeamBatch(object):
|
|
|
|
cdef public TransitionSystem moves
|
|
|
|
cdef public object states
|
|
|
|
cdef public object docs
|
|
|
|
cdef public object golds
|
|
|
|
cdef public object beams
|
|
|
|
|
|
|
|
def __init__(self, TransitionSystem moves, states, golds,
|
|
|
|
int width, float density=0.):
|
|
|
|
cdef StateClass state
|
|
|
|
self.moves = moves
|
|
|
|
self.states = states
|
|
|
|
self.docs = [state.doc for state in states]
|
|
|
|
self.golds = golds
|
|
|
|
self.beams = []
|
|
|
|
cdef Beam beam
|
|
|
|
cdef StateC* st
|
|
|
|
for state in states:
|
|
|
|
beam = Beam(self.moves.n_moves, width, min_density=density)
|
|
|
|
beam.initialize(self.moves.init_beam_state,
|
|
|
|
self.moves.del_beam_state, state.c.length,
|
|
|
|
<void*>state.c._sent)
|
|
|
|
for i in range(beam.width):
|
|
|
|
st = <StateC*>beam.at(i)
|
|
|
|
st.offset = state.c.offset
|
|
|
|
beam.check_done(check_final_state, NULL)
|
|
|
|
self.beams.append(beam)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_done(self):
|
|
|
|
return all(b.is_done for b in self.beams)
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
return self.beams[i]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.beams)
|
|
|
|
|
|
|
|
def get_states(self):
|
|
|
|
cdef Beam beam
|
|
|
|
cdef StateC* state
|
|
|
|
cdef StateClass stcls
|
|
|
|
states = []
|
|
|
|
for beam, doc in zip(self, self.docs):
|
|
|
|
for i in range(beam.size):
|
|
|
|
state = <StateC*>beam.at(i)
|
|
|
|
stcls = StateClass.borrow(state, doc)
|
|
|
|
states.append(stcls)
|
|
|
|
return states
|
|
|
|
|
|
|
|
def get_unfinished_states(self):
|
|
|
|
return [st for st in self.get_states() if not st.is_final()]
|
|
|
|
|
|
|
|
def advance(self, float[:, ::1] scores, follow_gold=False):
|
|
|
|
cdef Beam beam
|
|
|
|
cdef int nr_class = scores.shape[1]
|
|
|
|
cdef const float* c_scores = &scores[0, 0]
|
|
|
|
docs = self.docs
|
|
|
|
for i, beam in enumerate(self):
|
|
|
|
if not beam.is_done:
|
|
|
|
nr_state = self._set_scores(beam, c_scores, nr_class)
|
|
|
|
assert nr_state
|
|
|
|
if self.golds is not None:
|
|
|
|
self._set_costs(
|
|
|
|
beam,
|
|
|
|
docs[i],
|
|
|
|
self.golds[i],
|
|
|
|
follow_gold=follow_gold
|
|
|
|
)
|
|
|
|
c_scores += nr_state * nr_class
|
|
|
|
beam.advance(transition_state, NULL, <void*>self.moves.c)
|
|
|
|
beam.check_done(check_final_state, NULL)
|
|
|
|
|
|
|
|
cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
|
|
|
|
cdef int nr_state = 0
|
|
|
|
for i in range(beam.size):
|
|
|
|
state = <StateC*>beam.at(i)
|
|
|
|
if not state.is_final():
|
|
|
|
for j in range(nr_class):
|
|
|
|
beam.scores[i][j] = scores[nr_state * nr_class + j]
|
|
|
|
self.moves.set_valid(beam.is_valid[i], state)
|
|
|
|
nr_state += 1
|
|
|
|
else:
|
|
|
|
for j in range(beam.nr_class):
|
|
|
|
beam.scores[i][j] = 0
|
|
|
|
beam.costs[i][j] = 0
|
|
|
|
return nr_state
|
|
|
|
|
|
|
|
def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
|
|
|
|
cdef const StateC* state
|
|
|
|
for i in range(beam.size):
|
|
|
|
state = <const StateC*>beam.at(i)
|
|
|
|
if state.is_final():
|
|
|
|
for j in range(beam.nr_class):
|
|
|
|
beam.is_valid[i][j] = 0
|
|
|
|
beam.costs[i][j] = 9000
|
|
|
|
else:
|
|
|
|
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
|
|
|
state, gold)
|
|
|
|
if follow_gold:
|
|
|
|
min_cost = 0
|
|
|
|
for j in range(beam.nr_class):
|
|
|
|
if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
|
|
|
|
min_cost = beam.costs[i][j]
|
|
|
|
for j in range(beam.nr_class):
|
|
|
|
if beam.costs[i][j] > min_cost:
|
|
|
|
beam.is_valid[i][j] = 0
|
|
|
|
|
|
|
|
|
|
|
|
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
|
|
|
|
cdef MaxViolation violn
|
|
|
|
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
|
|
|
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
|
|
|
beam_maps = []
|
|
|
|
backprops = []
|
|
|
|
violns = [MaxViolation() for _ in range(len(states))]
|
|
|
|
dones = [False for _ in states]
|
|
|
|
while not pbeam.is_done or not gbeam.is_done:
|
|
|
|
# The beam maps let us find the right row in the flattened scores
|
|
|
|
# array for each state. States are identified by (example id,
|
|
|
|
# history). We keep a different beam map for each step (since we'll
|
|
|
|
# have a flat scores array for each step). The beam map will let us
|
|
|
|
# take the per-state losses, and compute the gradient for each (step,
|
|
|
|
# state, class).
|
|
|
|
# Gather all states from the two beams in a list. Some stats may occur
|
|
|
|
# in both beams. To figure out which beam each state belonged to,
|
|
|
|
# we keep two lists of indices, p_indices and g_indices
|
|
|
|
states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
|
|
|
|
beam_maps.append(beam_map)
|
|
|
|
if not states:
|
|
|
|
break
|
|
|
|
# Now that we have our flat list of states, feed them through the model
|
|
|
|
scores, bp_scores = model.begin_update(states)
|
|
|
|
assert scores.size != 0
|
|
|
|
# Store the callbacks for the backward pass
|
|
|
|
backprops.append(bp_scores)
|
|
|
|
# Unpack the scores for the two beams. The indices arrays
|
|
|
|
# tell us which example and state the scores-row refers to.
|
|
|
|
# Now advance the states in the beams. The gold beam is constrained to
|
|
|
|
# to follow only gold analyses.
|
|
|
|
if not pbeam.is_done:
|
|
|
|
pbeam.advance(model.ops.as_contig(scores[p_indices]))
|
|
|
|
if not gbeam.is_done:
|
|
|
|
gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
|
|
|
|
# Track the "maximum violation", to use in the update.
|
|
|
|
for i, violn in enumerate(violns):
|
|
|
|
if not dones[i]:
|
|
|
|
violn.check_crf(pbeam[i], gbeam[i])
|
|
|
|
if pbeam[i].is_done and gbeam[i].is_done:
|
|
|
|
dones[i] = True
|
|
|
|
histories = []
|
|
|
|
grads = []
|
|
|
|
for violn in violns:
|
|
|
|
if violn.p_hist:
|
|
|
|
histories.append(violn.p_hist + violn.g_hist)
|
|
|
|
d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
|
|
|
|
grads.append(d_loss)
|
|
|
|
else:
|
|
|
|
histories.append([])
|
|
|
|
grads.append([])
|
|
|
|
loss = 0.0
|
|
|
|
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
|
|
|
|
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
|
|
|
|
loss += (d_scores**2).mean()
|
|
|
|
bp_scores(d_scores)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def collect_states(beams, docs):
|
|
|
|
cdef StateClass state
|
|
|
|
cdef Beam beam
|
|
|
|
states = []
|
|
|
|
for state_or_beam, doc in zip(beams, docs):
|
|
|
|
if isinstance(state_or_beam, StateClass):
|
|
|
|
states.append(state_or_beam)
|
|
|
|
else:
|
|
|
|
beam = state_or_beam
|
|
|
|
state = StateClass.borrow(<StateC*>beam.at(0), doc)
|
|
|
|
states.append(state)
|
|
|
|
return states
|
|
|
|
|
|
|
|
|
|
|
|
def get_unique_states(pbeams, gbeams):
|
|
|
|
seen = {}
|
|
|
|
states = []
|
|
|
|
p_indices = []
|
|
|
|
g_indices = []
|
|
|
|
beam_map = {}
|
|
|
|
docs = pbeams.docs
|
|
|
|
cdef Beam pbeam, gbeam
|
|
|
|
if len(pbeams) != len(gbeams):
|
|
|
|
raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
|
|
|
|
for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
|
|
|
|
if not pbeam.is_done:
|
|
|
|
for i in range(pbeam.size):
|
|
|
|
state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
|
|
|
|
if not state.is_final():
|
|
|
|
key = tuple([eg_id] + pbeam.histories[i])
|
|
|
|
if key in seen:
|
|
|
|
raise ValueError(Errors.E080.format(key=key))
|
|
|
|
seen[key] = len(states)
|
|
|
|
p_indices.append(len(states))
|
|
|
|
states.append(state)
|
|
|
|
beam_map.update(seen)
|
|
|
|
if not gbeam.is_done:
|
|
|
|
for i in range(gbeam.size):
|
|
|
|
state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
|
|
|
|
if not state.is_final():
|
|
|
|
key = tuple([eg_id] + gbeam.histories[i])
|
|
|
|
if key in seen:
|
|
|
|
g_indices.append(seen[key])
|
|
|
|
else:
|
|
|
|
g_indices.append(len(states))
|
|
|
|
beam_map[key] = len(states)
|
|
|
|
states.append(state)
|
|
|
|
p_indices = numpy.asarray(p_indices, dtype='i')
|
|
|
|
g_indices = numpy.asarray(g_indices, dtype='i')
|
|
|
|
return states, p_indices, g_indices, beam_map
|
|
|
|
|
|
|
|
|
|
|
|
def get_gradient(nr_class, beam_maps, histories, losses):
|
|
|
|
"""The global model assigns a loss to each parse. The beam scores
|
|
|
|
are additive, so the same gradient is applied to each action
|
|
|
|
in the history. This gives the gradient of a single *action*
|
|
|
|
for a beam state -- so we have "the gradient of loss for taking
|
|
|
|
action i given history H."
|
|
|
|
|
2021-01-06 14:02:32 +03:00
|
|
|
Histories: Each history is a list of actions
|
2020-12-13 04:08:32 +03:00
|
|
|
Each candidate has a history
|
|
|
|
Each beam has multiple candidates
|
|
|
|
Each batch has multiple beams
|
|
|
|
So history is list of lists of lists of ints
|
|
|
|
"""
|
|
|
|
grads = []
|
|
|
|
nr_steps = []
|
|
|
|
for eg_id, hists in enumerate(histories):
|
|
|
|
nr_step = 0
|
|
|
|
for loss, hist in zip(losses[eg_id], hists):
|
|
|
|
assert not numpy.isnan(loss)
|
|
|
|
if loss != 0.0:
|
|
|
|
nr_step = max(nr_step, len(hist))
|
|
|
|
nr_steps.append(nr_step)
|
|
|
|
for i in range(max(nr_steps)):
|
|
|
|
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
|
|
|
|
dtype='f'))
|
|
|
|
if len(histories) != len(losses):
|
|
|
|
raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
|
|
|
|
for eg_id, hists in enumerate(histories):
|
|
|
|
for loss, hist in zip(losses[eg_id], hists):
|
|
|
|
assert not numpy.isnan(loss)
|
|
|
|
if loss == 0.0:
|
|
|
|
continue
|
|
|
|
key = tuple([eg_id])
|
|
|
|
# Adjust loss for length
|
|
|
|
# We need to do this because each state in a short path is scored
|
|
|
|
# multiple times, as we add in the average cost when we run out
|
|
|
|
# of actions.
|
|
|
|
avg_loss = loss / len(hist)
|
|
|
|
loss += avg_loss * (nr_steps[eg_id] - len(hist))
|
|
|
|
for step, clas in enumerate(hist):
|
|
|
|
i = beam_maps[step][key]
|
|
|
|
# In step j, at state i action clas
|
|
|
|
# resulted in loss
|
|
|
|
grads[step][i, clas] += loss
|
|
|
|
key = key + tuple([clas])
|
|
|
|
return grads
|