Replace squeeze with flatten

At a few points in the code it's normal to get a "2d" array where each
row is a single entry. Calling squeeze will make that a proper 1d
array... unless it's just one entry, in which case it turns into a 0d
scalar. That's not what we want; flatten() provides the desired
behavior.
This commit is contained in:
Paul O'Leary McCann 2021-06-12 19:48:01 +09:00
parent e728b0e45d
commit d71198ed36

View File

@ -240,7 +240,7 @@ def coarse_prune(
Mentions can contain other mentions, but candidate mentions cannot cross each other.
"""
rawscores, spanembeds = inputs
scores = rawscores.squeeze()
scores = rawscores.flatten()
mention_limit = model.attrs["mention_limit"]
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
# In old code the limit can be:
@ -287,7 +287,7 @@ def coarse_prune(
dYscores, dYembeds = dY
dXscores = model.ops.alloc1f(idxlen)
dXscores[selected] = dYscores.squeeze()
dXscores[selected] = dYscores.flatten()
dXvecs = model.ops.alloc2f(*vecshape)
dXvecs[selected] = dYembeds.vectors.data
@ -362,7 +362,7 @@ def ant_scorer_forward(
pw_prod, prod_back = pairwise_product(bilinear, dropout, cvecs, is_train)
# now calculate the pairwise mention scores
ms = mscores[offset:hi].squeeze()
ms = mscores[offset:hi].flatten()
pw_sum, pw_sum_back = pairwise_sum(ops, ms)
# make a mask so antecedents precede referrents