Minor cleanup

This commit is contained in:
Paul O'Leary McCann 2021-05-24 14:56:43 +09:00
parent 0942a0b51b
commit d6fd5fe1c0

View File

@ -170,13 +170,15 @@ def span_embeddings_forward(
# TODO support attention here
tokvecs = xp.concatenate(tokvecs)
spans = [tokvecs[ii:jj] for ii, jj in mentions.tolist()]
spans = [tokvecs[ii:jj] for ii, jj in mentions]
avgs = [xp.mean(ss, axis=0) for ss in spans]
spanvecs = ops.asarray2f(avgs)
# first and last token embeds
starts = [tokvecs[ii] for ii in mentions[:, 0]]
ends = [tokvecs[jj] for jj in mentions[:, 1]]
# XXX probably would be faster to get these at once
#starts = [tokvecs[ii] for ii in mentions[:, 0]]
#ends = [tokvecs[jj] for jj in mentions[:, 1]]
starts, ends = zip(*[(tokvecs[ii], tokvecs[jj]) for ii, jj in mentions])
starts = ops.asarray2f(starts)
ends = ops.asarray2f(ends)
@ -366,6 +368,7 @@ def ant_scorer_forward(
# make a mask so antecedents precede referrents
ant_range = xp.arange(0, cvecs.shape[0])
# TODO use python warning
# with xp.errstate(divide="ignore"):
# mask = xp.log(
# (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1