Improve unseen label masking

Two changes to speed up masking by ~10%:

- Use a bool array rather than an array of float32.

- Let the mask indicate whether a label was seen, rather than
  unseen. The mask is most frequently used to index scores for
  seen labels. However, since the mask marked unseen labels,
  this required computing an intermittent flipped mask.
This commit is contained in:
Daniël de Kok 2022-01-26 12:01:37 +01:00
parent ec0cae9db8
commit c2475a25de

View File

@ -150,7 +150,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
all_statevecs = []
all_scores = []
next_states = [s for s in states if not s.is_final()]
unseen_mask = _get_unseen_mask(model)
seen_mask = _get_seen_mask(model)
ids = numpy.zeros((len(states), nF), dtype="i")
arange = model.ops.xp.arange(nF)
while next_states:
@ -168,9 +168,10 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
# to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores)
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores)
cpu_scores = model.ops.to_numpy(scores)
next_states = moves.transition_states(next_states, cpu_scores)
all_scores.append(scores)
if is_train:
# Remember intermediate results for the backprop.
@ -191,7 +192,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
for clas in set(model.attrs["unseen_classes"]):
if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas)
d_scores *= unseen_mask
d_scores *= seen_mask == False
# Calculate the gradients for the parameters of the upper layer.
# The weight gemm is (nS, nO) @ (nS, nH).T
model.inc_grad("upper_b", d_scores.sum(axis=0))
@ -240,7 +241,7 @@ def _forward_reference(
all_scores = []
all_tokfeats = []
next_states = [s for s in states if not s.is_final()]
unseen_mask = _get_unseen_mask(model)
seen_mask = _get_seen_mask(model)
ids = numpy.zeros((len(states), nF), dtype="i")
while next_states:
ids = ids[: len(next_states)]
@ -258,7 +259,7 @@ def _forward_reference(
# to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores)
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores)
all_scores.append(scores)
@ -285,7 +286,7 @@ def _forward_reference(
for clas in set(model.attrs["unseen_classes"]):
if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas)
d_scores *= unseen_mask
d_scores *= seen_mask == False
assert statevecs.shape == (nS, nH), statevecs.shape
assert d_scores.shape == (nS, nO), d_scores.shape
# Calculate the gradients for the parameters of the upper layer.
@ -314,11 +315,10 @@ def _forward_reference(
return (states, all_scores), backprop_parser
def _get_unseen_mask(model: Model) -> Floats1d:
mask = model.ops.alloc1f(model.get_dim("nO"))
mask.fill(1)
def _get_seen_mask(model: Model) -> Floats1d:
mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
for class_ in model.attrs.get("unseen_classes", set()):
mask[class_] = 0
mask[class_] = True
return mask