mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Minor cleanup
This commit is contained in:
parent
0942a0b51b
commit
d6fd5fe1c0
|
@ -170,13 +170,15 @@ def span_embeddings_forward(
|
|||
|
||||
# TODO support attention here
|
||||
tokvecs = xp.concatenate(tokvecs)
|
||||
spans = [tokvecs[ii:jj] for ii, jj in mentions.tolist()]
|
||||
spans = [tokvecs[ii:jj] for ii, jj in mentions]
|
||||
avgs = [xp.mean(ss, axis=0) for ss in spans]
|
||||
spanvecs = ops.asarray2f(avgs)
|
||||
|
||||
# first and last token embeds
|
||||
starts = [tokvecs[ii] for ii in mentions[:, 0]]
|
||||
ends = [tokvecs[jj] for jj in mentions[:, 1]]
|
||||
# XXX probably would be faster to get these at once
|
||||
#starts = [tokvecs[ii] for ii in mentions[:, 0]]
|
||||
#ends = [tokvecs[jj] for jj in mentions[:, 1]]
|
||||
starts, ends = zip(*[(tokvecs[ii], tokvecs[jj]) for ii, jj in mentions])
|
||||
|
||||
starts = ops.asarray2f(starts)
|
||||
ends = ops.asarray2f(ends)
|
||||
|
@ -366,6 +368,7 @@ def ant_scorer_forward(
|
|||
|
||||
# make a mask so antecedents precede referrents
|
||||
ant_range = xp.arange(0, cvecs.shape[0])
|
||||
# TODO use python warning
|
||||
# with xp.errstate(divide="ignore"):
|
||||
# mask = xp.log(
|
||||
# (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user