diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 04fbbefc1..4ee41779e 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -49,6 +49,10 @@ cdef class BiluoGold: self.mem = Pool() self.c = create_gold_state(self.mem, moves, stcls, example) + def update(self, StateClass stcls): + update_gold_state(&self.c, stcls) + + cdef GoldNERStateC create_gold_state( Pool mem, @@ -58,30 +62,15 @@ cdef GoldNERStateC create_gold_state( ) except *: cdef GoldNERStateC gs gs.ner = mem.alloc(example.x.length, sizeof(Transition)) - ner_tags = get_aligned_ner(example) + ner_tags = example.get_aligned_ner() for i, ner_tag in enumerate(ner_tags): gs.ner[i] = moves.lookup_transition(ner_tag) return gs -def get_aligned_ner(Example example): - cand_to_gold = example.alignment.cand_to_gold - i2j_multi = example.alignment.i2j_multi - y_tags = biluo_tags_from_offsets( - example.y, - [(e.start_char, e.end_char, e.label_) for e in example.y.ents] - ) - x_tags = [None] * example.x.length - for i in range(example.x.length): - if example.x[i].is_space: - pass - elif cand_to_gold[i] is not None: - x_tags[i] = y_tags[cand_to_gold[i]] - elif i in i2j_multi: - # Assign O/- for many-to-one O/- NER tags - if y_tags[i2j_multi[i]] in ("O", "-"): - x_tags[i] = y_tags[i2j_multi[i]] - return y_tags +cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *: + # We don't need to update each time, unlike the parser. + pass cdef do_func_t[N_MOVES] do_funcs @@ -120,11 +109,12 @@ cdef class BiluoPushDown(TransitionSystem): for action in (BEGIN, IN, LAST, UNIT): actions[action][entity_type] = 1 moves = ('M', 'B', 'I', 'L', 'U') - for example in kwargs.get('gold_parses', []): - for ner_tag in example.get_aligned("ENT_TYPE", as_string=True): - if ner_tag != 'O' and ner_tag != '-': + for example in kwargs.get('examples', []): + for token in example.y: + ent_type = token.ent_type_ + if ent_type: for action in (BEGIN, IN, LAST, UNIT): - actions[action][ner_tag] += 1 + actions[action][ent_type] += 1 return actions @property @@ -247,6 +237,37 @@ cdef class BiluoPushDown(TransitionSystem): self.add_action(UNIT, st._sent[i].ent_type) self.add_action(LAST, st._sent[i].ent_type) + def get_cost(self, StateClass stcls, gold, int i): + if not isinstance(gold, BiluoGold): + raise TypeError("Expected BiluoGold") + cdef BiluoGold gold_ = gold + gold_state = gold_.c + n_gold = 0 + if self.c[i].is_valid(stcls.c, self.c[i].label): + cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + else: + cost = 9000 + return cost + + cdef int set_costs(self, int* is_valid, weight_t* costs, + StateClass stcls, gold) except -1: + if not isinstance(gold, BiluoGold): + raise TypeError("Expected BiluoGold") + cdef BiluoGold gold_ = gold + gold_.update(stcls) + gold_state = gold_.c + n_gold = 0 + for i in range(self.n_moves): + if self.c[i].is_valid(stcls.c, self.c[i].label): + is_valid[i] = True + costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) + n_gold += costs[i] <= 0 + else: + is_valid[i] = False + costs[i] = 9000 + if n_gold < 1: + raise ValueError + cdef class Missing: @staticmethod