From 53931be9a18cb0c7dfda410ac667c49499244992 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 22 Jun 2020 16:00:45 +0200 Subject: [PATCH] Replace unseen labels for parser --- spacy/syntax/arc_eager.pyx | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 58427a3a8..28787f97d 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -617,9 +617,29 @@ cdef class ArcEager(TransitionSystem): keeps = [i for i, s in enumerate(states) if not s.is_final()] states = [states[i] for i in keeps] golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] + for gold in golds: + self._replace_unseen_labels(gold) n_steps = sum([len(s.queue) * 4 for s in states]) return states, golds, n_steps + def _replace_unseen_labels(self, ArcEagerGold gold): + backoff_label = self.strings["dep"] + root_label = self.strings["ROOT"] + left_labels = self.labels[LEFT] + right_labels = self.labels[RIGHT] + break_labels = self.labels[BREAK] + for i in range(gold.c.length): + if not is_head_unknown(&gold.c, i): + head = gold.c.heads[i] + label = self.strings[gold.c.labels[i]] + if head > i and label not in left_labels: + gold.c.labels[i] = backoff_label + elif head < i and label not in right_labels: + gold.c.labels[i] = backoff_label + elif head == i and label not in break_labels: + gold.c.labels[i] = root_label + return gold + cdef Transition lookup_transition(self, object name_or_id) except *: if isinstance(name_or_id, int): return self.c[name_or_id]