diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index 60d22a1ab..f7b8bc266 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -16,7 +16,6 @@ from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop from ..typedefs cimport weight_t, class_t, hash_t from ..tokens.doc cimport Doc -from ..gold cimport GoldParse from .stateclass cimport StateClass from .transition_system cimport Transition diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 19be95f3f..47dcee156 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -8,7 +8,6 @@ import json from ..typedefs cimport hash_t, attr_t from ..strings cimport hash_string -from ..gold cimport GoldParse, GoldParseC from ..structs cimport TokenC from ..tokens.doc cimport Doc, set_children_from_heads from .stateclass cimport StateClass @@ -49,40 +48,75 @@ MOVE_NAMES[RIGHT] = 'R' MOVE_NAMES[BREAK] = 'B' +cdef enum: + HEAD_ON_STACK = 0 + HEAD_IN_BUFFER + IS_SENT_START + HEAD_UNKNOWN + + +cdef struct GoldParseStateC: + char* state_bits + attr_t* labels + int32_t* heads + int32_t* n_kids_in_buffer + int32_t* n_kids_on_stack + int32_t length + int32_t stride + + +cdef int check_state_flag(char state_bits, char flag) nogil: + cdef char one = 1 + return state_bits & (one << flag) + + +cdef int set_state_flag(char state_bits, char flag, int value) nogil: + cdef char one = 1 + if value: + return state_bits | (one << flag) + else: + return state_bits & ~(one << flag) + + +cdef int is_head_on_stack(GoldParseStateC gold, int i) nogil: + return check_state_gold(gold.state_bits[i], HEAD_ON_STACK) + + +cdef int is_head_in_buffer(GoldParseStateC gold, int i) nogil: + return check_state_gold(gold.state_bits[i], HEAD_IN_BUFFER) + + +cdef int is_sent_start(GoldParseStateC gold, int i) nogil: + return check_state_gold(gold.state_bits[i], IS_SENT_START) + + +cdef int is_head_unknown(GoldParseStateC gold, int i) nogil: + return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN) + + # Helper functions for the arc-eager oracle -cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: +cdef weight_t push_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil: cdef weight_t cost = 0 - cdef int i, S_i - for i in range(stcls.stack_depth()): - S_i = stcls.S(i) - if gold.heads[target] == S_i: - cost += 1 - if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)): - cost += 1 - if BINARY_COSTS and cost >= 1: - return cost - cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0 - return cost - - -cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: - cdef weight_t cost = 0 - cdef int i, B_i - for i in range(stcls.buffer_length()): - B_i = stcls.B(i) - cost += gold.heads[B_i] == target - cost += gold.heads[target] == B_i - if gold.heads[B_i] == B_i or gold.heads[B_i] < target: - break - if BINARY_COSTS and cost >= 1: - return cost + if is_head_in_stack(gold[0], target): + cost += 1 + cost += gold.n_kids_in_buffer[target] if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: cost += 1 return cost -cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil: +cdef weight_t pop_cost(StateClass stcls, const GoldParseStateC* gold, int target) nogil: + cdef weight_t cost = 0 + if is_head_in_buffer(gold[0], target): + cost += 1 + cost += gold[0].n_kids_in_buffer[target] + if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: + cost += 1 + return cost + + +cdef weight_t arc_cost(StateClass stcls, const GoldParseStateC* gold, int head, int child) nogil: if arc_is_gold(gold, head, child): return 0 elif stcls.H(child) == gold.heads[child]: @@ -94,8 +128,8 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c return 0 -cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil: - if not gold.has_dep[child]: +cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil: + if is_head_unknown(gold[0], child): return True elif gold.heads[child] == head: return True @@ -103,8 +137,8 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil: return False -cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil: - if not gold.has_dep[child]: +cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil: + if is_head_unknown(gold[0], child): return True elif label == 0: return True @@ -114,8 +148,9 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe return False -cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: - return gold.heads[word] == word or not gold.has_dep[word] +cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil: + return gold.heads[word] == word or is_head_unknown(gold[0], word) + cdef class Shift: @staticmethod @@ -129,15 +164,16 @@ cdef class Shift: st.fast_forward() @staticmethod - cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil: + gold = _gold return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label) @staticmethod - cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: + cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil: return push_cost(s, gold, s.B(0)) @staticmethod - cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: return 0 @@ -155,26 +191,27 @@ cdef class Reduce: st.fast_forward() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) @staticmethod - cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: - cost = pop_cost(st, gold, st.S(0)) - if not st.has_head(st.S(0)): - # Decrement cost for the arcs e save - for i in range(1, st.stack_depth()): - S_i = st.S(i) - if gold.heads[st.S(0)] == S_i: - cost -= 1 - if gold.heads[S_i] == st.S(0): - cost -= 1 + cdef inline weight_t move_cost(StateClass st, const GoldParseStateC* gold) nogil: + s0 = st.S(0) + cost = pop_cost(st, gold, s0) + return_to_buffer = not st.has_head(s0) + if return_to_buffer: + # Decrement cost for the arcs we save, as we'll be putting this + # back to the buffer + if is_head_in_stack(gold[0], s0): + cost -= 1 + cost -= gold.n_kids_in_stack[s0] if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: cost -= 1 return cost @staticmethod - cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: return 0 @@ -193,49 +230,12 @@ cdef class LeftArc: st.fast_forward() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: - return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) - - @staticmethod - cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: - cdef weight_t cost = 0 - if arc_is_gold(gold, s.B(0), s.S(0)): - # Have a negative cost if we 'recover' from the wrong dependency - return 0 if not s.has_head(s.S(0)) else -1 - else: - # Account for deps we might lose between S0 and stack - if not s.has_head(s.S(0)): - for i in range(1, s.stack_depth()): - cost += gold.heads[s.S(i)] == s.S(0) - cost += gold.heads[s.S(0)] == s.S(i) - return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0)) - - @staticmethod - cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: - return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label) - - -cdef class RightArc: - @staticmethod - cdef bint is_valid(const StateC* st, attr_t label) nogil: - # If there's (perhaps partial) parse pre-set, don't allow cycle. - if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1): - return 0 - sent_start = st._sent[st.B_(0).l_edge].sent_start - return sent_start != 1 and st.H(st.S(0)) != st.B(0) - - @staticmethod - cdef int transition(StateC* st, attr_t label) nogil: - st.add_arc(st.S(0), st.B(0), label) - st.push() - st.fast_forward() - - @staticmethod - cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef inline weight_t cost(StateClass s, const void* gold, attr_t label) nogil: + gold = _gold return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) @staticmethod - cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: + cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil: if arc_is_gold(gold, s.S(0), s.B(0)): return 0 elif s.c.shifted[s.B(0)]: @@ -244,7 +244,7 @@ cdef class RightArc: return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) @staticmethod - cdef weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label) @@ -271,21 +271,19 @@ cdef class Break: st.fast_forward() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold return Break.move_cost(s, gold) + Break.label_cost(s, gold, label) @staticmethod - cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: + cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil: cdef weight_t cost = 0 cdef int i, j, S_i, B_i for i in range(s.stack_depth()): S_i = s.S(i) - for j in range(s.buffer_length()): - B_i = s.B(j) - cost += gold.heads[S_i] == B_i - cost += gold.heads[B_i] == S_i - if cost != 0: - return cost + cost += gold.n_kids_in_buffer[S_i] + if is_head_in_buffer(gold[0], S_i): + cost += 1 # Check for sentence boundary --- if it's here, we can't have any deps # between stack and buffer, so rest of action is irrelevant. s0_root = _get_root(s.S(0), gold) @@ -296,14 +294,16 @@ cdef class Break: return cost + 1 @staticmethod - cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: return 0 -cdef int _get_root(int word, const GoldParseC* gold) nogil: - while gold.heads[word] != word and gold.has_dep[word] and word >= 0: - word = gold.heads[word] - if not gold.has_dep[word]: +cdef int _get_root(int word, const GoldParseStateC* gold) nogil: + if is_head_unset(gold[0], word): return -1 + while gold.heads[word] != word and word >= 0: + word = gold.heads[word] + if is_head_unset(gold[0], word): + return -1 else: return word @@ -378,86 +378,22 @@ cdef class ArcEager(TransitionSystem): def action_types(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) - def get_cost(self, StateClass state, GoldParse gold, action): - cdef Transition t = self.lookup_transition(action) - if not t.is_valid(state.c, t.label): - return 9000 - else: - return t.get_cost(state, &gold.c, t.label) + def get_cost(self, StateClass state, NewExample gold, action): + raise NotImplementedError def transition(self, StateClass state, action): cdef Transition t = self.lookup_transition(action) t.do(state.c, t.label) return state - def is_gold_parse(self, StateClass state, GoldParse gold): - predicted = set() - truth = set() - for i in range(gold.length): - if gold.cand_to_gold[i] is None: - continue - if state.safe_get(i).dep: - predicted.add((i, state.H(i), - self.strings[state.safe_get(i).dep])) - else: - predicted.add((i, state.H(i), 'ROOT')) - id_ = gold.orig.ids[gold.cand_to_gold[i]] - head = gold.orig.heads[gold.cand_to_gold[i]] - dep = gold.orig.deps[gold.cand_to_gold[i]] - truth.add((id_, head, dep)) - return truth == predicted + def is_gold_parse(self, StateClass state, gold): + raise NotImplementedError - def has_gold(self, GoldParse gold, start=0, end=None): - end = end or len(gold.heads) - if all([tag is None for tag in gold.heads[start:end]]): - return False - else: - return True + def has_gold(self, gold, start=0, end=None): + raise NotImplementedError - def preprocess_gold(self, GoldParse gold): - if not self.has_gold(gold): - return None - # Figure out whether we're using subtok - use_subtok = False - for action, labels in self.labels.items(): - if SUBTOK_LABEL in labels: - use_subtok = True - break - for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)): - # Missing values - if head is None or dep is None: - gold.c.heads[i] = i - gold.c.has_dep[i] = False - elif dep == SUBTOK_LABEL and not use_subtok: - # If we're not doing the joint tokenization and parsing, - # regard these subtok labels as missing - gold.c.heads[i] = i - gold.c.labels[i] = 0 - gold.c.has_dep[i] = False - else: - if head > i: - action = LEFT - elif head < i: - action = RIGHT - else: - action = BREAK - if dep not in self.labels[action]: - if action == BREAK: - dep = 'ROOT' - elif nonproj.is_decorated(dep): - backoff = nonproj.decompose(dep)[0] - if backoff in self.labels[action]: - dep = backoff - else: - dep = 'dep' - else: - dep = 'dep' - gold.c.has_dep[i] = True - if dep.upper() == 'ROOT': - dep = 'ROOT' - gold.c.heads[i] = head - gold.c.labels[i] = self.strings.add(dep) - return gold + def preprocess_gold(self, gold): + raise NotImplementedError def get_beam_parses(self, Beam beam): parses = [] @@ -569,7 +505,9 @@ cdef class ArcEager(TransitionSystem): output[i] = is_valid[self.c[i].move] cdef int set_costs(self, int* is_valid, weight_t* costs, - StateClass stcls, GoldParse gold) except -1: + StateClass stcls, NewExample example) except -1: + cdef Pool mem = Pool() + gold_state = create_gold_state(mem, stcls, example) cdef int i, move cdef attr_t label cdef label_cost_func_t[N_MOVES] label_cost_funcs @@ -599,8 +537,8 @@ cdef class ArcEager(TransitionSystem): move = self.c[i].move label = self.c[i].label if move_costs[move] == 9000: - move_costs[move] = move_cost_funcs[move](stcls, &gold.c) - costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) + move_costs[move] = move_cost_funcs[move](stcls, gold_state) + costs[i] = move_costs[move] + label_cost_funcs[move](stcls, gold_state, label) n_gold += costs[i] <= 0 else: is_valid[i] = False diff --git a/spacy/syntax/ner.pxd b/spacy/syntax/ner.pxd index 647f98fc0..989593a92 100644 --- a/spacy/syntax/ner.pxd +++ b/spacy/syntax/ner.pxd @@ -1,6 +1,5 @@ from .transition_system cimport TransitionSystem from .transition_system cimport Transition -from ..gold cimport GoldParseC from ..typedefs cimport attr_t diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index ff74be601..60deaad28 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -7,11 +7,11 @@ from .stateclass cimport StateClass from ._state cimport StateC from .transition_system cimport Transition from .transition_system cimport do_func_t -from ..gold cimport GoldParseC, GoldParse from ..lexeme cimport Lexeme from ..attrs cimport IS_SPACE from ..errors import Errors +from .gold_parse cimport GoldParseC cdef enum: @@ -91,19 +91,11 @@ cdef class BiluoPushDown(TransitionSystem): else: return MOVE_NAMES[move] + '-' + self.strings[label] - def has_gold(self, GoldParse gold, start=0, end=None): - end = end or len(gold.ner) - if all([tag in ('-', None) for tag in gold.ner[start:end]]): - return False - else: - return True + def has_gold(self, gold, start=0, end=None): + raise NotImplementedError - def preprocess_gold(self, GoldParse gold): - if not self.has_gold(gold): - return None - for i in range(gold.length): - gold.c.ner[i] = self.lookup_transition(gold.ner[i]) - return gold + def preprocess_gold(self, gold): + raise NotImplementedError def get_beam_annot(self, Beam beam): entities = {} @@ -248,7 +240,7 @@ cdef class Missing: pass @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: return 9000 @@ -300,7 +292,8 @@ cdef class Begin: st.pop() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label @@ -363,7 +356,8 @@ cdef class In: st.pop() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold move = IN cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT cdef int g_act = gold.ner[s.B(0)].move @@ -429,7 +423,8 @@ cdef class Last: st.pop() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold move = LAST cdef int g_act = gold.ner[s.B(0)].move @@ -497,7 +492,8 @@ cdef class Unit: st.pop() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label @@ -537,7 +533,8 @@ cdef class Out: st.pop() @staticmethod - cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: + cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: + gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 14a0d7741..9720b98d8 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -5,6 +5,7 @@ from ..structs cimport TokenC from ..strings cimport StringStore from .stateclass cimport StateClass from ._state cimport StateC +from ..gold.new_example cimport NewExample cdef struct Transition: