spaCy/spacy/pipeline/_parser_internals/_beam_utils.pyx
2023-07-19 16:37:31 +02:00

304 lines
11 KiB
Cython

# cython: infer_types=True
# cython: profile=True
import numpy
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.extra.search cimport MaxViolation
from ...typedefs cimport class_t
from .transition_system cimport Transition, TransitionSystem
from ...errors import Errors
from .batch cimport Batch
from .search cimport Beam, MaxViolation
from .search import MaxViolation
from .stateclass cimport StateC, StateClass
# These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateC*>_dest
src = <StateC*>_src
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest, moves[clas].label)
cdef int check_final_state(void* _state, void* extra_args) except -1:
state = <StateC*>_state
return state.is_final()
cdef class BeamBatch(Batch):
cdef public TransitionSystem moves
cdef public object states
cdef public object docs
cdef public object golds
cdef public object beams
def __init__(self, TransitionSystem moves, states, golds,
int width, float density=0.):
cdef StateClass state
self.moves = moves
self.states = states
self.docs = [state.doc for state in states]
self.golds = golds
self.beams = []
cdef Beam beam
cdef StateC* st
for state in states:
beam = Beam(self.moves.n_moves, width, min_density=density)
beam.initialize(self.moves.init_beam_state,
self.moves.del_beam_state, state.c.length,
<void*>state.c._sent)
for i in range(beam.width):
st = <StateC*>beam.at(i)
st.offset = state.c.offset
beam.check_done(check_final_state, NULL)
self.beams.append(beam)
@property
def is_done(self):
return all(b.is_done for b in self.beams)
def __getitem__(self, i):
return self.beams[i]
def __len__(self):
return len(self.beams)
def get_states(self):
cdef Beam beam
cdef StateC* state
cdef StateClass stcls
states = []
for beam, doc in zip(self, self.docs):
for i in range(beam.size):
state = <StateC*>beam.at(i)
stcls = StateClass.borrow(state, doc)
states.append(stcls)
return states
def get_unfinished_states(self):
return [st for st in self.get_states() if not st.is_final()]
def advance(self, float[:, ::1] scores, follow_gold=False):
cdef Beam beam
cdef int nr_class = scores.shape[1]
cdef const float* c_scores = &scores[0, 0]
docs = self.docs
for i, beam in enumerate(self):
if not beam.is_done:
nr_state = self._set_scores(beam, c_scores, nr_class)
assert nr_state
if self.golds is not None:
self._set_costs(
beam,
docs[i],
self.golds[i],
follow_gold=follow_gold
)
c_scores += nr_state * nr_class
beam.advance(transition_state, NULL, <void*>self.moves.c)
beam.check_done(check_final_state, NULL)
cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
cdef int nr_state = 0
for i in range(beam.size):
state = <StateC*>beam.at(i)
if not state.is_final():
for j in range(nr_class):
beam.scores[i][j] = scores[nr_state * nr_class + j]
self.moves.set_valid(beam.is_valid[i], state)
nr_state += 1
else:
for j in range(beam.nr_class):
beam.scores[i][j] = 0
beam.costs[i][j] = 0
return nr_state
def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
cdef const StateC* state
for i in range(beam.size):
state = <const StateC*>beam.at(i)
if state.is_final():
for j in range(beam.nr_class):
beam.is_valid[i][j] = 0
beam.costs[i][j] = 9000
else:
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
state, gold)
if follow_gold:
min_cost = 0
for j in range(beam.nr_class):
if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
min_cost = beam.costs[i][j]
for j in range(beam.nr_class):
if beam.costs[i][j] > min_cost:
beam.is_valid[i][j] = 0
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
cdef MaxViolation violn
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
beam_maps = []
backprops = []
violns = [MaxViolation() for _ in range(len(states))]
dones = [False for _ in states]
while not pbeam.is_done or not gbeam.is_done:
# The beam maps let us find the right row in the flattened scores
# array for each state. States are identified by (example id,
# history). We keep a different beam map for each step (since we'll
# have a flat scores array for each step). The beam map will let us
# take the per-state losses, and compute the gradient for each (step,
# state, class).
# Gather all states from the two beams in a list. Some stats may occur
# in both beams. To figure out which beam each state belonged to,
# we keep two lists of indices, p_indices and g_indices
states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
beam_maps.append(beam_map)
if not states:
break
# Now that we have our flat list of states, feed them through the model
scores, bp_scores = model.begin_update(states)
assert scores.size != 0
# Store the callbacks for the backward pass
backprops.append(bp_scores)
# Unpack the scores for the two beams. The indices arrays
# tell us which example and state the scores-row refers to.
# Now advance the states in the beams. The gold beam is constrained to
# to follow only gold analyses.
if not pbeam.is_done:
pbeam.advance(model.ops.as_contig(scores[p_indices]))
if not gbeam.is_done:
gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
# Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns):
if not dones[i]:
violn.check_crf(pbeam[i], gbeam[i])
if pbeam[i].is_done and gbeam[i].is_done:
dones[i] = True
histories = []
grads = []
for violn in violns:
if violn.p_hist:
histories.append(violn.p_hist + violn.g_hist)
d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
grads.append(d_loss)
else:
histories.append([])
grads.append([])
loss = 0.0
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
loss += (d_scores**2).mean()
bp_scores(d_scores)
return loss
def collect_states(beams, docs):
cdef StateClass state
cdef Beam beam
states = []
for state_or_beam, doc in zip(beams, docs):
if isinstance(state_or_beam, StateClass):
states.append(state_or_beam)
else:
beam = state_or_beam
state = StateClass.borrow(<StateC*>beam.at(0), doc)
states.append(state)
return states
def get_unique_states(pbeams, gbeams):
seen = {}
states = []
p_indices = []
g_indices = []
beam_map = {}
docs = pbeams.docs
cdef Beam pbeam, gbeam
if len(pbeams) != len(gbeams):
raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
if not pbeam.is_done:
for i in range(pbeam.size):
state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i])
if key in seen:
raise ValueError(Errors.E080.format(key=key))
seen[key] = len(states)
p_indices.append(len(states))
states.append(state)
beam_map.update(seen)
if not gbeam.is_done:
for i in range(gbeam.size):
state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
if not state.is_final():
key = tuple([eg_id] + gbeam.histories[i])
if key in seen:
g_indices.append(seen[key])
else:
g_indices.append(len(states))
beam_map[key] = len(states)
states.append(state)
p_indices = numpy.asarray(p_indices, dtype='i')
g_indices = numpy.asarray(g_indices, dtype='i')
return states, p_indices, g_indices, beam_map
def get_gradient(nr_class, beam_maps, histories, losses):
"""The global model assigns a loss to each parse. The beam scores
are additive, so the same gradient is applied to each action
in the history. This gives the gradient of a single *action*
for a beam state -- so we have "the gradient of loss for taking
action i given history H."
Histories: Each history is a list of actions
Each candidate has a history
Each beam has multiple candidates
Each batch has multiple beams
So history is list of lists of lists of ints
"""
grads = []
nr_steps = []
for eg_id, hists in enumerate(histories):
nr_step = 0
for loss, hist in zip(losses[eg_id], hists):
assert not numpy.isnan(loss)
if loss != 0.0:
nr_step = max(nr_step, len(hist))
nr_steps.append(nr_step)
for i in range(max(nr_steps)):
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
dtype='f'))
if len(histories) != len(losses):
raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists):
assert not numpy.isnan(loss)
if loss == 0.0:
continue
key = tuple([eg_id])
# Adjust loss for length
# We need to do this because each state in a short path is scored
# multiple times, as we add in the average cost when we run out
# of actions.
avg_loss = loss / len(hist)
loss += avg_loss * (nr_steps[eg_id] - len(hist))
for step, clas in enumerate(hist):
i = beam_maps[step][key]
# In step j, at state i action clas
# resulted in loss
grads[step][i, clas] += loss
key = key + tuple([clas])
return grads