diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index ddc283216..714a4e43e 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -21,7 +21,7 @@ def TransitionModel( layer and a linear output layer. """ t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None - tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) + tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore tok2vec_projected.set_dim("nO", hidden_width) return Model( @@ -47,17 +47,28 @@ def TransitionModel( attrs={ "unseen_classes": set(unseen_classes), "resize_output": resize_output, - "make_step_model": make_step_model, }, ) -def make_step_model(model: Model) -> Model[List[State], Floats2d]: - ... - - -def resize_output(model: Model) -> Model: - ... +def resize_output(model: Model, new_nO: int) -> Model: + old_nO = model.maybe_get_dim("nO") + if old_nO is None: + model.set_dim("nO", new_nO) + return model + elif new_nO <= old_nO: + return model + elif model.has_param("upper_W"): + nH = model.get_dim("nH") + new_W = model.ops.alloc2f(new_nO, nH) + new_b = model.ops.alloc1f(new_nO) + old_W = model.get_param("upper_W") + old_b = model.get_param("upper_b") + new_W[:old_nO] = old_W # type: ignore + new_b[:old_nO] = old_b # type: ignore + for i in range(old_nO, new_nO): + model.attrs["unseen_classes"].add(i) + return model def init( @@ -87,9 +98,9 @@ def init( padl = ops.alloc4f(1, nF, nH, nP) Wu = ops.alloc2f(nO, nH) bu = ops.alloc1f(nO) - Wl = normal_init(ops, Wl.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI))) - padl = normal_init(ops, padl.shape, mean=1.0) - # TODO: Experiment with whether better to initialize Wu + Wl = normal_init(ops, Wl.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI))) # type: ignore + padl = normal_init(ops, padl.shape, mean=1.0) # type: ignore + # TODO: Experiment with whether better to initialize upper_W model.set_param("lower_W", Wl) model.set_param("lower_b", bl) model.set_param("lower_pad", padl) @@ -101,11 +112,11 @@ def init( def forward(model, docs_moves, is_train): tok2vec = model.get_ref("tok2vec") - state2scores = model.get_ref("state2scores") - # Get a reference to the parameters. We need to work with - # stable references through the forward/backward pass, to make - # sure we don't have a stale reference if there's concurrent shenanigans. - params = {name: model.get_param(name) for name in model.param_names} + lower_pad = model.get_param("lower_pad") + lower_b = model.get_param("lower_b") + upper_W = model.get_param("upper_W") + upper_b = model.get_param("upper_b") + ops = model.ops docs, moves = docs_moves states = moves.init_batch(docs) @@ -113,108 +124,58 @@ def forward(model, docs_moves, is_train): feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) memory = [] all_scores = [] - while states: - states, scores, memory = _step_parser( - ops, params, moves, states, feats, memory, is_train - ) + next_states = list(states) + while next_states: + ids = moves.get_state_ids(states) + preacts = _sum_state_features(feats, lower_pad, ids) + # * Add the bias + preacts += lower_b + # * Apply the activation (maxout) + statevecs, which = ops.maxout(preacts) + # * Multiply the state-vector by the scores weights + scores = ops.gemm(statevecs, upper_W, trans2=True) + # * Add the bias + scores += upper_b + next_states = moves.transition_states(states, scores) all_scores.append(scores) + if is_train: + memory.append((ids, statevecs, which)) def backprop_parser(d_states_d_scores): _, d_scores = d_states_d_scores - d_feats, ids = _backprop_parser_steps(ops, params, memory, d_scores) - d_tokvecs = backprop_feats((d_feats, ids)) - return backprop_tok2vec(d_tokvecs), None + ids, statevecs, whiches = [ops.xp.concatenate(*item) for item in zip(*memory)] + # TODO: Unseen class masking + # Calculate the gradients for the parameters of the upper layer. + model.inc_grad("upper_b", d_scores.sum(axis=0)) + model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True)) + # Now calculate d_statevecs, by backproping through the upper linear layer. + d_statevecs = model.ops.gemm(d_scores, upper_W) + # Backprop through the maxout activation + d_preacts = model.ops.backprop_maxount( + d_statevecs, whiches, model.get_dim("nP") + ) + # We don't need to backprop the summation, because we pass back the IDs instead + d_tokvecs = backprop_feats((d_preacts, ids)) + return (backprop_tok2vec(d_tokvecs), None) return (states, all_scores), backprop_parser -def _step_parser(ops, params, moves, states, feats, memory, is_train): - ids = moves.get_state_ids(states) - statevecs, which, scores = _score_ids(ops, params, ids, feats, is_train) - next_states = moves.transition_states(states, scores) - if is_train: - memory.append((ids, statevecs, which)) - return next_states, scores, memory - - -def _score_ids(ops, params, ids, feats, is_train): - lower_pad = params["lower_pad"] - lower_b = params["lower_b"] - upper_W = params["upper_W"] - upper_b = params["upper_b"] - # During each step of the parser, we do: - # * Index into the features, to get the pre-activated vector - # for each (token, feature) and sum the feature vectors - preacts = _sum_state_features(feats, lower_pad, ids) - # * Add the bias - preacts += lower_b - # * Apply the activation (maxout) - statevecs, which = ops.maxout(preacts) - # * Multiply the state-vector by the scores weights - scores = ops.gemm(statevecs, upper_W, trans2=True) - # * Add the bias - scores += upper_b - # * Apply the is-class-unseen masking - # TODO - return statevecs, which, scores - - -def _sum_state_features(ops: Ops, feats: Floats3d, ids: Ints2d) -> Floats2d: +def _sum_state_features(ops: Ops, feats: Floats3d, ids: Ints2d, _arange=[]) -> Floats2d: # Here's what we're trying to implement here: # # for i in range(ids.shape[0]): # for j in range(ids.shape[1]): # output[i] += feats[ids[i, j], j] # - # Reshape the feats into 2d, to make indexing easier. Instead of getting an - # array of indices where the cell at (4, 2) needs to refer to the row at - # feats[4, 2], we'll translate the index so that it directly addresses - # feats[18]. This lets us make the indices array 1d, leading to fewer - # numpy shennanigans. - feats2d = ops.reshape2f(feats, feats.shape[0] * feats.shape[1], feats.shape[2]) - # Now translate the ids. If we're looking for the row that used to be at - # (4, 1) and we have 4 features, we'll find it at (4*4)+1=17. - oob_ids = ids < 0 # Retain the -1 values - ids = ids * feats.shape[1] + ops.xp.arange(feats.shape[1]) - ids[oob_ids] = -1 - unsummed2d = feats2d[ops.reshape1i(ids, ids.size)] - unsummed3d = ops.reshape3f( - unsummed2d, feats.shape[0], feats.shape[1], feats.shape[2] - ) - summed = unsummed3d.sum(axis=1) # type: ignore - return summed - - -def _process_memory(ops, memory): - """Concatenate the memory buffers from each state into contiguous - buffers for the whole batch. - """ - return [ops.xp.concatenate(*item) for item in zip(*memory)] - - -def _backprop_parser_steps(model, upper_W, memory, d_scores): - # During each step of the parser, we do: - # * Index into the features, to get the pre-activated vector - # for each (token, feature) - # * Sum the feature vectors - # * Add the bias - # * Apply the activation (maxout) - # * Multiply the state-vector by the scores weights - # * Add the bias - # * Apply the is-class-unseen masking - # - # So we have to backprop through all those steps. - ids, statevecs, whiches = _process_memory(model.ops, memory) - # TODO: Unseen class masking - # Calculate the gradients for the parameters of the upper layer. - model.inc_grad("upper_b", d_scores.sum(axis=0)) - model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True)) - # Now calculate d_statevecs, by backproping through the upper linear layer. - d_statevecs = model.ops.gemm(d_scores, upper_W) - # Backprop through the maxout activation - d_preacts = model.ops.backprop_maxount(d_statevecs, whiches, model.get_dim("nP")) - # We don't need to backprop the summation, because we pass back the IDs instead - return d_preacts, ids + # The arange thingy here is highly weird to me, but apparently + # it's how it works. If you squint a bit at the loop above I guess + # it makes sense? + if not _arange: + _arange.append(ops.xp.arange(ids.shape[1])) + if _arange[0].size != ids.shape[1]: + _arange[0] = ops.xp.arange(ids.shape[1]) + return feats[ids, _arange[0]].sum(axis=1) # type: ignore def _forward_precomputable_affine(model, X: Floats2d, is_train: bool):