mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-27 16:33:18 +03:00
Merge branch 'add/exclusive-spancat' of github.com:ljvmiranda921/spaCy into add/exclusive-spancat
This commit is contained in:
commit
60a8df7c5f
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from thinc.api import Config, Model, Ops, Optimizer
|
from thinc.api import Config, Model, Ops, Optimizer
|
||||||
|
@ -370,7 +370,8 @@ class SpanCategorizerExclusive(TrainablePipe):
|
||||||
target[negative_samples, self._negative_label] = 1.0 # type: ignore
|
target[negative_samples, self._negative_label] = 1.0 # type: ignore
|
||||||
d_scores = scores - target
|
d_scores = scores - target
|
||||||
neg_weight = cast(float, self.cfg["negative_weight"])
|
neg_weight = cast(float, self.cfg["negative_weight"])
|
||||||
d_scores[negative_samples] *= neg_weight
|
if neg_weight != 1.0:
|
||||||
|
d_scores[negative_samples] *= neg_weight
|
||||||
loss = float((d_scores**2).sum())
|
loss = float((d_scores**2).sum())
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user