mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-13 07:55:49 +03:00
Fix arc_eager oracle
This commit is contained in:
parent
a68d0e63f0
commit
420a986d15
|
@ -200,7 +200,6 @@ cdef class ArcEagerGold:
|
||||||
sent_starts = example.get_aligned("SENT_START")
|
sent_starts = example.get_aligned("SENT_START")
|
||||||
assert len(heads) == len(labels) == len(sent_starts)
|
assert len(heads) == len(labels) == len(sent_starts)
|
||||||
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
|
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
|
||||||
self.update(stcls)
|
|
||||||
|
|
||||||
def update(self, StateClass stcls):
|
def update(self, StateClass stcls):
|
||||||
update_gold_state(&self.c, stcls)
|
update_gold_state(&self.c, stcls)
|
||||||
|
@ -577,17 +576,12 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def is_gold_parse(self, StateClass state, gold):
|
def is_gold_parse(self, StateClass state, gold):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def has_gold(self, gold, start=0, end=None):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def preprocess_gold(self, example):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def init_gold_batch(self, examples):
|
def init_gold_batch(self, examples):
|
||||||
|
examples = [eg for eg in examples if self.has_gold(eg)]
|
||||||
states = self.init_batch([eg.predicted for eg in examples])
|
states = self.init_batch([eg.predicted for eg in examples])
|
||||||
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
||||||
states = [states[i] for i in keeps]
|
|
||||||
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
||||||
|
states = [states[i] for i in keeps]
|
||||||
for gold in golds:
|
for gold in golds:
|
||||||
self._replace_unseen_labels(gold)
|
self._replace_unseen_labels(gold)
|
||||||
n_steps = sum([len(s.queue) * 4 for s in states])
|
n_steps = sum([len(s.queue) * 4 for s in states])
|
||||||
|
@ -690,6 +684,9 @@ cdef class ArcEager(TransitionSystem):
|
||||||
doc.is_parsed = True
|
doc.is_parsed = True
|
||||||
set_children_from_heads(doc.c, doc.length)
|
set_children_from_heads(doc.c, doc.length)
|
||||||
|
|
||||||
|
def has_gold(self, Example eg):
|
||||||
|
return eg.y.is_parsed
|
||||||
|
|
||||||
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = Shift.is_valid(st, 0)
|
is_valid[SHIFT] = Shift.is_valid(st, 0)
|
||||||
|
@ -736,21 +733,29 @@ cdef class ArcEager(TransitionSystem):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example):
|
def get_oracle_sequence(self, Example example):
|
||||||
|
cdef StateClass state
|
||||||
|
cdef ArcEagerGold gold
|
||||||
|
states, golds, n_steps = self.init_gold_batch([example])
|
||||||
|
if not golds:
|
||||||
|
return []
|
||||||
|
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.n_moves > 0
|
assert self.n_moves > 0
|
||||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||||
|
|
||||||
cdef StateClass state
|
|
||||||
cdef ArcEagerGold gold
|
|
||||||
states, golds, n_steps = self.init_gold_batch([example])
|
|
||||||
state = states[0]
|
state = states[0]
|
||||||
gold = golds[0]
|
gold = golds[0]
|
||||||
history = []
|
history = []
|
||||||
debug_log = []
|
debug_log = []
|
||||||
|
failed = False
|
||||||
while not state.is_final():
|
while not state.is_final():
|
||||||
self.set_costs(is_valid, costs, state, gold)
|
try:
|
||||||
|
self.set_costs(is_valid, costs, state, gold)
|
||||||
|
except ValueError:
|
||||||
|
failed = True
|
||||||
|
break
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if is_valid[i] and costs[i] <= 0:
|
if is_valid[i] and costs[i] <= 0:
|
||||||
action = self.c[i]
|
action = self.c[i]
|
||||||
|
@ -766,36 +771,39 @@ cdef class ArcEager(TransitionSystem):
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print("Actions")
|
failed = False
|
||||||
for i in range(self.n_moves):
|
break
|
||||||
print(self.get_class_name(i))
|
if failed:
|
||||||
print("Gold")
|
print("Actions")
|
||||||
for token in example.y:
|
for i in range(self.n_moves):
|
||||||
print(token.i, token.text, token.dep_, token.head.text)
|
print(self.get_class_name(i))
|
||||||
aligned_heads, aligned_labels = example.get_aligned_parse()
|
print("Gold")
|
||||||
print("Aligned heads")
|
for token in example.y:
|
||||||
for i, head in enumerate(aligned_heads):
|
print(token.i, token.text, token.dep_, token.head.text)
|
||||||
print(example.x[i], example.x[head] if head is not None else "__")
|
aligned_heads, aligned_labels = example.get_aligned_parse()
|
||||||
|
print("Aligned heads")
|
||||||
|
for i, head in enumerate(aligned_heads):
|
||||||
|
print(example.x[i], example.x[head] if head is not None else "__")
|
||||||
|
|
||||||
print("Predicted tokens")
|
print("Predicted tokens")
|
||||||
print([(w.i, w.text) for w in example.x])
|
print([(w.i, w.text) for w in example.x])
|
||||||
s0 = state.S(0)
|
s0 = state.S(0)
|
||||||
b0 = state.B(0)
|
b0 = state.B(0)
|
||||||
debug_log.append(" ".join((
|
debug_log.append(" ".join((
|
||||||
"?",
|
"?",
|
||||||
"S0=", (example.x[s0].text if s0 >= 0 else "-"),
|
"S0=", (example.x[s0].text if s0 >= 0 else "-"),
|
||||||
"B0=", (example.x[b0].text if b0 >= 0 else "-"),
|
"B0=", (example.x[b0].text if b0 >= 0 else "-"),
|
||||||
"S0 head?", str(state.has_head(state.S(0))),
|
"S0 head?", str(state.has_head(state.S(0))),
|
||||||
)))
|
)))
|
||||||
s0 = state.S(0)
|
s0 = state.S(0)
|
||||||
b0 = state.B(0)
|
b0 = state.B(0)
|
||||||
print("\n".join(debug_log))
|
print("\n".join(debug_log))
|
||||||
print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0))
|
print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0))
|
||||||
print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0))
|
print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0))
|
||||||
print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0))
|
print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0))
|
||||||
print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0))
|
print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0))
|
||||||
print("b0", b0, "gold.heads[s0]", gold.c.heads[s0])
|
print("b0", b0, "gold.heads[s0]", gold.c.heads[s0])
|
||||||
print("Stack", [example.x[i] for i in state.stack])
|
print("Stack", [example.x[i] for i in state.stack])
|
||||||
print("Buffer", [example.x[i] for i in state.queue])
|
print("Buffer", [example.x[i] for i in state.queue])
|
||||||
raise ValueError(Errors.E024)
|
raise ValueError(Errors.E024)
|
||||||
return history
|
return history
|
||||||
|
|
Loading…
Reference in New Issue
Block a user