mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 21:57:15 +03:00
75f7c15187
* Span/SpanGroup: wrap SpanC in shared_ptr When a Span that was retrieved from a SpanGroup was modified, these changes were not reflected in the SpanGroup because the underlying SpanC struct was copied. This change applies the solution proposed by @nrodnova, to wrap SpanC in a shared_ptr. This makes a SpanGroup and Spans derived from it share the same SpanC. So, changes made through a Span are visible in the SpanGroup as well. Fixes #9556 * Test that a SpanGroup is modified through its Spans * SpanGroup.push_back: remove nogil Modifying std::vector is not thread-safe. * C++ >= 11 does not allow const T in vector<T> * Add Span.span_c as a shorthand for Span.c.get Since this method is cdef'ed, it is only visible from Cython, so we avoid using raw pointers in Python Replace existing uses of span.c.get() to use this new method. * Fix formatting * Style fix: pointer types * SpanGroup.to_bytes: reduce number of shared_ptr::get calls * Mark SpanGroup modification test with issue Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
695 lines
23 KiB
Cython
695 lines
23 KiB
Cython
import os
|
|
import random
|
|
from libc.stdint cimport int32_t
|
|
from libcpp.memory cimport shared_ptr
|
|
from libcpp.vector cimport vector
|
|
from cymem.cymem cimport Pool
|
|
|
|
from collections import Counter
|
|
from thinc.extra.search cimport Beam
|
|
|
|
from ...tokens.doc cimport Doc
|
|
from ...tokens.span import Span
|
|
from ...tokens.span cimport Span
|
|
from ...typedefs cimport weight_t, attr_t
|
|
from ...lexeme cimport Lexeme
|
|
from ...attrs cimport IS_SPACE
|
|
from ...structs cimport TokenC, SpanC
|
|
from ...training.example cimport Example
|
|
from .stateclass cimport StateClass
|
|
from ._state cimport StateC
|
|
from .transition_system cimport Transition, do_func_t
|
|
|
|
from ...errors import Errors
|
|
|
|
|
|
cdef enum:
|
|
MISSING
|
|
BEGIN
|
|
IN
|
|
LAST
|
|
UNIT
|
|
OUT
|
|
N_MOVES
|
|
|
|
|
|
MOVE_NAMES = [None] * N_MOVES
|
|
MOVE_NAMES[MISSING] = 'M'
|
|
MOVE_NAMES[BEGIN] = 'B'
|
|
MOVE_NAMES[IN] = 'I'
|
|
MOVE_NAMES[LAST] = 'L'
|
|
MOVE_NAMES[UNIT] = 'U'
|
|
MOVE_NAMES[OUT] = 'O'
|
|
|
|
|
|
cdef struct GoldNERStateC:
|
|
Transition* ner
|
|
vector[shared_ptr[SpanC]] negs
|
|
|
|
|
|
cdef class BiluoGold:
|
|
cdef Pool mem
|
|
cdef GoldNERStateC c
|
|
|
|
def __init__(self, BiluoPushDown moves, StateClass stcls, Example example, neg_key):
|
|
self.mem = Pool()
|
|
self.c = create_gold_state(self.mem, moves, stcls.c, example, neg_key)
|
|
|
|
def update(self, StateClass stcls):
|
|
update_gold_state(&self.c, stcls.c)
|
|
|
|
|
|
cdef GoldNERStateC create_gold_state(
|
|
Pool mem,
|
|
BiluoPushDown moves,
|
|
const StateC* stcls,
|
|
Example example,
|
|
neg_key
|
|
) except *:
|
|
cdef GoldNERStateC gs
|
|
cdef Span neg
|
|
if neg_key is not None:
|
|
negs = example.get_aligned_spans_y2x(
|
|
example.y.spans.get(neg_key, []),
|
|
allow_overlap=True
|
|
)
|
|
else:
|
|
negs = []
|
|
assert example.x.length > 0
|
|
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
|
|
ner_ents, ner_tags = example.get_aligned_ents_and_ner()
|
|
for i, ner_tag in enumerate(ner_tags):
|
|
gs.ner[i] = moves.lookup_transition(ner_tag)
|
|
|
|
# Prevent conflicting spans in the data. For NER, spans are equal if they have the same offsets and label.
|
|
neg_span_triples = {(neg_ent.start_char, neg_ent.end_char, neg_ent.label) for neg_ent in negs}
|
|
for pos_span in ner_ents:
|
|
if (pos_span.start_char, pos_span.end_char, pos_span.label) in neg_span_triples:
|
|
raise ValueError(Errors.E868.format(span=(pos_span.start_char, pos_span.end_char, pos_span.label_)))
|
|
|
|
# In order to handle negative samples, we need to maintain the full
|
|
# (start, end, label) triple. If we break it down to the 'isnt B-LOC'
|
|
# thing, we'll get blocked if there's an incorrect prefix.
|
|
for neg in negs:
|
|
gs.negs.push_back(neg.c)
|
|
return gs
|
|
|
|
|
|
cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
|
|
# We don't need to update each time, unlike the parser.
|
|
pass
|
|
|
|
|
|
cdef do_func_t[N_MOVES] do_funcs
|
|
|
|
|
|
cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
|
|
if not state.entity_is_open():
|
|
return False
|
|
|
|
cdef const Transition* gold = &golds[state.E(0)]
|
|
ent = state.get_ent()
|
|
if gold.move != BEGIN and gold.move != UNIT:
|
|
return True
|
|
elif gold.label != ent.label:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
cdef class BiluoPushDown(TransitionSystem):
|
|
def __init__(self, *args, **kwargs):
|
|
TransitionSystem.__init__(self, *args, **kwargs)
|
|
|
|
@classmethod
|
|
def get_actions(cls, **kwargs):
|
|
actions = {
|
|
MISSING: Counter(),
|
|
BEGIN: Counter(),
|
|
IN: Counter(),
|
|
LAST: Counter(),
|
|
UNIT: Counter(),
|
|
OUT: Counter()
|
|
}
|
|
actions[OUT][''] = 1 # Represents a token predicted to be outside of any entity
|
|
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
|
|
for entity_type in kwargs.get('entity_types', []):
|
|
for action in (BEGIN, IN, LAST, UNIT):
|
|
actions[action][entity_type] = 1
|
|
moves = ('M', 'B', 'I', 'L', 'U')
|
|
for example in kwargs.get('examples', []):
|
|
for token in example.y:
|
|
ent_type = token.ent_type_
|
|
if ent_type:
|
|
for action in (BEGIN, IN, LAST, UNIT):
|
|
actions[action][ent_type] += 1
|
|
return actions
|
|
|
|
@property
|
|
def action_types(self):
|
|
return (BEGIN, IN, LAST, UNIT, OUT)
|
|
|
|
def get_doc_labels(self, doc):
|
|
labels = set()
|
|
for token in doc:
|
|
if token.ent_type:
|
|
labels.add(token.ent_type_)
|
|
return labels
|
|
|
|
def move_name(self, int move, attr_t label):
|
|
if move == OUT:
|
|
return 'O'
|
|
elif move == MISSING:
|
|
return 'M'
|
|
else:
|
|
return MOVE_NAMES[move] + '-' + self.strings[label]
|
|
|
|
def init_gold_batch(self, examples):
|
|
all_states = self.init_batch([eg.predicted for eg in examples])
|
|
golds = []
|
|
states = []
|
|
for state, eg in zip(all_states, examples):
|
|
if self.has_gold(eg) and not state.is_final():
|
|
golds.append(self.init_gold(state, eg))
|
|
states.append(state)
|
|
n_steps = sum([len(s.queue) for s in states])
|
|
return states, golds, n_steps
|
|
|
|
cdef Transition lookup_transition(self, object name) except *:
|
|
cdef attr_t label
|
|
if name == '-' or name == '' or name is None:
|
|
return Transition(clas=0, move=MISSING, label=0, score=0)
|
|
elif '-' in name:
|
|
move_str, label_str = name.split('-', 1)
|
|
# Deprecated, hacky way to denote 'not this entity'
|
|
if label_str.startswith('!'):
|
|
raise ValueError(Errors.E869.format(label=name))
|
|
label = self.strings.add(label_str)
|
|
else:
|
|
move_str = name
|
|
label = 0
|
|
move = MOVE_NAMES.index(move_str)
|
|
for i in range(self.n_moves):
|
|
if self.c[i].move == move and self.c[i].label == label:
|
|
return self.c[i]
|
|
raise KeyError(Errors.E022.format(name=name))
|
|
|
|
cdef Transition init_transition(self, int clas, int move, attr_t label) except *:
|
|
# TODO: Apparent Cython bug here when we try to use the Transition()
|
|
# constructor with the function pointers
|
|
cdef Transition t
|
|
t.score = 0
|
|
t.clas = clas
|
|
t.move = move
|
|
t.label = label
|
|
if move == MISSING:
|
|
t.is_valid = Missing.is_valid
|
|
t.do = Missing.transition
|
|
t.get_cost = Missing.cost
|
|
elif move == BEGIN:
|
|
t.is_valid = Begin.is_valid
|
|
t.do = Begin.transition
|
|
t.get_cost = Begin.cost
|
|
elif move == IN:
|
|
t.is_valid = In.is_valid
|
|
t.do = In.transition
|
|
t.get_cost = In.cost
|
|
elif move == LAST:
|
|
t.is_valid = Last.is_valid
|
|
t.do = Last.transition
|
|
t.get_cost = Last.cost
|
|
elif move == UNIT:
|
|
t.is_valid = Unit.is_valid
|
|
t.do = Unit.transition
|
|
t.get_cost = Unit.cost
|
|
elif move == OUT:
|
|
t.is_valid = Out.is_valid
|
|
t.do = Out.transition
|
|
t.get_cost = Out.cost
|
|
else:
|
|
raise ValueError(Errors.E019.format(action=move, src='ner'))
|
|
return t
|
|
|
|
def add_action(self, int action, label_name, freq=None):
|
|
cdef attr_t label_id
|
|
if not isinstance(label_name, (int, long)):
|
|
label_id = self.strings.add(label_name)
|
|
else:
|
|
label_id = label_name
|
|
if action == OUT and label_id != 0:
|
|
return None
|
|
if action == MISSING:
|
|
return None
|
|
# Check we're not creating a move we already have, so that this is
|
|
# idempotent
|
|
for trans in self.c[:self.n_moves]:
|
|
if trans.move == action and trans.label == label_id:
|
|
return 0
|
|
if self.n_moves >= self._size:
|
|
self._size = self.n_moves
|
|
self._size *= 2
|
|
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
|
|
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
|
|
self.n_moves += 1
|
|
if self.labels.get(action, []):
|
|
freq = min(0, min(self.labels[action].values()))
|
|
self.labels[action][label_name] = freq-1
|
|
else:
|
|
self.labels[action] = Counter()
|
|
self.labels[action][label_name] = -1
|
|
return 1
|
|
|
|
def set_annotations(self, StateClass state, Doc doc):
|
|
cdef int i
|
|
ents = []
|
|
for i in range(state.c._ents.size()):
|
|
ent = state.c._ents.at(i)
|
|
if ent.start != -1 and ent.end != -1:
|
|
ents.append(Span(doc, ent.start, ent.end, label=ent.label, kb_id=doc.c[ent.start].ent_kb_id))
|
|
doc.set_ents(ents, default="unmodified")
|
|
# Set non-blocked tokens to O
|
|
for i in range(doc.length):
|
|
if doc.c[i].ent_iob == 0:
|
|
doc.c[i].ent_iob = 2
|
|
|
|
def get_beam_parses(self, Beam beam):
|
|
parses = []
|
|
probs = beam.probs
|
|
for i in range(beam.size):
|
|
state = <StateC*>beam.at(i)
|
|
if state.is_final():
|
|
prob = probs[i]
|
|
parse = []
|
|
for j in range(state._ents.size()):
|
|
ent = state._ents.at(j)
|
|
if ent.start != -1 and ent.end != -1:
|
|
parse.append((ent.start, ent.end, self.strings[ent.label]))
|
|
parses.append((prob, parse))
|
|
return parses
|
|
|
|
def init_gold(self, StateClass state, Example example):
|
|
return BiluoGold(self, state, example, self.neg_key)
|
|
|
|
def has_gold(self, Example eg, start=0, end=None):
|
|
# We get x and y referring to X, we want to check relative to Y,
|
|
# the reference
|
|
y_spans = eg.get_aligned_spans_x2y([eg.x[start:end]])
|
|
if not y_spans:
|
|
y_spans = [eg.y[:]]
|
|
y_span = y_spans[0]
|
|
start = y_span.start
|
|
end = y_span.end
|
|
neg_key = self.neg_key
|
|
if neg_key is not None:
|
|
# If we have any negative samples, count that as having annotation.
|
|
for span in eg.y.spans.get(neg_key, []):
|
|
if span.start >= start and span.end <= end:
|
|
return True
|
|
for word in eg.y[start:end]:
|
|
if word.ent_iob != 0:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def get_cost(self, StateClass stcls, gold, int i):
|
|
if not isinstance(gold, BiluoGold):
|
|
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
|
cdef BiluoGold gold_ = gold
|
|
gold_state = gold_.c
|
|
n_gold = 0
|
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
|
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
|
else:
|
|
cost = 9000
|
|
return cost
|
|
|
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
|
const StateC* state, gold) except -1:
|
|
if not isinstance(gold, BiluoGold):
|
|
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
|
cdef BiluoGold gold_ = gold
|
|
gold_state = gold_.c
|
|
update_gold_state(&gold_state, state)
|
|
n_gold = 0
|
|
self.set_valid(is_valid, state)
|
|
for i in range(self.n_moves):
|
|
if is_valid[i]:
|
|
costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
|
|
n_gold += costs[i] <= 0
|
|
else:
|
|
costs[i] = 9000
|
|
|
|
|
|
cdef class Missing:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
return False
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* s, attr_t label) nogil:
|
|
pass
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
return 9000
|
|
|
|
|
|
cdef class Begin:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if st.entity_is_open():
|
|
return False
|
|
if st.buffer_length() < 2:
|
|
# If we're the last token of the input, we can't B -- must U or O.
|
|
return False
|
|
elif label == 0:
|
|
return False
|
|
elif preset_ent_iob == 1:
|
|
# Ensure we don't clobber preset entities. If no entity preset,
|
|
# ent_iob is 0
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
# Okay, we're in a preset entity.
|
|
if label != preset_ent_label:
|
|
# If label isn't right, reject
|
|
return False
|
|
elif st.B_(1).ent_iob != 1:
|
|
# If next token isn't marked I, we need to make U, not B.
|
|
return False
|
|
else:
|
|
# Otherwise, force acceptance, even if we're across a sentence
|
|
# boundary or the token is whitespace.
|
|
return True
|
|
elif st.B_(1).ent_iob == 3:
|
|
# If the next word is B, we can't B now
|
|
return False
|
|
elif st.B_(1).sent_start == 1:
|
|
# Don't allow entities to extend across sentence boundaries
|
|
return False
|
|
# Don't allow entities to start on whitespace
|
|
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.open_ent(label)
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
b0 = s.B(0)
|
|
cdef int cost = 0
|
|
cdef int g_act = gold.ner[b0].move
|
|
cdef attr_t g_tag = gold.ner[b0].label
|
|
|
|
cdef shared_ptr[SpanC] span
|
|
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == BEGIN:
|
|
# B, Gold B --> Label match
|
|
cost += label != g_tag
|
|
else:
|
|
# B, Gold I --> False (P)
|
|
# B, Gold L --> False (P)
|
|
# B, Gold O --> False (P)
|
|
# B, Gold U --> False (P)
|
|
cost += 1
|
|
if s.buffer_length() < 3:
|
|
# Handle negatives. In general we can't really do much to block
|
|
# B, because we don't know whether the whole entity is going to
|
|
# be correct or not. However, we can at least tell whether we're
|
|
# going to be opening an entity where there's only one possible
|
|
# L.
|
|
for span in gold.negs:
|
|
if span.get().label == label and span.get().start == b0:
|
|
cost += 1
|
|
break
|
|
return cost
|
|
|
|
|
|
cdef class In:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
if not st.entity_is_open():
|
|
return False
|
|
if st.buffer_length() < 2:
|
|
# If we're at the end, we can't I.
|
|
return False
|
|
ent = st.get_ent()
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if label == 0:
|
|
return False
|
|
elif ent.label != label:
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
return False
|
|
elif st.B_(1).ent_iob == 3:
|
|
# If we know the next word is B, we can't be I (must be L)
|
|
return False
|
|
elif preset_ent_iob == 1:
|
|
if st.B_(1).ent_iob in (0, 2):
|
|
# if next preset is missing or O, this can't be I (must be L)
|
|
return False
|
|
elif label != preset_ent_label:
|
|
# If label isn't right, reject
|
|
return False
|
|
else:
|
|
# Otherwise, force acceptance, even if we're across a sentence
|
|
# boundary or the token is whitespace.
|
|
return True
|
|
elif st.B(1) != -1 and st.B_(1).sent_start == 1:
|
|
# Don't allow entities to extend across sentence boundaries
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
move = IN
|
|
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
|
cdef int g_act = gold.ner[s.B(0)].move
|
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
|
|
|
if g_act == MISSING:
|
|
return 0
|
|
elif g_act == BEGIN:
|
|
# I, Gold B --> True
|
|
# (P of bad open entity sunk, R of this entity sunk)
|
|
return 0
|
|
elif g_act == IN:
|
|
# I, Gold I --> True
|
|
# (label forced by prev, if mismatch, P and R both sunk)
|
|
return 0
|
|
elif g_act == LAST:
|
|
# I, Gold L --> True iff this entity sunk and next tag == O
|
|
return not (is_sunk and (next_act == OUT or next_act == MISSING))
|
|
elif g_act == OUT:
|
|
# I, Gold O --> True iff next tag == O
|
|
return not (next_act == OUT or next_act == MISSING)
|
|
elif g_act == UNIT:
|
|
# I, Gold U --> True iff next tag == O
|
|
return next_act != OUT
|
|
else:
|
|
return 1
|
|
|
|
cdef class Last:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if label == 0:
|
|
return False
|
|
elif not st.entity_is_open():
|
|
return False
|
|
elif preset_ent_iob == 1 and st.B_(1).ent_iob != 1:
|
|
# If a preset entity has I followed by not-I, is L
|
|
if label != preset_ent_label:
|
|
# If label isn't right, reject
|
|
return False
|
|
else:
|
|
# Otherwise, force acceptance, even if we're across a sentence
|
|
# boundary or the token is whitespace.
|
|
return True
|
|
elif st.get_ent().label != label:
|
|
return False
|
|
elif st.B_(1).ent_iob == 1:
|
|
# If a preset entity has I next, we can't L here.
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.close_ent()
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
move = LAST
|
|
b0 = s.B(0)
|
|
ent_start = s.E(0)
|
|
|
|
cdef int g_act = gold.ner[b0].move
|
|
cdef attr_t g_tag = gold.ner[b0].label
|
|
|
|
cdef int cost = 0
|
|
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == BEGIN:
|
|
# L, Gold B --> True
|
|
pass
|
|
elif g_act == IN:
|
|
# L, Gold I --> True iff this entity sunk
|
|
cost += not _entity_is_sunk(s, gold.ner)
|
|
elif g_act == LAST:
|
|
# L, Gold L --> True
|
|
pass
|
|
elif g_act == OUT:
|
|
# L, Gold O --> True
|
|
pass
|
|
elif g_act == UNIT:
|
|
# L, Gold U --> True
|
|
pass
|
|
else:
|
|
cost += 1
|
|
# If we have negative-example entities, integrate them into the objective,
|
|
# by marking actions that close an entity that we know is incorrect
|
|
# as costly.
|
|
cdef shared_ptr[SpanC] span
|
|
for span in gold.negs:
|
|
if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start:
|
|
cost += 1
|
|
break
|
|
return cost
|
|
|
|
|
|
cdef class Unit:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if label == 0:
|
|
# this is only allowed if it's a preset blocked annotation
|
|
if preset_ent_label == 0 and preset_ent_iob == 3:
|
|
return True
|
|
else:
|
|
return False
|
|
elif st.entity_is_open():
|
|
return False
|
|
elif st.B_(1).ent_iob == 1:
|
|
# If next token is In, we can't be Unit -- must be Begin
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
# Okay, there's a preset entity here
|
|
if label != preset_ent_label:
|
|
# Require labels to match
|
|
return False
|
|
else:
|
|
# Otherwise return True, ignoring the whitespace constraint.
|
|
return True
|
|
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.open_ent(label)
|
|
st.close_ent()
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
cdef int g_act = gold.ner[s.B(0)].move
|
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
cdef int cost = 0
|
|
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == UNIT:
|
|
# U, Gold U --> True iff tag match
|
|
cost += label != g_tag
|
|
else:
|
|
# U, Gold B --> False
|
|
# U, Gold I --> False
|
|
# U, Gold L --> False
|
|
# U, Gold O --> False
|
|
cost += 1
|
|
# If we have negative-example entities, integrate them into the objective.
|
|
# This is fairly straight-forward for U- entities, as we have a single
|
|
# action
|
|
cdef int b0 = s.B(0)
|
|
cdef shared_ptr[SpanC] span
|
|
for span in gold.negs:
|
|
if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1):
|
|
cost += 1
|
|
break
|
|
return cost
|
|
|
|
|
|
|
|
cdef class Out:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
if st.entity_is_open():
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
return False
|
|
elif preset_ent_iob == 1:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
cdef int g_act = gold.ner[s.B(0)].move
|
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
cdef weight_t cost = 0
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == BEGIN:
|
|
# O, Gold B --> False
|
|
cost += 1
|
|
elif g_act == IN:
|
|
# O, Gold I --> True
|
|
pass
|
|
elif g_act == LAST:
|
|
# O, Gold L --> True
|
|
pass
|
|
elif g_act == OUT:
|
|
# O, Gold O --> True
|
|
pass
|
|
elif g_act == UNIT:
|
|
# O, Gold U --> False
|
|
cost += 1
|
|
else:
|
|
cost += 1
|
|
return cost
|