mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 13:17:06 +03:00
a183db3cef
* Try to fix doc.copy * Set dev version * Make vocab always own lexemes * Change version * Add SpanGroups.copy method * Fix set_annotations during Parser.update * Fix dict proxy copy * Upd version * Fix copying SpanGroups * Fix set_annotations in parser.update * Fix parser set_annotations during update * Revert "Fix parser set_annotations during update" This reverts commiteb138c89ed
. * Revert "Fix set_annotations in parser.update" This reverts commitc6df0eafd0
. * Fix set_annotations during parser update * Inc version * Handle final states in get_oracle_sequence * Inc version * Try to fix parser training * Inc version * Fix * Inc version * Fix parser oracle * Inc version * Inc version * Fix transition has_gold * Inc version * Try to use real histories, not oracle * Inc version * Upd parser * Inc version * WIP on rewrite parser * WIP refactor parser * New progress on parser model refactor * Prepare to remove parser_model.pyx * Convert parser from cdef class * Delete spacy.ml.parser_model * Delete _precomputable_affine module * Wire up tb_framework to new parser model * Wire up parser model * Uncython ner.pyx and dep_parser.pyx * Uncython * Work on parser model * Support unseen_classes in parser model * Support unseen classes in parser * Cleaner handling of unseen classes * Work through tests * Keep working through errors * Keep working through errors * Work on parser. 15 tests failing * Xfail beam stuff. 9 failures * More xfail. 7 failures * Xfail. 6 failures * cleanup * formatting * fixes * pass nO through * Fix empty doc in update * Hackishly fix resizing. 3 failures * Fix redundant test. 2 failures * Add reference version * black formatting * Get tests passing with reference implementation * Fix missing prints * Add missing file * Improve indexing on reference implementation * Get non-reference forward func working * Start rigging beam back up * removing redundant tests, cf #8106 * black formatting * temporarily xfailing issue 4314 * make flake8 happy again * mypy fixes * ensure labels are added upon predict * cleanup remnants from merge conflicts * Improve unseen label masking Two changes to speed up masking by ~10%: - Use a bool array rather than an array of float32. - Let the mask indicate whether a label was seen, rather than unseen. The mask is most frequently used to index scores for seen labels. However, since the mask marked unseen labels, this required computing an intermittent flipped mask. * Write moves costs directly into numpy array (#10163) This avoids elementwise indexing and the allocation of an additional array. Gives a ~15% speed improvement when using batch_by_sequence with size 32. * Temporarily disable ner and rehearse tests Until rehearse is implemented again in the refactored parser. * Fix loss serialization issue (#10600) * Fix loss serialization issue Serialization of a model fails with: TypeError: array(738.3855, dtype=float32) is not JSON serializable Fix this using float conversion. * Disable CI steps that require spacy.TransitionBasedParser.v2 After finishing the refactor, TransitionBasedParser.v2 should be provided for backwards compat. * Add back support for beam parsing to the refactored parser (#10633) * Add back support for beam parsing Beam parsing was already implemented as part of the `BeamBatch` class. This change makes its counterpart `GreedyBatch`. Both classes are hooked up in `TransitionModel`, selecting `GreedyBatch` when the beam size is one, or `BeamBatch` otherwise. * Use kwarg for beam width Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Avoid implicit default for beam_width and beam_density * Parser.{beam,greedy}_parse: ensure labels are added * Remove 'deprecated' comments Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Parser `StateC` optimizations (#10746) * `StateC`: Optimizations Avoid GIL acquisition in `__init__` Increase default buffer capacities on init Reduce C++ exception overhead * Fix typo * Replace `set::count` with `set::find` * Add exception attribute to c'tor * Remove unused import * Use a power-of-two value for initial capacity Use default-insert to init `_heads` and `_unshiftable` * Merge `cdef` variable declarations and assignments * Vectorize `example.get_aligned_parses` (#10789) * `example`: Vectorize `get_aligned_parse` Rename `numpy` import * Convert aligned array to lists before returning * Revert import renaming * Elide slice arguments when selecting the entire range * Tagger/morphologizer alignment performance optimizations (#10798) * `example`: Unwrap `numpy` scalar arrays before passing them to `StringStore.__getitem__` * `AlignmentArray`: Use native list as staging buffer for offset calculation * `example`: Vectorize `get_aligned` * Hoist inner functions out of `get_aligned` * Replace inline `if..else` clause in assignment statement * `AlignmentArray`: Use raw indexing into offset and data `numpy` arrays * `example`: Replace array unique value check with `groupby` * `example`: Correctly exclude tokens with no alignment in `_get_aligned_vectorized` Simplify `_get_aligned_non_vectorized` * `util`: Update `all_equal` docstring * Explicitly use `int32_t*` * Restore C CPU inference in the refactored parser (#10747) * Bring back the C parsing model The C parsing model is used for CPU inference and is still faster for CPU inference than the forward pass of the Thinc model. * Use C sgemm provided by the Ops implementation * Make tb_framework module Cython, merge in C forward implementation * TransitionModel: raise in backprop returned from forward_cpu * Re-enable greedy parse test * Return transition scores when forward_cpu is used * Apply suggestions from code review Import `Model` from `thinc.api` Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Use relative imports in tb_framework * Don't assume a default for beam_width * We don't have a direct dependency on BLIS anymore * Rename forwards to _forward_{fallback,greedy_cpu} * Require thinc >=8.1.0,<8.2.0 * tb_framework: clean up imports * Fix return type of _get_seen_mask * Move up _forward_greedy_cpu * Style fixes. * Lower thinc lowerbound to 8.1.0.dev0 * Formatting fix Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Reimplement parser rehearsal function (#10878) * Reimplement parser rehearsal function Before the parser refactor, rehearsal was driven by a loop in the `rehearse` method itself. For each parsing step, the loops would: 1. Get the predictions of the teacher. 2. Get the predictions and backprop function of the student. 3. Compute the loss and backprop into the student. 4. Move the teacher and student forward with the predictions of the student. In the refactored parser, we cannot perform search stepwise rehearsal anymore, since the model now predicts all parsing steps at once. Therefore, rehearsal is performed in the following steps: 1. Get the predictions of all parsing steps from the student, along with its backprop function. 2. Get the predictions from the teacher, but use the predictions of the student to advance the parser while doing so. 3. Compute the loss and backprop into the student. To support the second step a new method, `advance_with_actions` is added to `GreedyBatch`, which performs the provided parsing steps. * tb_framework: wrap upper_W and upper_b in Linear Thinc's Optimizer cannot handle resizing of existing parameters. Until it does, we work around this by wrapping the weights/biases of the upper layer of the parser model in Linear. When the upper layer is resized, we copy over the existing parameters into a new Linear instance. This does not trigger an error in Optimizer, because it sees the resized layer as a new set of parameters. * Add test for TransitionSystem.apply_actions * Better FIXME marker Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Fixes from Madeesh * Apply suggestions from Sofie Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Remove useless assignment Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Rename some identifiers in the parser refactor (#10935) * Rename _parseC to _parse_batch * tb_framework: prefix many auxiliary functions with underscore To clearly state the intent that they are private. * Rename `lower` to `hidden`, `upper` to `output` * Parser slow test fixup We don't have TransitionBasedParser.{v1,v2} until we bring it back as a legacy option. * Remove last vestiges of PrecomputableAffine This does not exist anymore as a separate layer. * ner: re-enable sentence boundary checks * Re-enable test that works now. * test_ner: make loss test more strict again * Remove commented line * Re-enable some more beam parser tests * Remove unused _forward_reference function * Update for CBlas changes in Thinc 8.1.0.dev2 Bump thinc dependency to 8.1.0.dev3. * Remove references to spacy.TransitionBasedParser.{v1,v2} Since they will not be offered starting with spaCy v4. * `tb_framework`: Replace references to `thinc.backends.linalg` with `CBlas` * dont use get_array_module (#11056) (#11293) Co-authored-by: kadarakos <kadar.akos@gmail.com> * Move `thinc.extra.search` to `spacy.pipeline._parser_internals` (#11317) * `search`: Move from `thinc.extra.search` Fix NPE in `Beam.__dealloc__` * `pytest`: Add support for executing Cython tests Move `search` tests from thinc and patch them to run with `pytest` * `mypy` fix * Update comment * `conftest`: Expose `register_cython_tests` * Remove unused import * Move `argmax` impls to new `_parser_utils` Cython module (#11410) * Parser does not have to be a cdef class anymore This also fixes validation of the initialization schema. * Add back spacy.TransitionBasedParser.v2 * Fix a rename that was missed in #10878. So that rehearsal tests pass. * Remove module from setup.py that got added during the merge * Bring back support for `update_with_oracle_cut_size` (#12086) * Bring back support for `update_with_oracle_cut_size` This option was available in the pre-refactor parser, but was never implemented in the refactored parser. This option cuts transition sequences that are longer than `update_with_oracle_cut` size into separate sequences that have at most `update_with_oracle_cut` transitions. The oracle (gold standard) transition sequence is used to determine the cuts and the initial states for the additional sequences. Applying this cut makes the batches more homogeneous in the transition sequence lengths, making forward passes (and as a consequence training) much faster. Training time 1000 steps on de_core_news_lg: - Before this change: 149s - After this change: 68s - Pre-refactor parser: 81s * Fix a rename that was missed in #10878. So that rehearsal tests pass. * Apply suggestions from @shadeMe * Use chained conditional * Test with update_with_oracle_cut_size={0, 1, 5, 100} And fix a git that occurs with a cut size of 1. * Fix up some merge fall out * Update parser distillation for the refactor In the old parser, we'd iterate over the transitions in the distill function and compute the loss/gradients on the go. In the refactored parser, we first let the student model parse the inputs. Then we'll let the teacher compute the transition probabilities of the states in the student's transition sequence. We can then compute the gradients of the student given the teacher. * Add back spacy.TransitionBasedParser.v1 references - Accordion in the architecture docs. - Test in test_parse, but disabled until we have a spacy-legacy release. Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: svlandeg <svlandeg@github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: kadarakos <kadar.akos@gmail.com>
698 lines
23 KiB
Cython
698 lines
23 KiB
Cython
import os
|
|
import random
|
|
from libc.stdint cimport int32_t
|
|
from libcpp.memory cimport shared_ptr
|
|
from libcpp.vector cimport vector
|
|
from cymem.cymem cimport Pool
|
|
|
|
from collections import Counter
|
|
|
|
from ...tokens.doc cimport Doc
|
|
from ...tokens.span import Span
|
|
from ...tokens.span cimport Span
|
|
from ...typedefs cimport weight_t, attr_t
|
|
from ...lexeme cimport Lexeme
|
|
from ...attrs cimport IS_SPACE
|
|
from ...structs cimport TokenC, SpanC
|
|
from ...training import split_bilu_label
|
|
from ...training.example cimport Example
|
|
from .search cimport Beam
|
|
from .stateclass cimport StateClass
|
|
from ._state cimport StateC
|
|
from .transition_system cimport Transition, do_func_t
|
|
|
|
from ...errors import Errors
|
|
|
|
|
|
cdef enum:
|
|
MISSING
|
|
BEGIN
|
|
IN
|
|
LAST
|
|
UNIT
|
|
OUT
|
|
N_MOVES
|
|
|
|
|
|
MOVE_NAMES = [None] * N_MOVES
|
|
MOVE_NAMES[MISSING] = 'M'
|
|
MOVE_NAMES[BEGIN] = 'B'
|
|
MOVE_NAMES[IN] = 'I'
|
|
MOVE_NAMES[LAST] = 'L'
|
|
MOVE_NAMES[UNIT] = 'U'
|
|
MOVE_NAMES[OUT] = 'O'
|
|
|
|
|
|
cdef struct GoldNERStateC:
|
|
Transition* ner
|
|
vector[shared_ptr[SpanC]] negs
|
|
|
|
|
|
cdef class BiluoGold:
|
|
cdef Pool mem
|
|
cdef GoldNERStateC c
|
|
|
|
def __init__(self, BiluoPushDown moves, StateClass stcls, Example example, neg_key):
|
|
self.mem = Pool()
|
|
self.c = create_gold_state(self.mem, moves, stcls.c, example, neg_key)
|
|
|
|
def update(self, StateClass stcls):
|
|
update_gold_state(&self.c, stcls.c)
|
|
|
|
|
|
cdef GoldNERStateC create_gold_state(
|
|
Pool mem,
|
|
BiluoPushDown moves,
|
|
const StateC* stcls,
|
|
Example example,
|
|
neg_key
|
|
) except *:
|
|
cdef GoldNERStateC gs
|
|
cdef Span neg
|
|
if neg_key is not None:
|
|
negs = example.get_aligned_spans_y2x(
|
|
example.y.spans.get(neg_key, []),
|
|
allow_overlap=True
|
|
)
|
|
else:
|
|
negs = []
|
|
assert example.x.length > 0
|
|
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
|
|
ner_ents, ner_tags = example.get_aligned_ents_and_ner()
|
|
for i, ner_tag in enumerate(ner_tags):
|
|
gs.ner[i] = moves.lookup_transition(ner_tag)
|
|
|
|
# Prevent conflicting spans in the data. For NER, spans are equal if they have the same offsets and label.
|
|
neg_span_triples = {(neg_ent.start_char, neg_ent.end_char, neg_ent.label) for neg_ent in negs}
|
|
for pos_span in ner_ents:
|
|
if (pos_span.start_char, pos_span.end_char, pos_span.label) in neg_span_triples:
|
|
raise ValueError(Errors.E868.format(span=(pos_span.start_char, pos_span.end_char, pos_span.label_)))
|
|
|
|
# In order to handle negative samples, we need to maintain the full
|
|
# (start, end, label) triple. If we break it down to the 'isnt B-LOC'
|
|
# thing, we'll get blocked if there's an incorrect prefix.
|
|
for neg in negs:
|
|
gs.negs.push_back(neg.c)
|
|
return gs
|
|
|
|
|
|
cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
|
|
# We don't need to update each time, unlike the parser.
|
|
pass
|
|
|
|
|
|
cdef do_func_t[N_MOVES] do_funcs
|
|
|
|
|
|
cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
|
|
if not state.entity_is_open():
|
|
return False
|
|
|
|
cdef const Transition* gold = &golds[state.E(0)]
|
|
ent = state.get_ent()
|
|
if gold.move != BEGIN and gold.move != UNIT:
|
|
return True
|
|
elif gold.label != ent.label:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
cdef class BiluoPushDown(TransitionSystem):
|
|
def __init__(self, *args, **kwargs):
|
|
TransitionSystem.__init__(self, *args, **kwargs)
|
|
|
|
@classmethod
|
|
def get_actions(cls, **kwargs):
|
|
actions = {
|
|
MISSING: Counter(),
|
|
BEGIN: Counter(),
|
|
IN: Counter(),
|
|
LAST: Counter(),
|
|
UNIT: Counter(),
|
|
OUT: Counter()
|
|
}
|
|
actions[OUT][''] = 1 # Represents a token predicted to be outside of any entity
|
|
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
|
|
for entity_type in kwargs.get('entity_types', []):
|
|
for action in (BEGIN, IN, LAST, UNIT):
|
|
actions[action][entity_type] = 1
|
|
moves = ('M', 'B', 'I', 'L', 'U')
|
|
for example in kwargs.get('examples', []):
|
|
for token in example.y:
|
|
ent_type = token.ent_type_
|
|
if ent_type:
|
|
for action in (BEGIN, IN, LAST, UNIT):
|
|
actions[action][ent_type] += 1
|
|
return actions
|
|
|
|
@property
|
|
def action_types(self):
|
|
return (BEGIN, IN, LAST, UNIT, OUT)
|
|
|
|
def get_doc_labels(self, doc):
|
|
labels = set()
|
|
for token in doc:
|
|
if token.ent_type:
|
|
labels.add(token.ent_type_)
|
|
return labels
|
|
|
|
def move_name(self, int move, attr_t label):
|
|
if move == OUT:
|
|
return 'O'
|
|
elif move == MISSING:
|
|
return 'M'
|
|
else:
|
|
return MOVE_NAMES[move] + '-' + self.strings[label]
|
|
|
|
def init_gold_batch(self, examples):
|
|
all_states = self.init_batch([eg.predicted for eg in examples])
|
|
golds = []
|
|
states = []
|
|
for state, eg in zip(all_states, examples):
|
|
if self.has_gold(eg) and not state.is_final():
|
|
golds.append(self.init_gold(state, eg))
|
|
states.append(state)
|
|
n_steps = sum([len(s.queue) for s in states])
|
|
return states, golds, n_steps
|
|
|
|
cdef Transition lookup_transition(self, object name) except *:
|
|
cdef attr_t label
|
|
if name == '-' or name == '' or name is None:
|
|
return Transition(clas=0, move=MISSING, label=0, score=0)
|
|
elif '-' in name:
|
|
move_str, label_str = split_bilu_label(name)
|
|
# Deprecated, hacky way to denote 'not this entity'
|
|
if label_str.startswith('!'):
|
|
raise ValueError(Errors.E869.format(label=name))
|
|
label = self.strings.add(label_str)
|
|
else:
|
|
move_str = name
|
|
label = 0
|
|
move = MOVE_NAMES.index(move_str)
|
|
for i in range(self.n_moves):
|
|
if self.c[i].move == move and self.c[i].label == label:
|
|
return self.c[i]
|
|
raise KeyError(Errors.E022.format(name=name))
|
|
|
|
cdef Transition init_transition(self, int clas, int move, attr_t label) except *:
|
|
# TODO: Apparent Cython bug here when we try to use the Transition()
|
|
# constructor with the function pointers
|
|
cdef Transition t
|
|
t.score = 0
|
|
t.clas = clas
|
|
t.move = move
|
|
t.label = label
|
|
if move == MISSING:
|
|
t.is_valid = Missing.is_valid
|
|
t.do = Missing.transition
|
|
t.get_cost = Missing.cost
|
|
elif move == BEGIN:
|
|
t.is_valid = Begin.is_valid
|
|
t.do = Begin.transition
|
|
t.get_cost = Begin.cost
|
|
elif move == IN:
|
|
t.is_valid = In.is_valid
|
|
t.do = In.transition
|
|
t.get_cost = In.cost
|
|
elif move == LAST:
|
|
t.is_valid = Last.is_valid
|
|
t.do = Last.transition
|
|
t.get_cost = Last.cost
|
|
elif move == UNIT:
|
|
t.is_valid = Unit.is_valid
|
|
t.do = Unit.transition
|
|
t.get_cost = Unit.cost
|
|
elif move == OUT:
|
|
t.is_valid = Out.is_valid
|
|
t.do = Out.transition
|
|
t.get_cost = Out.cost
|
|
else:
|
|
raise ValueError(Errors.E019.format(action=move, src='ner'))
|
|
return t
|
|
|
|
def add_action(self, int action, label_name, freq=None):
|
|
cdef attr_t label_id
|
|
if not isinstance(label_name, (int, long)):
|
|
label_id = self.strings.add(label_name)
|
|
else:
|
|
label_id = label_name
|
|
if action == OUT and label_id != 0:
|
|
return None
|
|
if action == MISSING:
|
|
return None
|
|
# Check we're not creating a move we already have, so that this is
|
|
# idempotent
|
|
for trans in self.c[:self.n_moves]:
|
|
if trans.move == action and trans.label == label_id:
|
|
return 0
|
|
if self.n_moves >= self._size:
|
|
self._size = self.n_moves
|
|
self._size *= 2
|
|
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
|
|
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
|
|
self.n_moves += 1
|
|
if self.labels.get(action, []):
|
|
freq = min(0, min(self.labels[action].values()))
|
|
self.labels[action][label_name] = freq-1
|
|
else:
|
|
self.labels[action] = Counter()
|
|
self.labels[action][label_name] = -1
|
|
return 1
|
|
|
|
def set_annotations(self, StateClass state, Doc doc):
|
|
cdef int i
|
|
ents = []
|
|
for i in range(state.c._ents.size()):
|
|
ent = state.c._ents.at(i)
|
|
if ent.start != -1 and ent.end != -1:
|
|
ents.append(Span(doc, ent.start, ent.end, label=ent.label, kb_id=doc.c[ent.start].ent_kb_id))
|
|
doc.set_ents(ents, default="unmodified")
|
|
# Set non-blocked tokens to O
|
|
for i in range(doc.length):
|
|
if doc.c[i].ent_iob == 0:
|
|
doc.c[i].ent_iob = 2
|
|
|
|
def get_beam_parses(self, Beam beam):
|
|
parses = []
|
|
probs = beam.probs
|
|
for i in range(beam.size):
|
|
state = <StateC*>beam.at(i)
|
|
if state.is_final():
|
|
prob = probs[i]
|
|
parse = []
|
|
for j in range(state._ents.size()):
|
|
ent = state._ents.at(j)
|
|
if ent.start != -1 and ent.end != -1:
|
|
parse.append((ent.start, ent.end, self.strings[ent.label]))
|
|
parses.append((prob, parse))
|
|
return parses
|
|
|
|
def init_gold(self, StateClass state, Example example):
|
|
return BiluoGold(self, state, example, self.neg_key)
|
|
|
|
def has_gold(self, Example eg, start=0, end=None):
|
|
# We get x and y referring to X, we want to check relative to Y,
|
|
# the reference
|
|
y_spans = eg.get_aligned_spans_x2y([eg.x[start:end]])
|
|
if not y_spans:
|
|
y_spans = [eg.y[:]]
|
|
y_span = y_spans[0]
|
|
start = y_span.start
|
|
end = y_span.end
|
|
neg_key = self.neg_key
|
|
if neg_key is not None:
|
|
# If we have any negative samples, count that as having annotation.
|
|
for span in eg.y.spans.get(neg_key, []):
|
|
if span.start >= start and span.end <= end:
|
|
return True
|
|
if end is not None and end < 0:
|
|
end = None
|
|
for word in eg.y[start:end]:
|
|
if word.ent_iob != 0:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def get_cost(self, StateClass stcls, gold, int i):
|
|
if not isinstance(gold, BiluoGold):
|
|
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
|
cdef BiluoGold gold_ = gold
|
|
gold_state = gold_.c
|
|
n_gold = 0
|
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
|
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
|
else:
|
|
cost = 9000
|
|
return cost
|
|
|
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
|
const StateC* state, gold) except -1:
|
|
if not isinstance(gold, BiluoGold):
|
|
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
|
cdef BiluoGold gold_ = gold
|
|
gold_state = gold_.c
|
|
update_gold_state(&gold_state, state)
|
|
n_gold = 0
|
|
self.set_valid(is_valid, state)
|
|
for i in range(self.n_moves):
|
|
if is_valid[i]:
|
|
costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
|
|
n_gold += costs[i] <= 0
|
|
else:
|
|
costs[i] = 9000
|
|
|
|
|
|
cdef class Missing:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
return False
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* s, attr_t label) nogil:
|
|
pass
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
return 9000
|
|
|
|
|
|
cdef class Begin:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if st.entity_is_open():
|
|
return False
|
|
if st.buffer_length() < 2:
|
|
# If we're the last token of the input, we can't B -- must U or O.
|
|
return False
|
|
elif label == 0:
|
|
return False
|
|
elif preset_ent_iob == 1:
|
|
# Ensure we don't clobber preset entities. If no entity preset,
|
|
# ent_iob is 0
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
# Okay, we're in a preset entity.
|
|
if label != preset_ent_label:
|
|
# If label isn't right, reject
|
|
return False
|
|
elif st.B_(1).ent_iob != 1:
|
|
# If next token isn't marked I, we need to make U, not B.
|
|
return False
|
|
else:
|
|
# Otherwise, force acceptance, even if we're across a sentence
|
|
# boundary or the token is whitespace.
|
|
return True
|
|
elif st.B_(1).ent_iob == 3:
|
|
# If the next word is B, we can't B now
|
|
return False
|
|
elif st.B_(1).sent_start == 1:
|
|
# Don't allow entities to extend across sentence boundaries
|
|
return False
|
|
# Don't allow entities to start on whitespace
|
|
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.open_ent(label)
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
b0 = s.B(0)
|
|
cdef int cost = 0
|
|
cdef int g_act = gold.ner[b0].move
|
|
cdef attr_t g_tag = gold.ner[b0].label
|
|
|
|
cdef shared_ptr[SpanC] span
|
|
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == BEGIN:
|
|
# B, Gold B --> Label match
|
|
cost += label != g_tag
|
|
else:
|
|
# B, Gold I --> False (P)
|
|
# B, Gold L --> False (P)
|
|
# B, Gold O --> False (P)
|
|
# B, Gold U --> False (P)
|
|
cost += 1
|
|
if s.buffer_length() < 3:
|
|
# Handle negatives. In general we can't really do much to block
|
|
# B, because we don't know whether the whole entity is going to
|
|
# be correct or not. However, we can at least tell whether we're
|
|
# going to be opening an entity where there's only one possible
|
|
# L.
|
|
for span in gold.negs:
|
|
if span.get().label == label and span.get().start == b0:
|
|
cost += 1
|
|
break
|
|
return cost
|
|
|
|
|
|
cdef class In:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
if not st.entity_is_open():
|
|
return False
|
|
if st.buffer_length() < 2:
|
|
# If we're at the end, we can't I.
|
|
return False
|
|
ent = st.get_ent()
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if label == 0:
|
|
return False
|
|
elif ent.label != label:
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
return False
|
|
elif st.B_(1).ent_iob == 3:
|
|
# If we know the next word is B, we can't be I (must be L)
|
|
return False
|
|
elif preset_ent_iob == 1:
|
|
if st.B_(1).ent_iob in (0, 2):
|
|
# if next preset is missing or O, this can't be I (must be L)
|
|
return False
|
|
elif label != preset_ent_label:
|
|
# If label isn't right, reject
|
|
return False
|
|
else:
|
|
# Otherwise, force acceptance, even if we're across a sentence
|
|
# boundary or the token is whitespace.
|
|
return True
|
|
elif st.B(1) != -1 and st.B_(1).sent_start == 1:
|
|
# Don't allow entities to extend across sentence boundaries
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
move = IN
|
|
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
|
cdef int g_act = gold.ner[s.B(0)].move
|
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
|
|
|
if g_act == MISSING:
|
|
return 0
|
|
elif g_act == BEGIN:
|
|
# I, Gold B --> True
|
|
# (P of bad open entity sunk, R of this entity sunk)
|
|
return 0
|
|
elif g_act == IN:
|
|
# I, Gold I --> True
|
|
# (label forced by prev, if mismatch, P and R both sunk)
|
|
return 0
|
|
elif g_act == LAST:
|
|
# I, Gold L --> True iff this entity sunk and next tag == O
|
|
return not (is_sunk and (next_act == OUT or next_act == MISSING))
|
|
elif g_act == OUT:
|
|
# I, Gold O --> True iff next tag == O
|
|
return not (next_act == OUT or next_act == MISSING)
|
|
elif g_act == UNIT:
|
|
# I, Gold U --> True iff next tag == O
|
|
return next_act != OUT
|
|
else:
|
|
return 1
|
|
|
|
cdef class Last:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if label == 0:
|
|
return False
|
|
elif not st.entity_is_open():
|
|
return False
|
|
elif preset_ent_iob == 1 and st.B_(1).ent_iob != 1:
|
|
# If a preset entity has I followed by not-I, is L
|
|
if label != preset_ent_label:
|
|
# If label isn't right, reject
|
|
return False
|
|
else:
|
|
# Otherwise, force acceptance, even if we're across a sentence
|
|
# boundary or the token is whitespace.
|
|
return True
|
|
elif st.get_ent().label != label:
|
|
return False
|
|
elif st.B_(1).ent_iob == 1:
|
|
# If a preset entity has I next, we can't L here.
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.close_ent()
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
move = LAST
|
|
b0 = s.B(0)
|
|
ent_start = s.E(0)
|
|
|
|
cdef int g_act = gold.ner[b0].move
|
|
cdef attr_t g_tag = gold.ner[b0].label
|
|
|
|
cdef int cost = 0
|
|
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == BEGIN:
|
|
# L, Gold B --> True
|
|
pass
|
|
elif g_act == IN:
|
|
# L, Gold I --> True iff this entity sunk
|
|
cost += not _entity_is_sunk(s, gold.ner)
|
|
elif g_act == LAST:
|
|
# L, Gold L --> True
|
|
pass
|
|
elif g_act == OUT:
|
|
# L, Gold O --> True
|
|
pass
|
|
elif g_act == UNIT:
|
|
# L, Gold U --> True
|
|
pass
|
|
else:
|
|
cost += 1
|
|
# If we have negative-example entities, integrate them into the objective,
|
|
# by marking actions that close an entity that we know is incorrect
|
|
# as costly.
|
|
cdef shared_ptr[SpanC] span
|
|
for span in gold.negs:
|
|
if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start:
|
|
cost += 1
|
|
break
|
|
return cost
|
|
|
|
|
|
cdef class Unit:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
|
if label == 0:
|
|
# this is only allowed if it's a preset blocked annotation
|
|
if preset_ent_label == 0 and preset_ent_iob == 3:
|
|
return True
|
|
else:
|
|
return False
|
|
elif st.entity_is_open():
|
|
return False
|
|
elif st.B_(1).ent_iob == 1:
|
|
# If next token is In, we can't be Unit -- must be Begin
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
# Okay, there's a preset entity here
|
|
if label != preset_ent_label:
|
|
# Require labels to match
|
|
return False
|
|
else:
|
|
# Otherwise return True, ignoring the whitespace constraint.
|
|
return True
|
|
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.open_ent(label)
|
|
st.close_ent()
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
cdef int g_act = gold.ner[s.B(0)].move
|
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
cdef int cost = 0
|
|
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == UNIT:
|
|
# U, Gold U --> True iff tag match
|
|
cost += label != g_tag
|
|
else:
|
|
# U, Gold B --> False
|
|
# U, Gold I --> False
|
|
# U, Gold L --> False
|
|
# U, Gold O --> False
|
|
cost += 1
|
|
# If we have negative-example entities, integrate them into the objective.
|
|
# This is fairly straight-forward for U- entities, as we have a single
|
|
# action
|
|
cdef int b0 = s.B(0)
|
|
cdef shared_ptr[SpanC] span
|
|
for span in gold.negs:
|
|
if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1):
|
|
cost += 1
|
|
break
|
|
return cost
|
|
|
|
|
|
|
|
cdef class Out:
|
|
@staticmethod
|
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
|
cdef int preset_ent_iob = st.B_(0).ent_iob
|
|
if st.entity_is_open():
|
|
return False
|
|
elif preset_ent_iob == 3:
|
|
return False
|
|
elif preset_ent_iob == 1:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
@staticmethod
|
|
cdef int transition(StateC* st, attr_t label) nogil:
|
|
st.push()
|
|
st.pop()
|
|
|
|
@staticmethod
|
|
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
|
gold = <GoldNERStateC*>_gold
|
|
cdef int g_act = gold.ner[s.B(0)].move
|
|
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
|
cdef weight_t cost = 0
|
|
if g_act == MISSING:
|
|
pass
|
|
elif g_act == BEGIN:
|
|
# O, Gold B --> False
|
|
cost += 1
|
|
elif g_act == IN:
|
|
# O, Gold I --> True
|
|
pass
|
|
elif g_act == LAST:
|
|
# O, Gold L --> True
|
|
pass
|
|
elif g_act == OUT:
|
|
# O, Gold O --> True
|
|
pass
|
|
elif g_act == UNIT:
|
|
# O, Gold U --> False
|
|
cost += 1
|
|
else:
|
|
cost += 1
|
|
return cost
|