import os import random from libc.stdint cimport int32_t from libcpp.memory cimport shared_ptr from libcpp.vector cimport vector from cymem.cymem cimport Pool from collections import Counter from thinc.extra.search cimport Beam from ...tokens.doc cimport Doc from ...tokens.span import Span from ...tokens.span cimport Span from ...typedefs cimport weight_t, attr_t from ...lexeme cimport Lexeme from ...attrs cimport IS_SPACE from ...structs cimport TokenC, SpanC from ...training import split_bilu_label from ...training.example cimport Example from .stateclass cimport StateClass from ._state cimport StateC from .transition_system cimport Transition, do_func_t from ...errors import Errors cdef enum: MISSING BEGIN IN LAST UNIT OUT N_MOVES MOVE_NAMES = [None] * N_MOVES MOVE_NAMES[MISSING] = 'M' MOVE_NAMES[BEGIN] = 'B' MOVE_NAMES[IN] = 'I' MOVE_NAMES[LAST] = 'L' MOVE_NAMES[UNIT] = 'U' MOVE_NAMES[OUT] = 'O' cdef struct GoldNERStateC: Transition* ner vector[shared_ptr[SpanC]] negs cdef class BiluoGold: cdef Pool mem cdef GoldNERStateC c def __init__(self, BiluoPushDown moves, StateClass stcls, Example example, neg_key): self.mem = Pool() self.c = create_gold_state(self.mem, moves, stcls.c, example, neg_key) def update(self, StateClass stcls): update_gold_state(&self.c, stcls.c) cdef GoldNERStateC create_gold_state( Pool mem, BiluoPushDown moves, const StateC* stcls, Example example, neg_key ) except *: cdef GoldNERStateC gs cdef Span neg if neg_key is not None: negs = example.get_aligned_spans_y2x( example.y.spans.get(neg_key, []), allow_overlap=True ) else: negs = [] assert example.x.length > 0 gs.ner = mem.alloc(example.x.length, sizeof(Transition)) ner_ents, ner_tags = example.get_aligned_ents_and_ner() for i, ner_tag in enumerate(ner_tags): gs.ner[i] = moves.lookup_transition(ner_tag) # Prevent conflicting spans in the data. For NER, spans are equal if they have the same offsets and label. neg_span_triples = {(neg_ent.start_char, neg_ent.end_char, neg_ent.label) for neg_ent in negs} for pos_span in ner_ents: if (pos_span.start_char, pos_span.end_char, pos_span.label) in neg_span_triples: raise ValueError(Errors.E868.format(span=(pos_span.start_char, pos_span.end_char, pos_span.label_))) # In order to handle negative samples, we need to maintain the full # (start, end, label) triple. If we break it down to the 'isnt B-LOC' # thing, we'll get blocked if there's an incorrect prefix. for neg in negs: gs.negs.push_back(neg.c) return gs cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *: # We don't need to update each time, unlike the parser. pass cdef do_func_t[N_MOVES] do_funcs cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil: if not state.entity_is_open(): return False cdef const Transition* gold = &golds[state.E(0)] ent = state.get_ent() if gold.move != BEGIN and gold.move != UNIT: return True elif gold.label != ent.label: return True else: return False cdef class BiluoPushDown(TransitionSystem): def __init__(self, *args, **kwargs): TransitionSystem.__init__(self, *args, **kwargs) @classmethod def get_actions(cls, **kwargs): actions = { MISSING: Counter(), BEGIN: Counter(), IN: Counter(), LAST: Counter(), UNIT: Counter(), OUT: Counter() } actions[OUT][''] = 1 # Represents a token predicted to be outside of any entity actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity for entity_type in kwargs.get('entity_types', []): for action in (BEGIN, IN, LAST, UNIT): actions[action][entity_type] = 1 moves = ('M', 'B', 'I', 'L', 'U') for example in kwargs.get('examples', []): for token in example.y: ent_type = token.ent_type_ if ent_type: for action in (BEGIN, IN, LAST, UNIT): actions[action][ent_type] += 1 return actions @property def action_types(self): return (BEGIN, IN, LAST, UNIT, OUT) def get_doc_labels(self, doc): labels = set() for token in doc: if token.ent_type: labels.add(token.ent_type_) return labels def move_name(self, int move, attr_t label): if move == OUT: return 'O' elif move == MISSING: return 'M' else: return MOVE_NAMES[move] + '-' + self.strings[label] def init_gold_batch(self, examples): all_states = self.init_batch([eg.predicted for eg in examples]) golds = [] states = [] for state, eg in zip(all_states, examples): if self.has_gold(eg) and not state.is_final(): golds.append(self.init_gold(state, eg)) states.append(state) n_steps = sum([len(s.queue) for s in states]) return states, golds, n_steps cdef Transition lookup_transition(self, object name) except *: cdef attr_t label if name == '-' or name == '' or name is None: return Transition(clas=0, move=MISSING, label=0, score=0) elif '-' in name: move_str, label_str = split_bilu_label(name) # Deprecated, hacky way to denote 'not this entity' if label_str.startswith('!'): raise ValueError(Errors.E869.format(label=name)) label = self.strings.add(label_str) else: move_str = name label = 0 move = MOVE_NAMES.index(move_str) for i in range(self.n_moves): if self.c[i].move == move and self.c[i].label == label: return self.c[i] raise KeyError(Errors.E022.format(name=name)) cdef Transition init_transition(self, int clas, int move, attr_t label) except *: # TODO: Apparent Cython bug here when we try to use the Transition() # constructor with the function pointers cdef Transition t t.score = 0 t.clas = clas t.move = move t.label = label if move == MISSING: t.is_valid = Missing.is_valid t.do = Missing.transition t.get_cost = Missing.cost elif move == BEGIN: t.is_valid = Begin.is_valid t.do = Begin.transition t.get_cost = Begin.cost elif move == IN: t.is_valid = In.is_valid t.do = In.transition t.get_cost = In.cost elif move == LAST: t.is_valid = Last.is_valid t.do = Last.transition t.get_cost = Last.cost elif move == UNIT: t.is_valid = Unit.is_valid t.do = Unit.transition t.get_cost = Unit.cost elif move == OUT: t.is_valid = Out.is_valid t.do = Out.transition t.get_cost = Out.cost else: raise ValueError(Errors.E019.format(action=move, src='ner')) return t def add_action(self, int action, label_name, freq=None): cdef attr_t label_id if not isinstance(label_name, (int, long)): label_id = self.strings.add(label_name) else: label_id = label_name if action == OUT and label_id != 0: return None if action == MISSING: return None # Check we're not creating a move we already have, so that this is # idempotent for trans in self.c[:self.n_moves]: if trans.move == action and trans.label == label_id: return 0 if self.n_moves >= self._size: self._size = self.n_moves self._size *= 2 self.c = self.mem.realloc(self.c, self._size * sizeof(self.c[0])) self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id) self.n_moves += 1 if self.labels.get(action, []): freq = min(0, min(self.labels[action].values())) self.labels[action][label_name] = freq-1 else: self.labels[action] = Counter() self.labels[action][label_name] = -1 return 1 def set_annotations(self, StateClass state, Doc doc): cdef int i ents = [] for i in range(state.c._ents.size()): ent = state.c._ents.at(i) if ent.start != -1 and ent.end != -1: ents.append(Span(doc, ent.start, ent.end, label=ent.label, kb_id=doc.c[ent.start].ent_kb_id)) doc.set_ents(ents, default="unmodified") # Set non-blocked tokens to O for i in range(doc.length): if doc.c[i].ent_iob == 0: doc.c[i].ent_iob = 2 def get_beam_parses(self, Beam beam): parses = [] probs = beam.probs for i in range(beam.size): state = beam.at(i) if state.is_final(): prob = probs[i] parse = [] for j in range(state._ents.size()): ent = state._ents.at(j) if ent.start != -1 and ent.end != -1: parse.append((ent.start, ent.end, self.strings[ent.label])) parses.append((prob, parse)) return parses def init_gold(self, StateClass state, Example example): return BiluoGold(self, state, example, self.neg_key) def has_gold(self, Example eg, start=0, end=None): # We get x and y referring to X, we want to check relative to Y, # the reference y_spans = eg.get_aligned_spans_x2y([eg.x[start:end]]) if not y_spans: y_spans = [eg.y[:]] y_span = y_spans[0] start = y_span.start end = y_span.end neg_key = self.neg_key if neg_key is not None: # If we have any negative samples, count that as having annotation. for span in eg.y.spans.get(neg_key, []): if span.start >= start and span.end <= end: return True for word in eg.y[start:end]: if word.ent_iob != 0: return True else: return False def get_cost(self, StateClass stcls, gold, int i): if not isinstance(gold, BiluoGold): raise TypeError(Errors.E909.format(name="BiluoGold")) cdef BiluoGold gold_ = gold gold_state = gold_.c n_gold = 0 if self.c[i].is_valid(stcls.c, self.c[i].label): cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label) else: cost = 9000 return cost cdef int set_costs(self, int* is_valid, weight_t* costs, const StateC* state, gold) except -1: if not isinstance(gold, BiluoGold): raise TypeError(Errors.E909.format(name="BiluoGold")) cdef BiluoGold gold_ = gold gold_state = gold_.c update_gold_state(&gold_state, state) n_gold = 0 self.set_valid(is_valid, state) for i in range(self.n_moves): if is_valid[i]: costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label) n_gold += costs[i] <= 0 else: costs[i] = 9000 cdef class Missing: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: return False @staticmethod cdef int transition(StateC* s, attr_t label) nogil: pass @staticmethod cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: return 9000 cdef class Begin: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if st.entity_is_open(): return False if st.buffer_length() < 2: # If we're the last token of the input, we can't B -- must U or O. return False elif label == 0: return False elif preset_ent_iob == 1: # Ensure we don't clobber preset entities. If no entity preset, # ent_iob is 0 return False elif preset_ent_iob == 3: # Okay, we're in a preset entity. if label != preset_ent_label: # If label isn't right, reject return False elif st.B_(1).ent_iob != 1: # If next token isn't marked I, we need to make U, not B. return False else: # Otherwise, force acceptance, even if we're across a sentence # boundary or the token is whitespace. return True elif st.B_(1).ent_iob == 3: # If the next word is B, we can't B now return False elif st.B_(1).sent_start == 1: # Don't allow entities to extend across sentence boundaries return False # Don't allow entities to start on whitespace elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE): return False else: return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.open_ent(label) st.push() st.pop() @staticmethod cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold b0 = s.B(0) cdef int cost = 0 cdef int g_act = gold.ner[b0].move cdef attr_t g_tag = gold.ner[b0].label cdef shared_ptr[SpanC] span if g_act == MISSING: pass elif g_act == BEGIN: # B, Gold B --> Label match cost += label != g_tag else: # B, Gold I --> False (P) # B, Gold L --> False (P) # B, Gold O --> False (P) # B, Gold U --> False (P) cost += 1 if s.buffer_length() < 3: # Handle negatives. In general we can't really do much to block # B, because we don't know whether the whole entity is going to # be correct or not. However, we can at least tell whether we're # going to be opening an entity where there's only one possible # L. for span in gold.negs: if span.get().label == label and span.get().start == b0: cost += 1 break return cost cdef class In: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: if not st.entity_is_open(): return False if st.buffer_length() < 2: # If we're at the end, we can't I. return False ent = st.get_ent() cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if label == 0: return False elif ent.label != label: return False elif preset_ent_iob == 3: return False elif st.B_(1).ent_iob == 3: # If we know the next word is B, we can't be I (must be L) return False elif preset_ent_iob == 1: if st.B_(1).ent_iob in (0, 2): # if next preset is missing or O, this can't be I (must be L) return False elif label != preset_ent_label: # If label isn't right, reject return False else: # Otherwise, force acceptance, even if we're across a sentence # boundary or the token is whitespace. return True elif st.B(1) != -1 and st.B_(1).sent_start == 1: # Don't allow entities to extend across sentence boundaries return False else: return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.push() st.pop() @staticmethod cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold move = IN cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label cdef bint is_sunk = _entity_is_sunk(s, gold.ner) if g_act == MISSING: return 0 elif g_act == BEGIN: # I, Gold B --> True # (P of bad open entity sunk, R of this entity sunk) return 0 elif g_act == IN: # I, Gold I --> True # (label forced by prev, if mismatch, P and R both sunk) return 0 elif g_act == LAST: # I, Gold L --> True iff this entity sunk and next tag == O return not (is_sunk and (next_act == OUT or next_act == MISSING)) elif g_act == OUT: # I, Gold O --> True iff next tag == O return not (next_act == OUT or next_act == MISSING) elif g_act == UNIT: # I, Gold U --> True iff next tag == O return next_act != OUT else: return 1 cdef class Last: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if label == 0: return False elif not st.entity_is_open(): return False elif preset_ent_iob == 1 and st.B_(1).ent_iob != 1: # If a preset entity has I followed by not-I, is L if label != preset_ent_label: # If label isn't right, reject return False else: # Otherwise, force acceptance, even if we're across a sentence # boundary or the token is whitespace. return True elif st.get_ent().label != label: return False elif st.B_(1).ent_iob == 1: # If a preset entity has I next, we can't L here. return False else: return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.close_ent() st.push() st.pop() @staticmethod cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold move = LAST b0 = s.B(0) ent_start = s.E(0) cdef int g_act = gold.ner[b0].move cdef attr_t g_tag = gold.ner[b0].label cdef int cost = 0 if g_act == MISSING: pass elif g_act == BEGIN: # L, Gold B --> True pass elif g_act == IN: # L, Gold I --> True iff this entity sunk cost += not _entity_is_sunk(s, gold.ner) elif g_act == LAST: # L, Gold L --> True pass elif g_act == OUT: # L, Gold O --> True pass elif g_act == UNIT: # L, Gold U --> True pass else: cost += 1 # If we have negative-example entities, integrate them into the objective, # by marking actions that close an entity that we know is incorrect # as costly. cdef shared_ptr[SpanC] span for span in gold.negs: if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start: cost += 1 break return cost cdef class Unit: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob cdef attr_t preset_ent_label = st.B_(0).ent_type if label == 0: # this is only allowed if it's a preset blocked annotation if preset_ent_label == 0 and preset_ent_iob == 3: return True else: return False elif st.entity_is_open(): return False elif st.B_(1).ent_iob == 1: # If next token is In, we can't be Unit -- must be Begin return False elif preset_ent_iob == 3: # Okay, there's a preset entity here if label != preset_ent_label: # Require labels to match return False else: # Otherwise return True, ignoring the whitespace constraint. return True elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE): return False else: return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.open_ent(label) st.close_ent() st.push() st.pop() @staticmethod cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label cdef int cost = 0 if g_act == MISSING: pass elif g_act == UNIT: # U, Gold U --> True iff tag match cost += label != g_tag else: # U, Gold B --> False # U, Gold I --> False # U, Gold L --> False # U, Gold O --> False cost += 1 # If we have negative-example entities, integrate them into the objective. # This is fairly straight-forward for U- entities, as we have a single # action cdef int b0 = s.B(0) cdef shared_ptr[SpanC] span for span in gold.negs: if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1): cost += 1 break return cost cdef class Out: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob if st.entity_is_open(): return False elif preset_ent_iob == 3: return False elif preset_ent_iob == 1: return False else: return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.push() st.pop() @staticmethod cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil: gold = _gold cdef int g_act = gold.ner[s.B(0)].move cdef attr_t g_tag = gold.ner[s.B(0)].label cdef weight_t cost = 0 if g_act == MISSING: pass elif g_act == BEGIN: # O, Gold B --> False cost += 1 elif g_act == IN: # O, Gold I --> True pass elif g_act == LAST: # O, Gold L --> True pass elif g_act == OUT: # O, Gold O --> True pass elif g_act == UNIT: # O, Gold U --> False cost += 1 else: cost += 1 return cost