diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 9aac5b801..dd2ff6c19 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -150,7 +150,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo all_statevecs = [] all_scores = [] next_states = [s for s in states if not s.is_final()] - unseen_mask = _get_unseen_mask(model) + seen_mask = _get_seen_mask(model) ids = numpy.zeros((len(states), nF), dtype="i") arange = model.ops.xp.arange(nF) while next_states: @@ -168,9 +168,10 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo # to get the logits. scores = model.ops.gemm(statevecs, upper_W, trans2=True) scores += upper_b - scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) + scores[:, seen_mask] = model.ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. - next_states = moves.transition_states(next_states, scores) + cpu_scores = model.ops.to_numpy(scores) + next_states = moves.transition_states(next_states, cpu_scores) all_scores.append(scores) if is_train: # Remember intermediate results for the backprop. @@ -191,7 +192,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo for clas in set(model.attrs["unseen_classes"]): if (d_scores[:, clas] < 0).any(): model.attrs["unseen_classes"].remove(clas) - d_scores *= unseen_mask + d_scores *= seen_mask == False # Calculate the gradients for the parameters of the upper layer. # The weight gemm is (nS, nO) @ (nS, nH).T model.inc_grad("upper_b", d_scores.sum(axis=0)) @@ -240,7 +241,7 @@ def _forward_reference( all_scores = [] all_tokfeats = [] next_states = [s for s in states if not s.is_final()] - unseen_mask = _get_unseen_mask(model) + seen_mask = _get_seen_mask(model) ids = numpy.zeros((len(states), nF), dtype="i") while next_states: ids = ids[: len(next_states)] @@ -258,7 +259,7 @@ def _forward_reference( # to get the logits. scores = model.ops.gemm(statevecs, upper_W, trans2=True) scores += upper_b - scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) + scores[:, seen_mask] = model.ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. next_states = moves.transition_states(next_states, scores) all_scores.append(scores) @@ -285,7 +286,7 @@ def _forward_reference( for clas in set(model.attrs["unseen_classes"]): if (d_scores[:, clas] < 0).any(): model.attrs["unseen_classes"].remove(clas) - d_scores *= unseen_mask + d_scores *= seen_mask == False assert statevecs.shape == (nS, nH), statevecs.shape assert d_scores.shape == (nS, nO), d_scores.shape # Calculate the gradients for the parameters of the upper layer. @@ -314,11 +315,10 @@ def _forward_reference( return (states, all_scores), backprop_parser -def _get_unseen_mask(model: Model) -> Floats1d: - mask = model.ops.alloc1f(model.get_dim("nO")) - mask.fill(1) +def _get_seen_mask(model: Model) -> Floats1d: + mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool") for class_ in model.attrs.get("unseen_classes", set()): - mask[class_] = 0 + mask[class_] = True return mask