# cython: infer_types=True # cython: profile=True import numpy from thinc.extra.search cimport Beam from thinc.extra.search import MaxViolation from thinc.extra.search cimport MaxViolation from ...typedefs cimport class_t from .transition_system cimport Transition, TransitionSystem from ...errors import Errors from .stateclass cimport StateC, StateClass # These are passed as callbacks to thinc.search.Beam cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: dest = _dest src = _src moves = _moves dest.clone(src) moves[clas].do(dest, moves[clas].label) cdef int check_final_state(void* _state, void* extra_args) except -1: state = _state return state.is_final() cdef class BeamBatch(object): cdef public TransitionSystem moves cdef public object states cdef public object docs cdef public object golds cdef public object beams def __init__(self, TransitionSystem moves, states, golds, int width, float density=0.): cdef StateClass state self.moves = moves self.states = states self.docs = [state.doc for state in states] self.golds = golds self.beams = [] cdef Beam beam cdef StateC* st for state in states: beam = Beam(self.moves.n_moves, width, min_density=density) beam.initialize(self.moves.init_beam_state, self.moves.del_beam_state, state.c.length, state.c._sent) for i in range(beam.width): st = beam.at(i) st.offset = state.c.offset beam.check_done(check_final_state, NULL) self.beams.append(beam) @property def is_done(self): return all(b.is_done for b in self.beams) def __getitem__(self, i): return self.beams[i] def __len__(self): return len(self.beams) def get_states(self): cdef Beam beam cdef StateC* state cdef StateClass stcls states = [] for beam, doc in zip(self, self.docs): for i in range(beam.size): state = beam.at(i) stcls = StateClass.borrow(state, doc) states.append(stcls) return states def get_unfinished_states(self): return [st for st in self.get_states() if not st.is_final()] def advance(self, float[:, ::1] scores, follow_gold=False): cdef Beam beam cdef int nr_class = scores.shape[1] cdef const float* c_scores = &scores[0, 0] docs = self.docs for i, beam in enumerate(self): if not beam.is_done: nr_state = self._set_scores(beam, c_scores, nr_class) assert nr_state if self.golds is not None: self._set_costs( beam, docs[i], self.golds[i], follow_gold=follow_gold ) c_scores += nr_state * nr_class beam.advance(transition_state, NULL, self.moves.c) beam.check_done(check_final_state, NULL) cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1: cdef int nr_state = 0 for i in range(beam.size): state = beam.at(i) if not state.is_final(): for j in range(nr_class): beam.scores[i][j] = scores[nr_state * nr_class + j] self.moves.set_valid(beam.is_valid[i], state) nr_state += 1 else: for j in range(beam.nr_class): beam.scores[i][j] = 0 beam.costs[i][j] = 0 return nr_state def _set_costs(self, Beam beam, doc, gold, int follow_gold=False): cdef const StateC* state for i in range(beam.size): state = beam.at(i) if state.is_final(): for j in range(beam.nr_class): beam.is_valid[i][j] = 0 beam.costs[i][j] = 9000 else: self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold) if follow_gold: min_cost = 0 for j in range(beam.nr_class): if beam.is_valid[i][j] and beam.costs[i][j] < min_cost: min_cost = beam.costs[i][j] for j in range(beam.nr_class): if beam.costs[i][j] > min_cost: beam.is_valid[i][j] = 0 def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0): cdef MaxViolation violn pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density) gbeam = BeamBatch(moves, states, golds, width=width, density=0.0) beam_maps = [] backprops = [] violns = [MaxViolation() for _ in range(len(states))] dones = [False for _ in states] while not pbeam.is_done or not gbeam.is_done: # The beam maps let us find the right row in the flattened scores # array for each state. States are identified by (example id, # history). We keep a different beam map for each step (since we'll # have a flat scores array for each step). The beam map will let us # take the per-state losses, and compute the gradient for each (step, # state, class). # Gather all states from the two beams in a list. Some stats may occur # in both beams. To figure out which beam each state belonged to, # we keep two lists of indices, p_indices and g_indices states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam) beam_maps.append(beam_map) if not states: break # Now that we have our flat list of states, feed them through the model scores, bp_scores = model.begin_update(states) assert scores.size != 0 # Store the callbacks for the backward pass backprops.append(bp_scores) # Unpack the scores for the two beams. The indices arrays # tell us which example and state the scores-row refers to. # Now advance the states in the beams. The gold beam is constrained to # to follow only gold analyses. if not pbeam.is_done: pbeam.advance(model.ops.as_contig(scores[p_indices])) if not gbeam.is_done: gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True) # Track the "maximum violation", to use in the update. for i, violn in enumerate(violns): if not dones[i]: violn.check_crf(pbeam[i], gbeam[i]) if pbeam[i].is_done and gbeam[i].is_done: dones[i] = True histories = [] grads = [] for violn in violns: if violn.p_hist: histories.append(violn.p_hist + violn.g_hist) d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs] grads.append(d_loss) else: histories.append([]) grads.append([]) loss = 0.0 states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads) for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)): loss += (d_scores**2).mean() bp_scores(d_scores) return loss def collect_states(beams, docs): cdef StateClass state cdef Beam beam states = [] for state_or_beam, doc in zip(beams, docs): if isinstance(state_or_beam, StateClass): states.append(state_or_beam) else: beam = state_or_beam state = StateClass.borrow(beam.at(0), doc) states.append(state) return states def get_unique_states(pbeams, gbeams): seen = {} states = [] p_indices = [] g_indices = [] beam_map = {} docs = pbeams.docs cdef Beam pbeam, gbeam if len(pbeams) != len(gbeams): raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams))) for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)): if not pbeam.is_done: for i in range(pbeam.size): state = StateClass.borrow(pbeam.at(i), doc) if not state.is_final(): key = tuple([eg_id] + pbeam.histories[i]) if key in seen: raise ValueError(Errors.E080.format(key=key)) seen[key] = len(states) p_indices.append(len(states)) states.append(state) beam_map.update(seen) if not gbeam.is_done: for i in range(gbeam.size): state = StateClass.borrow(gbeam.at(i), doc) if not state.is_final(): key = tuple([eg_id] + gbeam.histories[i]) if key in seen: g_indices.append(seen[key]) else: g_indices.append(len(states)) beam_map[key] = len(states) states.append(state) p_indices = numpy.asarray(p_indices, dtype='i') g_indices = numpy.asarray(g_indices, dtype='i') return states, p_indices, g_indices, beam_map def get_gradient(nr_class, beam_maps, histories, losses): """The global model assigns a loss to each parse. The beam scores are additive, so the same gradient is applied to each action in the history. This gives the gradient of a single *action* for a beam state -- so we have "the gradient of loss for taking action i given history H." Histories: Each history is a list of actions Each candidate has a history Each beam has multiple candidates Each batch has multiple beams So history is list of lists of lists of ints """ grads = [] nr_steps = [] for eg_id, hists in enumerate(histories): nr_step = 0 for loss, hist in zip(losses[eg_id], hists): assert not numpy.isnan(loss) if loss != 0.0: nr_step = max(nr_step, len(hist)) nr_steps.append(nr_step) for i in range(max(nr_steps)): grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f')) if len(histories) != len(losses): raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses))) for eg_id, hists in enumerate(histories): for loss, hist in zip(losses[eg_id], hists): assert not numpy.isnan(loss) if loss == 0.0: continue key = tuple([eg_id]) # Adjust loss for length # We need to do this because each state in a short path is scored # multiple times, as we add in the average cost when we run out # of actions. avg_loss = loss / len(hist) loss += avg_loss * (nr_steps[eg_id] - len(hist)) for step, clas in enumerate(hist): i = beam_maps[step][key] # In step j, at state i action clas # resulted in loss grads[step][i, clas] += loss key = key + tuple([clas]) return grads