mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Cleaner handling of unseen classes
This commit is contained in:
parent
07a3581ff8
commit
d765a4f8ee
|
@ -151,6 +151,13 @@ def forward(model, docs_moves, is_train):
|
|||
|
||||
def backprop_parser(d_states_d_scores):
|
||||
_, d_scores = d_states_d_scores
|
||||
if model.attrs.get("unseen_classes"):
|
||||
# If we have a negative gradient (i.e. the probability should
|
||||
# increase) on any classes we filtered out as unseen, mark
|
||||
# them as seen.
|
||||
for clas in set(model.attrs["unseen_classes"]):
|
||||
if (d_scores[:, clas] < 0).any():
|
||||
model.attrs["unseen_classes"].remove(clas)
|
||||
d_scores *= unseen_mask
|
||||
ids = ops.xp.concatenate(all_ids)
|
||||
statevecs = ops.xp.concatenate(all_statevecs)
|
||||
|
|
|
@ -267,12 +267,6 @@ class Parser(TrainablePipe):
|
|||
gZ = exp_gscores.sum(axis=1, keepdims=True)
|
||||
d_scores = exp_scores / Z
|
||||
d_scores[is_gold] -= exp_gscores / gZ
|
||||
if "unseen_classes" in model.attrs:
|
||||
for i in range(costs.shape[0]):
|
||||
for clas in range(costs.shape[1]):
|
||||
if costs[i, clas] <= best_costs[i, 0]:
|
||||
if clas in model.attrs["unseen_classes"]:
|
||||
model.attrs["unseen_classes"].remove(clas)
|
||||
return d_scores
|
||||
|
||||
def _get_costs_from_histories(self, examples, histories):
|
||||
|
|
Loading…
Reference in New Issue
Block a user