From d765a4f8ee81d4dacb41044344f35a5ed5972e05 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 25 Oct 2021 22:34:29 +0200 Subject: [PATCH] Cleaner handling of unseen classes --- spacy/ml/tb_framework.py | 7 +++++++ spacy/pipeline/transition_parser.pyx | 6 ------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 9cb93c9a2..006d5a384 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -151,6 +151,13 @@ def forward(model, docs_moves, is_train): def backprop_parser(d_states_d_scores): _, d_scores = d_states_d_scores + if model.attrs.get("unseen_classes"): + # If we have a negative gradient (i.e. the probability should + # increase) on any classes we filtered out as unseen, mark + # them as seen. + for clas in set(model.attrs["unseen_classes"]): + if (d_scores[:, clas] < 0).any(): + model.attrs["unseen_classes"].remove(clas) d_scores *= unseen_mask ids = ops.xp.concatenate(all_ids) statevecs = ops.xp.concatenate(all_statevecs) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 1bf2140ab..c86a32a12 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -267,12 +267,6 @@ class Parser(TrainablePipe): gZ = exp_gscores.sum(axis=1, keepdims=True) d_scores = exp_scores / Z d_scores[is_gold] -= exp_gscores / gZ - if "unseen_classes" in model.attrs: - for i in range(costs.shape[0]): - for clas in range(costs.shape[1]): - if costs[i, clas] <= best_costs[i, 0]: - if clas in model.attrs["unseen_classes"]: - model.attrs["unseen_classes"].remove(clas) return d_scores def _get_costs_from_histories(self, examples, histories):