Improve unseen label masking

Two changes to speed up masking by ~10%:

- Use a bool array rather than an array of float32.

- Let the mask indicate whether a label was seen, rather than
  unseen. The mask is most frequently used to index scores for
  seen labels. However, since the mask marked unseen labels,
  this required computing an intermittent flipped mask.
This commit is contained in:
Daniël de Kok 2022-01-26 12:01:37 +01:00
parent ec0cae9db8
commit c2475a25de

View File

@ -150,7 +150,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
all_statevecs = [] all_statevecs = []
all_scores = [] all_scores = []
next_states = [s for s in states if not s.is_final()] next_states = [s for s in states if not s.is_final()]
unseen_mask = _get_unseen_mask(model) seen_mask = _get_seen_mask(model)
ids = numpy.zeros((len(states), nF), dtype="i") ids = numpy.zeros((len(states), nF), dtype="i")
arange = model.ops.xp.arange(nF) arange = model.ops.xp.arange(nF)
while next_states: while next_states:
@ -168,9 +168,10 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
# to get the logits. # to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True) scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b scores += upper_b
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished. # Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores) cpu_scores = model.ops.to_numpy(scores)
next_states = moves.transition_states(next_states, cpu_scores)
all_scores.append(scores) all_scores.append(scores)
if is_train: if is_train:
# Remember intermediate results for the backprop. # Remember intermediate results for the backprop.
@ -191,7 +192,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
for clas in set(model.attrs["unseen_classes"]): for clas in set(model.attrs["unseen_classes"]):
if (d_scores[:, clas] < 0).any(): if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas) model.attrs["unseen_classes"].remove(clas)
d_scores *= unseen_mask d_scores *= seen_mask == False
# Calculate the gradients for the parameters of the upper layer. # Calculate the gradients for the parameters of the upper layer.
# The weight gemm is (nS, nO) @ (nS, nH).T # The weight gemm is (nS, nO) @ (nS, nH).T
model.inc_grad("upper_b", d_scores.sum(axis=0)) model.inc_grad("upper_b", d_scores.sum(axis=0))
@ -240,7 +241,7 @@ def _forward_reference(
all_scores = [] all_scores = []
all_tokfeats = [] all_tokfeats = []
next_states = [s for s in states if not s.is_final()] next_states = [s for s in states if not s.is_final()]
unseen_mask = _get_unseen_mask(model) seen_mask = _get_seen_mask(model)
ids = numpy.zeros((len(states), nF), dtype="i") ids = numpy.zeros((len(states), nF), dtype="i")
while next_states: while next_states:
ids = ids[: len(next_states)] ids = ids[: len(next_states)]
@ -258,7 +259,7 @@ def _forward_reference(
# to get the logits. # to get the logits.
scores = model.ops.gemm(statevecs, upper_W, trans2=True) scores = model.ops.gemm(statevecs, upper_W, trans2=True)
scores += upper_b scores += upper_b
scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) scores[:, seen_mask] = model.ops.xp.nanmin(scores)
# Transition the states, filtering out any that are finished. # Transition the states, filtering out any that are finished.
next_states = moves.transition_states(next_states, scores) next_states = moves.transition_states(next_states, scores)
all_scores.append(scores) all_scores.append(scores)
@ -285,7 +286,7 @@ def _forward_reference(
for clas in set(model.attrs["unseen_classes"]): for clas in set(model.attrs["unseen_classes"]):
if (d_scores[:, clas] < 0).any(): if (d_scores[:, clas] < 0).any():
model.attrs["unseen_classes"].remove(clas) model.attrs["unseen_classes"].remove(clas)
d_scores *= unseen_mask d_scores *= seen_mask == False
assert statevecs.shape == (nS, nH), statevecs.shape assert statevecs.shape == (nS, nH), statevecs.shape
assert d_scores.shape == (nS, nO), d_scores.shape assert d_scores.shape == (nS, nO), d_scores.shape
# Calculate the gradients for the parameters of the upper layer. # Calculate the gradients for the parameters of the upper layer.
@ -314,11 +315,10 @@ def _forward_reference(
return (states, all_scores), backprop_parser return (states, all_scores), backprop_parser
def _get_unseen_mask(model: Model) -> Floats1d: def _get_seen_mask(model: Model) -> Floats1d:
mask = model.ops.alloc1f(model.get_dim("nO")) mask = model.ops.xp.zeros(model.get_dim("nO"), dtype="bool")
mask.fill(1)
for class_ in model.attrs.get("unseen_classes", set()): for class_ in model.attrs.get("unseen_classes", set()):
mask[class_] = 0 mask[class_] = True
return mask return mask