mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Silence warning
This commit is contained in:
parent
7efbc721a1
commit
e728b0e45d
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
from thinc.api import Model, Linear, Relu, Dropout, chain, noop
|
||||
from thinc.types import Floats2d, Floats1d, Ints2d, Ragged
|
||||
|
@ -366,14 +367,14 @@ def ant_scorer_forward(
|
|||
|
||||
# make a mask so antecedents precede referrents
|
||||
ant_range = xp.arange(0, cvecs.shape[0])
|
||||
# TODO use python warning
|
||||
# with xp.errstate(divide="ignore"):
|
||||
# mask = xp.log(
|
||||
# (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
# ).astype(float)
|
||||
mask = xp.log(
|
||||
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
).astype(float)
|
||||
|
||||
# This will take the log of 0, which causes a warning, but we're doing
|
||||
# it on purpose so we can just ignore the warning.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=RuntimeWarning)
|
||||
mask = xp.log(
|
||||
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
|
||||
).astype(float)
|
||||
|
||||
scores = pw_prod + pw_sum + mask
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user