Silence warning

This commit is contained in:
Paul O'Leary McCann 2021-06-12 19:31:35 +09:00
parent 7efbc721a1
commit e728b0e45d

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
import warnings
from thinc.api import Model, Linear, Relu, Dropout, chain, noop
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
@ -366,14 +367,14 @@ def ant_scorer_forward(
# make a mask so antecedents precede referrents
ant_range = xp.arange(0, cvecs.shape[0])
# TODO use python warning
# with xp.errstate(divide="ignore"):
# mask = xp.log(
# (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
# ).astype(float)
mask = xp.log(
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
).astype(float)
# This will take the log of 0, which causes a warning, but we're doing
# it on purpose so we can just ignore the warning.
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning)
mask = xp.log(
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
).astype(float)
scores = pw_prod + pw_sum + mask