mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Replace unseen labels for parser
This commit is contained in:
parent
c65f0ed8f6
commit
53931be9a1
|
@ -617,9 +617,29 @@ cdef class ArcEager(TransitionSystem):
|
||||||
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
||||||
states = [states[i] for i in keeps]
|
states = [states[i] for i in keeps]
|
||||||
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
||||||
|
for gold in golds:
|
||||||
|
self._replace_unseen_labels(gold)
|
||||||
n_steps = sum([len(s.queue) * 4 for s in states])
|
n_steps = sum([len(s.queue) * 4 for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, n_steps
|
||||||
|
|
||||||
|
def _replace_unseen_labels(self, ArcEagerGold gold):
|
||||||
|
backoff_label = self.strings["dep"]
|
||||||
|
root_label = self.strings["ROOT"]
|
||||||
|
left_labels = self.labels[LEFT]
|
||||||
|
right_labels = self.labels[RIGHT]
|
||||||
|
break_labels = self.labels[BREAK]
|
||||||
|
for i in range(gold.c.length):
|
||||||
|
if not is_head_unknown(&gold.c, i):
|
||||||
|
head = gold.c.heads[i]
|
||||||
|
label = self.strings[gold.c.labels[i]]
|
||||||
|
if head > i and label not in left_labels:
|
||||||
|
gold.c.labels[i] = backoff_label
|
||||||
|
elif head < i and label not in right_labels:
|
||||||
|
gold.c.labels[i] = backoff_label
|
||||||
|
elif head == i and label not in break_labels:
|
||||||
|
gold.c.labels[i] = root_label
|
||||||
|
return gold
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||||
if isinstance(name_or_id, int):
|
if isinstance(name_or_id, int):
|
||||||
return self.c[name_or_id]
|
return self.c[name_or_id]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user