mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Replace squeeze with flatten
At a few points in the code it's normal to get a "2d" array where each row is a single entry. Calling squeeze will make that a proper 1d array... unless it's just one entry, in which case it turns into a 0d scalar. That's not what we want; flatten() provides the desired behavior.
This commit is contained in:
parent
e728b0e45d
commit
d71198ed36
|
@ -240,7 +240,7 @@ def coarse_prune(
|
|||
Mentions can contain other mentions, but candidate mentions cannot cross each other.
|
||||
"""
|
||||
rawscores, spanembeds = inputs
|
||||
scores = rawscores.squeeze()
|
||||
scores = rawscores.flatten()
|
||||
mention_limit = model.attrs["mention_limit"]
|
||||
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
|
||||
# In old code the limit can be:
|
||||
|
@ -287,7 +287,7 @@ def coarse_prune(
|
|||
dYscores, dYembeds = dY
|
||||
|
||||
dXscores = model.ops.alloc1f(idxlen)
|
||||
dXscores[selected] = dYscores.squeeze()
|
||||
dXscores[selected] = dYscores.flatten()
|
||||
|
||||
dXvecs = model.ops.alloc2f(*vecshape)
|
||||
dXvecs[selected] = dYembeds.vectors.data
|
||||
|
@ -362,7 +362,7 @@ def ant_scorer_forward(
|
|||
pw_prod, prod_back = pairwise_product(bilinear, dropout, cvecs, is_train)
|
||||
|
||||
# now calculate the pairwise mention scores
|
||||
ms = mscores[offset:hi].squeeze()
|
||||
ms = mscores[offset:hi].flatten()
|
||||
pw_sum, pw_sum_back = pairwise_sum(ops, ms)
|
||||
|
||||
# make a mask so antecedents precede referrents
|
||||
|
|
Loading…
Reference in New Issue
Block a user