From ecf192aa705b740484e0110288faedd5bdfcb41d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 Jun 2020 17:17:34 +0200 Subject: [PATCH] Use get_aligned_parse in ArcEager --- spacy/syntax/arc_eager.pyx | 44 +++++++++++++++----------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index c7ecbceea..3d9071bcb 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -75,31 +75,20 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp gs.n_kids_in_buffer = mem.alloc(gs.length, sizeof(gs.n_kids_in_buffer[0])) gs.n_kids_in_stack = mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) - cand_to_gold = example.alignment.cand_to_gold - gold_to_cand = example.alignment.cand_to_gold + heads, labels = example.get_aligned_parse(projectivize=True) cdef TokenC ref_tok - for cand_i in range(example.x.length): - gold_i = cand_to_gold[cand_i] - if gold_i is not None: # Alignment found - ref_tok = example.y.c[gold_i] - gold_head = gold_to_cand[gold_i + ref_tok.head] - if gold_head is not None: - gs.heads[cand_i] = gold_head - gs.labels[cand_i] = ref_tok.dep - gs.state_bits[cand_i] = set_state_flag( - gs.state_bits[cand_i], - HEAD_UNKNOWN, - 0 - ) - else: - gs.state_bits[cand_i] = set_state_flag( - gs.state_bits[cand_i], - HEAD_UNKNOWN, - 1 - ) + for i, (head, label) in enumerate(zip(heads, labels)): + if head is not None: + gs.heads[i] = head + gs.labels[i] = example.x.vocab.strings.add(label) + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], + HEAD_UNKNOWN, + 0 + ) else: - gs.state_bits[cand_i] = set_state_flag( - gs.state_bits[cand_i], + gs.state_bits[i] = set_state_flag( + gs.state_bits[i], HEAD_UNKNOWN, 1 ) @@ -529,10 +518,11 @@ cdef class ArcEager(TransitionSystem): for label in kwargs.get('right_labels', []): actions[RIGHT][label] = 1 actions[REDUCE][label] = 1 - for example in kwargs.get('gold_parses', []): - heads, labels = nonproj.projectivize(example.get_aligned("HEAD"), - example.get_aligned("DEP")) - for child, head, label in zip(example.get_aligned("ID"), heads, labels): + for example in kwargs.get('examples', []): + heads, labels = example.get_aligned_parse(projectivize=True) + for child, (head, label) in enumerate(zip(heads, labels)): + if head is None or label is None: + continue if label.upper() == 'ROOT' : label = 'ROOT' if head == child: