spaCy/spacy/pipeline/_parser_internals/ner.pyx
Daniël de Kok a183db3cef
Merge the parser refactor into v4 (#10940)
* Try to fix doc.copy

* Set dev version

* Make vocab always own lexemes

* Change version

* Add SpanGroups.copy method

* Fix set_annotations during Parser.update

* Fix dict proxy copy

* Upd version

* Fix copying SpanGroups

* Fix set_annotations in parser.update

* Fix parser set_annotations during update

* Revert "Fix parser set_annotations during update"

This reverts commit eb138c89ed.

* Revert "Fix set_annotations in parser.update"

This reverts commit c6df0eafd0.

* Fix set_annotations during parser update

* Inc version

* Handle final states in get_oracle_sequence

* Inc version

* Try to fix parser training

* Inc version

* Fix

* Inc version

* Fix parser oracle

* Inc version

* Inc version

* Fix transition has_gold

* Inc version

* Try to use real histories, not oracle

* Inc version

* Upd parser

* Inc version

* WIP on rewrite parser

* WIP refactor parser

* New progress on parser model refactor

* Prepare to remove parser_model.pyx

* Convert parser from cdef class

* Delete spacy.ml.parser_model

* Delete _precomputable_affine module

* Wire up tb_framework to new parser model

* Wire up parser model

* Uncython ner.pyx and dep_parser.pyx

* Uncython

* Work on parser model

* Support unseen_classes in parser model

* Support unseen classes in parser

* Cleaner handling of unseen classes

* Work through tests

* Keep working through errors

* Keep working through errors

* Work on parser. 15 tests failing

* Xfail beam stuff. 9 failures

* More xfail. 7 failures

* Xfail. 6 failures

* cleanup

* formatting

* fixes

* pass nO through

* Fix empty doc in update

* Hackishly fix resizing. 3 failures

* Fix redundant test. 2 failures

* Add reference version

* black formatting

* Get tests passing with reference implementation

* Fix missing prints

* Add missing file

* Improve indexing on reference implementation

* Get non-reference forward func working

* Start rigging beam back up

* removing redundant tests, cf #8106

* black formatting

* temporarily xfailing issue 4314

* make flake8 happy again

* mypy fixes

* ensure labels are added upon predict

* cleanup remnants from merge conflicts

* Improve unseen label masking

Two changes to speed up masking by ~10%:

- Use a bool array rather than an array of float32.

- Let the mask indicate whether a label was seen, rather than
  unseen. The mask is most frequently used to index scores for
  seen labels. However, since the mask marked unseen labels,
  this required computing an intermittent flipped mask.

* Write moves costs directly into numpy array (#10163)

This avoids elementwise indexing and the allocation of an additional
array.

Gives a ~15% speed improvement when using batch_by_sequence with size
32.

* Temporarily disable ner and rehearse tests

Until rehearse is implemented again in the refactored parser.

* Fix loss serialization issue (#10600)

* Fix loss serialization issue

Serialization of a model fails with:

TypeError: array(738.3855, dtype=float32) is not JSON serializable

Fix this using float conversion.

* Disable CI steps that require spacy.TransitionBasedParser.v2

After finishing the refactor, TransitionBasedParser.v2 should be
provided for backwards compat.

* Add back support for beam parsing to the refactored parser (#10633)

* Add back support for beam parsing

Beam parsing was already implemented as part of the `BeamBatch` class.
This change makes its counterpart `GreedyBatch`. Both classes are hooked
up in `TransitionModel`, selecting `GreedyBatch` when the beam size is
one, or `BeamBatch` otherwise.

* Use kwarg for beam width

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Avoid implicit default for beam_width and beam_density

* Parser.{beam,greedy}_parse: ensure labels are added

* Remove 'deprecated' comments

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Parser `StateC` optimizations (#10746)

* `StateC`: Optimizations

Avoid GIL acquisition in `__init__`
Increase default buffer capacities on init
Reduce C++ exception overhead

* Fix typo

* Replace `set::count` with `set::find`

* Add exception attribute to c'tor

* Remove unused import

* Use a power-of-two value for initial capacity
Use default-insert to init `_heads` and `_unshiftable`

* Merge `cdef` variable declarations and assignments

* Vectorize `example.get_aligned_parses` (#10789)

* `example`: Vectorize `get_aligned_parse`
Rename `numpy` import

* Convert aligned array to lists before returning

* Revert import renaming

* Elide slice arguments when selecting the entire range

* Tagger/morphologizer alignment performance optimizations (#10798)

* `example`: Unwrap `numpy` scalar arrays before passing them to `StringStore.__getitem__`

* `AlignmentArray`: Use native list as staging buffer for offset calculation

* `example`: Vectorize `get_aligned`

* Hoist inner functions out of `get_aligned`

* Replace inline `if..else` clause in assignment statement

* `AlignmentArray`: Use raw indexing into offset and data `numpy` arrays

* `example`: Replace array unique value check with `groupby`

* `example`: Correctly exclude tokens with no alignment in `_get_aligned_vectorized`
Simplify `_get_aligned_non_vectorized`

* `util`: Update `all_equal` docstring

* Explicitly use `int32_t*`

* Restore C CPU inference in the refactored parser (#10747)

* Bring back the C parsing model

The C parsing model is used for CPU inference and is still faster for
CPU inference than the forward pass of the Thinc model.

* Use C sgemm provided by the Ops implementation

* Make tb_framework module Cython, merge in C forward implementation

* TransitionModel: raise in backprop returned from forward_cpu

* Re-enable greedy parse test

* Return transition scores when forward_cpu is used

* Apply suggestions from code review

Import `Model` from `thinc.api`

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Use relative imports in tb_framework

* Don't assume a default for beam_width

* We don't have a direct dependency on BLIS anymore

* Rename forwards to _forward_{fallback,greedy_cpu}

* Require thinc >=8.1.0,<8.2.0

* tb_framework: clean up imports

* Fix return type of _get_seen_mask

* Move up _forward_greedy_cpu

* Style fixes.

* Lower thinc lowerbound to 8.1.0.dev0

* Formatting fix

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Reimplement parser rehearsal function (#10878)

* Reimplement parser rehearsal function

Before the parser refactor, rehearsal was driven by a loop in the
`rehearse` method itself. For each parsing step, the loops would:

1. Get the predictions of the teacher.
2. Get the predictions and backprop function of the student.
3. Compute the loss and backprop into the student.
4. Move the teacher and student forward with the predictions of
   the student.

In the refactored parser, we cannot perform search stepwise rehearsal
anymore, since the model now predicts all parsing steps at once.
Therefore, rehearsal is performed in the following steps:

1. Get the predictions of all parsing steps from the student, along
   with its backprop function.
2. Get the predictions from the teacher, but use the predictions of
   the student to advance the parser while doing so.
3. Compute the loss and backprop into the student.

To support the second step a new method, `advance_with_actions` is
added to `GreedyBatch`, which performs the provided parsing steps.

* tb_framework: wrap upper_W and upper_b in Linear

Thinc's Optimizer cannot handle resizing of existing parameters. Until
it does, we work around this by wrapping the weights/biases of the upper
layer of the parser model in Linear. When the upper layer is resized, we
copy over the existing parameters into a new Linear instance. This does
not trigger an error in Optimizer, because it sees the resized layer as
a new set of parameters.

* Add test for TransitionSystem.apply_actions

* Better FIXME marker

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Fixes from Madeesh

* Apply suggestions from Sofie

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Remove useless assignment

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename some identifiers in the parser refactor (#10935)

* Rename _parseC to _parse_batch

* tb_framework: prefix many auxiliary functions with underscore

To clearly state the intent that they are private.

* Rename `lower` to `hidden`, `upper` to `output`

* Parser slow test fixup

We don't have TransitionBasedParser.{v1,v2} until we bring it back as a
legacy option.

* Remove last vestiges of PrecomputableAffine

This does not exist anymore as a separate layer.

* ner: re-enable sentence boundary checks

* Re-enable test that works now.

* test_ner: make loss test more strict again

* Remove commented line

* Re-enable some more beam parser tests

* Remove unused _forward_reference function

* Update for CBlas changes in Thinc 8.1.0.dev2

Bump thinc dependency to 8.1.0.dev3.

* Remove references to spacy.TransitionBasedParser.{v1,v2}

Since they will not be offered starting with spaCy v4.

* `tb_framework`: Replace references to `thinc.backends.linalg` with `CBlas`

* dont use get_array_module (#11056) (#11293)

Co-authored-by: kadarakos <kadar.akos@gmail.com>

* Move `thinc.extra.search` to `spacy.pipeline._parser_internals` (#11317)

* `search`: Move from `thinc.extra.search`
Fix NPE in `Beam.__dealloc__`

* `pytest`: Add support for executing Cython tests
Move `search` tests from thinc and patch them to run with `pytest`

* `mypy` fix

* Update comment

* `conftest`: Expose `register_cython_tests`

* Remove unused import

* Move `argmax` impls to new `_parser_utils` Cython module (#11410)

* Parser does not have to be a cdef class anymore

This also fixes validation of the initialization schema.

* Add back spacy.TransitionBasedParser.v2

* Fix a rename that was missed in #10878.

So that rehearsal tests pass.

* Remove module from setup.py that got added during the merge

* Bring back support for `update_with_oracle_cut_size` (#12086)

* Bring back support for `update_with_oracle_cut_size`

This option was available in the pre-refactor parser, but was never
implemented in the refactored parser. This option cuts transition
sequences that are longer than `update_with_oracle_cut` size into
separate sequences that have at most `update_with_oracle_cut`
transitions. The oracle (gold standard) transition sequence is used to
determine the cuts and the initial states for the additional sequences.

Applying this cut makes the batches more homogeneous in the transition
sequence lengths, making forward passes (and as a consequence training)
much faster.

Training time 1000 steps on de_core_news_lg:

- Before this change: 149s
- After this change: 68s
- Pre-refactor parser: 81s

* Fix a rename that was missed in #10878.

So that rehearsal tests pass.

* Apply suggestions from @shadeMe

* Use chained conditional

* Test with update_with_oracle_cut_size={0, 1, 5, 100}

And fix a git that occurs with a cut size of 1.

* Fix up some merge fall out

* Update parser distillation for the refactor

In the old parser, we'd iterate over the transitions in the distill
function and compute the loss/gradients on the go. In the refactored
parser, we first let the student model parse the inputs. Then we'll let
the teacher compute the transition probabilities of the states in the
student's transition sequence. We can then compute the gradients of the
student given the teacher.

* Add back spacy.TransitionBasedParser.v1 references

- Accordion in the architecture docs.
- Test in test_parse, but disabled until we have a spacy-legacy release.

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
Co-authored-by: svlandeg <svlandeg@github.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
Co-authored-by: kadarakos <kadar.akos@gmail.com>
2023-01-18 11:27:45 +01:00

698 lines
23 KiB
Cython

import os
import random
from libc.stdint cimport int32_t
from libcpp.memory cimport shared_ptr
from libcpp.vector cimport vector
from cymem.cymem cimport Pool
from collections import Counter
from ...tokens.doc cimport Doc
from ...tokens.span import Span
from ...tokens.span cimport Span
from ...typedefs cimport weight_t, attr_t
from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE
from ...structs cimport TokenC, SpanC
from ...training import split_bilu_label
from ...training.example cimport Example
from .search cimport Beam
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition, do_func_t
from ...errors import Errors
cdef enum:
MISSING
BEGIN
IN
LAST
UNIT
OUT
N_MOVES
MOVE_NAMES = [None] * N_MOVES
MOVE_NAMES[MISSING] = 'M'
MOVE_NAMES[BEGIN] = 'B'
MOVE_NAMES[IN] = 'I'
MOVE_NAMES[LAST] = 'L'
MOVE_NAMES[UNIT] = 'U'
MOVE_NAMES[OUT] = 'O'
cdef struct GoldNERStateC:
Transition* ner
vector[shared_ptr[SpanC]] negs
cdef class BiluoGold:
cdef Pool mem
cdef GoldNERStateC c
def __init__(self, BiluoPushDown moves, StateClass stcls, Example example, neg_key):
self.mem = Pool()
self.c = create_gold_state(self.mem, moves, stcls.c, example, neg_key)
def update(self, StateClass stcls):
update_gold_state(&self.c, stcls.c)
cdef GoldNERStateC create_gold_state(
Pool mem,
BiluoPushDown moves,
const StateC* stcls,
Example example,
neg_key
) except *:
cdef GoldNERStateC gs
cdef Span neg
if neg_key is not None:
negs = example.get_aligned_spans_y2x(
example.y.spans.get(neg_key, []),
allow_overlap=True
)
else:
negs = []
assert example.x.length > 0
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
ner_ents, ner_tags = example.get_aligned_ents_and_ner()
for i, ner_tag in enumerate(ner_tags):
gs.ner[i] = moves.lookup_transition(ner_tag)
# Prevent conflicting spans in the data. For NER, spans are equal if they have the same offsets and label.
neg_span_triples = {(neg_ent.start_char, neg_ent.end_char, neg_ent.label) for neg_ent in negs}
for pos_span in ner_ents:
if (pos_span.start_char, pos_span.end_char, pos_span.label) in neg_span_triples:
raise ValueError(Errors.E868.format(span=(pos_span.start_char, pos_span.end_char, pos_span.label_)))
# In order to handle negative samples, we need to maintain the full
# (start, end, label) triple. If we break it down to the 'isnt B-LOC'
# thing, we'll get blocked if there's an incorrect prefix.
for neg in negs:
gs.negs.push_back(neg.c)
return gs
cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
# We don't need to update each time, unlike the parser.
pass
cdef do_func_t[N_MOVES] do_funcs
cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
if not state.entity_is_open():
return False
cdef const Transition* gold = &golds[state.E(0)]
ent = state.get_ent()
if gold.move != BEGIN and gold.move != UNIT:
return True
elif gold.label != ent.label:
return True
else:
return False
cdef class BiluoPushDown(TransitionSystem):
def __init__(self, *args, **kwargs):
TransitionSystem.__init__(self, *args, **kwargs)
@classmethod
def get_actions(cls, **kwargs):
actions = {
MISSING: Counter(),
BEGIN: Counter(),
IN: Counter(),
LAST: Counter(),
UNIT: Counter(),
OUT: Counter()
}
actions[OUT][''] = 1 # Represents a token predicted to be outside of any entity
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
for entity_type in kwargs.get('entity_types', []):
for action in (BEGIN, IN, LAST, UNIT):
actions[action][entity_type] = 1
moves = ('M', 'B', 'I', 'L', 'U')
for example in kwargs.get('examples', []):
for token in example.y:
ent_type = token.ent_type_
if ent_type:
for action in (BEGIN, IN, LAST, UNIT):
actions[action][ent_type] += 1
return actions
@property
def action_types(self):
return (BEGIN, IN, LAST, UNIT, OUT)
def get_doc_labels(self, doc):
labels = set()
for token in doc:
if token.ent_type:
labels.add(token.ent_type_)
return labels
def move_name(self, int move, attr_t label):
if move == OUT:
return 'O'
elif move == MISSING:
return 'M'
else:
return MOVE_NAMES[move] + '-' + self.strings[label]
def init_gold_batch(self, examples):
all_states = self.init_batch([eg.predicted for eg in examples])
golds = []
states = []
for state, eg in zip(all_states, examples):
if self.has_gold(eg) and not state.is_final():
golds.append(self.init_gold(state, eg))
states.append(state)
n_steps = sum([len(s.queue) for s in states])
return states, golds, n_steps
cdef Transition lookup_transition(self, object name) except *:
cdef attr_t label
if name == '-' or name == '' or name is None:
return Transition(clas=0, move=MISSING, label=0, score=0)
elif '-' in name:
move_str, label_str = split_bilu_label(name)
# Deprecated, hacky way to denote 'not this entity'
if label_str.startswith('!'):
raise ValueError(Errors.E869.format(label=name))
label = self.strings.add(label_str)
else:
move_str = name
label = 0
move = MOVE_NAMES.index(move_str)
for i in range(self.n_moves):
if self.c[i].move == move and self.c[i].label == label:
return self.c[i]
raise KeyError(Errors.E022.format(name=name))
cdef Transition init_transition(self, int clas, int move, attr_t label) except *:
# TODO: Apparent Cython bug here when we try to use the Transition()
# constructor with the function pointers
cdef Transition t
t.score = 0
t.clas = clas
t.move = move
t.label = label
if move == MISSING:
t.is_valid = Missing.is_valid
t.do = Missing.transition
t.get_cost = Missing.cost
elif move == BEGIN:
t.is_valid = Begin.is_valid
t.do = Begin.transition
t.get_cost = Begin.cost
elif move == IN:
t.is_valid = In.is_valid
t.do = In.transition
t.get_cost = In.cost
elif move == LAST:
t.is_valid = Last.is_valid
t.do = Last.transition
t.get_cost = Last.cost
elif move == UNIT:
t.is_valid = Unit.is_valid
t.do = Unit.transition
t.get_cost = Unit.cost
elif move == OUT:
t.is_valid = Out.is_valid
t.do = Out.transition
t.get_cost = Out.cost
else:
raise ValueError(Errors.E019.format(action=move, src='ner'))
return t
def add_action(self, int action, label_name, freq=None):
cdef attr_t label_id
if not isinstance(label_name, (int, long)):
label_id = self.strings.add(label_name)
else:
label_id = label_name
if action == OUT and label_id != 0:
return None
if action == MISSING:
return None
# Check we're not creating a move we already have, so that this is
# idempotent
for trans in self.c[:self.n_moves]:
if trans.move == action and trans.label == label_id:
return 0
if self.n_moves >= self._size:
self._size = self.n_moves
self._size *= 2
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
self.n_moves += 1
if self.labels.get(action, []):
freq = min(0, min(self.labels[action].values()))
self.labels[action][label_name] = freq-1
else:
self.labels[action] = Counter()
self.labels[action][label_name] = -1
return 1
def set_annotations(self, StateClass state, Doc doc):
cdef int i
ents = []
for i in range(state.c._ents.size()):
ent = state.c._ents.at(i)
if ent.start != -1 and ent.end != -1:
ents.append(Span(doc, ent.start, ent.end, label=ent.label, kb_id=doc.c[ent.start].ent_kb_id))
doc.set_ents(ents, default="unmodified")
# Set non-blocked tokens to O
for i in range(doc.length):
if doc.c[i].ent_iob == 0:
doc.c[i].ent_iob = 2
def get_beam_parses(self, Beam beam):
parses = []
probs = beam.probs
for i in range(beam.size):
state = <StateC*>beam.at(i)
if state.is_final():
prob = probs[i]
parse = []
for j in range(state._ents.size()):
ent = state._ents.at(j)
if ent.start != -1 and ent.end != -1:
parse.append((ent.start, ent.end, self.strings[ent.label]))
parses.append((prob, parse))
return parses
def init_gold(self, StateClass state, Example example):
return BiluoGold(self, state, example, self.neg_key)
def has_gold(self, Example eg, start=0, end=None):
# We get x and y referring to X, we want to check relative to Y,
# the reference
y_spans = eg.get_aligned_spans_x2y([eg.x[start:end]])
if not y_spans:
y_spans = [eg.y[:]]
y_span = y_spans[0]
start = y_span.start
end = y_span.end
neg_key = self.neg_key
if neg_key is not None:
# If we have any negative samples, count that as having annotation.
for span in eg.y.spans.get(neg_key, []):
if span.start >= start and span.end <= end:
return True
if end is not None and end < 0:
end = None
for word in eg.y[start:end]:
if word.ent_iob != 0:
return True
else:
return False
def get_cost(self, StateClass stcls, gold, int i):
if not isinstance(gold, BiluoGold):
raise TypeError(Errors.E909.format(name="BiluoGold"))
cdef BiluoGold gold_ = gold
gold_state = gold_.c
n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
else:
cost = 9000
return cost
cdef int set_costs(self, int* is_valid, weight_t* costs,
const StateC* state, gold) except -1:
if not isinstance(gold, BiluoGold):
raise TypeError(Errors.E909.format(name="BiluoGold"))
cdef BiluoGold gold_ = gold
gold_state = gold_.c
update_gold_state(&gold_state, state)
n_gold = 0
self.set_valid(is_valid, state)
for i in range(self.n_moves):
if is_valid[i]:
costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
n_gold += costs[i] <= 0
else:
costs[i] = 9000
cdef class Missing:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
return False
@staticmethod
cdef int transition(StateC* s, attr_t label) nogil:
pass
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
return 9000
cdef class Begin:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
cdef attr_t preset_ent_label = st.B_(0).ent_type
if st.entity_is_open():
return False
if st.buffer_length() < 2:
# If we're the last token of the input, we can't B -- must U or O.
return False
elif label == 0:
return False
elif preset_ent_iob == 1:
# Ensure we don't clobber preset entities. If no entity preset,
# ent_iob is 0
return False
elif preset_ent_iob == 3:
# Okay, we're in a preset entity.
if label != preset_ent_label:
# If label isn't right, reject
return False
elif st.B_(1).ent_iob != 1:
# If next token isn't marked I, we need to make U, not B.
return False
else:
# Otherwise, force acceptance, even if we're across a sentence
# boundary or the token is whitespace.
return True
elif st.B_(1).ent_iob == 3:
# If the next word is B, we can't B now
return False
elif st.B_(1).sent_start == 1:
# Don't allow entities to extend across sentence boundaries
return False
# Don't allow entities to start on whitespace
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.open_ent(label)
st.push()
st.pop()
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
b0 = s.B(0)
cdef int cost = 0
cdef int g_act = gold.ner[b0].move
cdef attr_t g_tag = gold.ner[b0].label
cdef shared_ptr[SpanC] span
if g_act == MISSING:
pass
elif g_act == BEGIN:
# B, Gold B --> Label match
cost += label != g_tag
else:
# B, Gold I --> False (P)
# B, Gold L --> False (P)
# B, Gold O --> False (P)
# B, Gold U --> False (P)
cost += 1
if s.buffer_length() < 3:
# Handle negatives. In general we can't really do much to block
# B, because we don't know whether the whole entity is going to
# be correct or not. However, we can at least tell whether we're
# going to be opening an entity where there's only one possible
# L.
for span in gold.negs:
if span.get().label == label and span.get().start == b0:
cost += 1
break
return cost
cdef class In:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if not st.entity_is_open():
return False
if st.buffer_length() < 2:
# If we're at the end, we can't I.
return False
ent = st.get_ent()
cdef int preset_ent_iob = st.B_(0).ent_iob
cdef attr_t preset_ent_label = st.B_(0).ent_type
if label == 0:
return False
elif ent.label != label:
return False
elif preset_ent_iob == 3:
return False
elif st.B_(1).ent_iob == 3:
# If we know the next word is B, we can't be I (must be L)
return False
elif preset_ent_iob == 1:
if st.B_(1).ent_iob in (0, 2):
# if next preset is missing or O, this can't be I (must be L)
return False
elif label != preset_ent_label:
# If label isn't right, reject
return False
else:
# Otherwise, force acceptance, even if we're across a sentence
# boundary or the token is whitespace.
return True
elif st.B(1) != -1 and st.B_(1).sent_start == 1:
# Don't allow entities to extend across sentence boundaries
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.push()
st.pop()
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
if g_act == MISSING:
return 0
elif g_act == BEGIN:
# I, Gold B --> True
# (P of bad open entity sunk, R of this entity sunk)
return 0
elif g_act == IN:
# I, Gold I --> True
# (label forced by prev, if mismatch, P and R both sunk)
return 0
elif g_act == LAST:
# I, Gold L --> True iff this entity sunk and next tag == O
return not (is_sunk and (next_act == OUT or next_act == MISSING))
elif g_act == OUT:
# I, Gold O --> True iff next tag == O
return not (next_act == OUT or next_act == MISSING)
elif g_act == UNIT:
# I, Gold U --> True iff next tag == O
return next_act != OUT
else:
return 1
cdef class Last:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
cdef attr_t preset_ent_label = st.B_(0).ent_type
if label == 0:
return False
elif not st.entity_is_open():
return False
elif preset_ent_iob == 1 and st.B_(1).ent_iob != 1:
# If a preset entity has I followed by not-I, is L
if label != preset_ent_label:
# If label isn't right, reject
return False
else:
# Otherwise, force acceptance, even if we're across a sentence
# boundary or the token is whitespace.
return True
elif st.get_ent().label != label:
return False
elif st.B_(1).ent_iob == 1:
# If a preset entity has I next, we can't L here.
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.close_ent()
st.push()
st.pop()
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
move = LAST
b0 = s.B(0)
ent_start = s.E(0)
cdef int g_act = gold.ner[b0].move
cdef attr_t g_tag = gold.ner[b0].label
cdef int cost = 0
if g_act == MISSING:
pass
elif g_act == BEGIN:
# L, Gold B --> True
pass
elif g_act == IN:
# L, Gold I --> True iff this entity sunk
cost += not _entity_is_sunk(s, gold.ner)
elif g_act == LAST:
# L, Gold L --> True
pass
elif g_act == OUT:
# L, Gold O --> True
pass
elif g_act == UNIT:
# L, Gold U --> True
pass
else:
cost += 1
# If we have negative-example entities, integrate them into the objective,
# by marking actions that close an entity that we know is incorrect
# as costly.
cdef shared_ptr[SpanC] span
for span in gold.negs:
if span.get().label == label and (span.get().end-1) == b0 and span.get().start == ent_start:
cost += 1
break
return cost
cdef class Unit:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
cdef attr_t preset_ent_label = st.B_(0).ent_type
if label == 0:
# this is only allowed if it's a preset blocked annotation
if preset_ent_label == 0 and preset_ent_iob == 3:
return True
else:
return False
elif st.entity_is_open():
return False
elif st.B_(1).ent_iob == 1:
# If next token is In, we can't be Unit -- must be Begin
return False
elif preset_ent_iob == 3:
# Okay, there's a preset entity here
if label != preset_ent_label:
# Require labels to match
return False
else:
# Otherwise return True, ignoring the whitespace constraint.
return True
elif Lexeme.get_struct_attr(st.B_(0).lex, IS_SPACE):
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.open_ent(label)
st.close_ent()
st.push()
st.pop()
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
cdef int cost = 0
if g_act == MISSING:
pass
elif g_act == UNIT:
# U, Gold U --> True iff tag match
cost += label != g_tag
else:
# U, Gold B --> False
# U, Gold I --> False
# U, Gold L --> False
# U, Gold O --> False
cost += 1
# If we have negative-example entities, integrate them into the objective.
# This is fairly straight-forward for U- entities, as we have a single
# action
cdef int b0 = s.B(0)
cdef shared_ptr[SpanC] span
for span in gold.negs:
if span.get().label == label and span.get().start == b0 and span.get().end == (b0+1):
cost += 1
break
return cost
cdef class Out:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob
if st.entity_is_open():
return False
elif preset_ent_iob == 3:
return False
elif preset_ent_iob == 1:
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.push()
st.pop()
@staticmethod
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label
cdef weight_t cost = 0
if g_act == MISSING:
pass
elif g_act == BEGIN:
# O, Gold B --> False
cost += 1
elif g_act == IN:
# O, Gold I --> True
pass
elif g_act == LAST:
# O, Gold L --> True
pass
elif g_act == OUT:
# O, Gold O --> True
pass
elif g_act == UNIT:
# O, Gold U --> False
cost += 1
else:
cost += 1
return cost