mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Add beam_parser and beam_ner components for v3 (#6369)
* Get basic beam tests working * Get basic beam tests working * Compile _beam_utils * Remove prints * Test beam density * Beam parser seems to train * Draft beam NER * Upd beam * Add hypothesis as dev dependency * Implement missing is-gold-parse method * Implement early update * Fix state hashing * Fix test * Fix test * Default to non-beam in parser constructor * Improve oracle for beam * Start refactoring beam * Update test * Refactor beam * Update nn * Refactor beam and weight by cost * Update ner beam settings * Update test * Add __init__.pxd * Upd test * Fix test * Upd test * Fix test * Remove ring buffer history from StateC * WIP change arc-eager transitions * Add state tests * Support ternary sent start values * Fix arc eager * Fix NER * Pass oracle cut size for beam * Fix ner test * Fix beam * Improve StateC.clone * Improve StateClass.borrow * Work directly with StateC, not StateClass * Remove print statements * Fix state copy * Improve state class * Refactor parser oracles * Fix arc eager oracle * Fix arc eager oracle * Use a vector to implement the stack * Refactor state data structure * Fix alignment of sent start * Add get_aligned_sent_starts method * Add test for ae oracle when bad sentence starts * Fix sentence segment handling * Avoid Reduce that inserts illegal sentence * Update preset SBD test * Fix test * Remove prints * Fix sent starts in Example * Improve python API of StateClass * Tweak comments and debug output of arc eager * Upd test * Fix state test * Fix state test
This commit is contained in:
parent
85ca8c2bdd
commit
8656a08777
|
@ -27,3 +27,4 @@ pytest>=4.6.5
|
|||
pytest-timeout>=1.3.0,<2.0.0
|
||||
mock>=2.0.0,<3.0.0
|
||||
flake8>=3.5.0,<3.6.0
|
||||
hypothesis
|
||||
|
|
1
setup.py
1
setup.py
|
@ -48,6 +48,7 @@ MOD_NAMES = [
|
|||
"spacy.pipeline._parser_internals._state",
|
||||
"spacy.pipeline._parser_internals.stateclass",
|
||||
"spacy.pipeline._parser_internals.transition_system",
|
||||
"spacy.pipeline._parser_internals._beam_utils",
|
||||
"spacy.tokenizer",
|
||||
"spacy.training.align",
|
||||
"spacy.training.gold_io",
|
||||
|
|
0
spacy/pipeline/_parser_internals/__init__.pxd
Normal file
0
spacy/pipeline/_parser_internals/__init__.pxd
Normal file
6
spacy/pipeline/_parser_internals/_beam_utils.pxd
Normal file
6
spacy/pipeline/_parser_internals/_beam_utils.pxd
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ...typedefs cimport class_t, hash_t
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
|
||||
|
||||
cdef int check_final_state(void* _state, void* extra_args) except -1
|
296
spacy/pipeline/_parser_internals/_beam_utils.pyx
Normal file
296
spacy/pipeline/_parser_internals/_beam_utils.pyx
Normal file
|
@ -0,0 +1,296 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
cimport numpy as np
|
||||
import numpy
|
||||
from cpython.ref cimport PyObject, Py_XDECREF
|
||||
from thinc.extra.search cimport Beam
|
||||
from thinc.extra.search import MaxViolation
|
||||
from thinc.extra.search cimport MaxViolation
|
||||
|
||||
from ...typedefs cimport hash_t, class_t
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ...errors import Errors
|
||||
from .stateclass cimport StateC, StateClass
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <StateC*>_dest
|
||||
src = <StateC*>_src
|
||||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
|
||||
|
||||
cdef int check_final_state(void* _state, void* extra_args) except -1:
|
||||
state = <StateC*>_state
|
||||
return state.is_final()
|
||||
|
||||
|
||||
cdef class BeamBatch(object):
|
||||
cdef public TransitionSystem moves
|
||||
cdef public object states
|
||||
cdef public object docs
|
||||
cdef public object golds
|
||||
cdef public object beams
|
||||
|
||||
def __init__(self, TransitionSystem moves, states, golds,
|
||||
int width, float density=0.):
|
||||
cdef StateClass state
|
||||
self.moves = moves
|
||||
self.states = states
|
||||
self.docs = [state.doc for state in states]
|
||||
self.golds = golds
|
||||
self.beams = []
|
||||
cdef Beam beam
|
||||
cdef StateC* st
|
||||
for state in states:
|
||||
beam = Beam(self.moves.n_moves, width, min_density=density)
|
||||
beam.initialize(self.moves.init_beam_state,
|
||||
self.moves.del_beam_state, state.c.length,
|
||||
<void*>state.c._sent)
|
||||
for i in range(beam.width):
|
||||
st = <StateC*>beam.at(i)
|
||||
st.offset = state.c.offset
|
||||
beam.check_done(check_final_state, NULL)
|
||||
self.beams.append(beam)
|
||||
|
||||
@property
|
||||
def is_done(self):
|
||||
return all(b.is_done for b in self.beams)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.beams[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.beams)
|
||||
|
||||
def get_states(self):
|
||||
cdef Beam beam
|
||||
cdef StateC* state
|
||||
cdef StateClass stcls
|
||||
states = []
|
||||
for beam, doc in zip(self, self.docs):
|
||||
for i in range(beam.size):
|
||||
state = <StateC*>beam.at(i)
|
||||
stcls = StateClass.borrow(state, doc)
|
||||
states.append(stcls)
|
||||
return states
|
||||
|
||||
def get_unfinished_states(self):
|
||||
return [st for st in self.get_states() if not st.is_final()]
|
||||
|
||||
def advance(self, float[:, ::1] scores, follow_gold=False):
|
||||
cdef Beam beam
|
||||
cdef int nr_class = scores.shape[1]
|
||||
cdef const float* c_scores = &scores[0, 0]
|
||||
docs = self.docs
|
||||
for i, beam in enumerate(self):
|
||||
if not beam.is_done:
|
||||
nr_state = self._set_scores(beam, c_scores, nr_class)
|
||||
assert nr_state
|
||||
if self.golds is not None:
|
||||
self._set_costs(
|
||||
beam,
|
||||
docs[i],
|
||||
self.golds[i],
|
||||
follow_gold=follow_gold
|
||||
)
|
||||
c_scores += nr_state * nr_class
|
||||
beam.advance(transition_state, NULL, <void*>self.moves.c)
|
||||
beam.check_done(check_final_state, NULL)
|
||||
|
||||
cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
|
||||
cdef int nr_state = 0
|
||||
for i in range(beam.size):
|
||||
state = <StateC*>beam.at(i)
|
||||
if not state.is_final():
|
||||
for j in range(nr_class):
|
||||
beam.scores[i][j] = scores[nr_state * nr_class + j]
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
nr_state += 1
|
||||
else:
|
||||
for j in range(beam.nr_class):
|
||||
beam.scores[i][j] = 0
|
||||
beam.costs[i][j] = 0
|
||||
return nr_state
|
||||
|
||||
def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
|
||||
cdef const StateC* state
|
||||
for i in range(beam.size):
|
||||
state = <const StateC*>beam.at(i)
|
||||
if state.is_final():
|
||||
for j in range(beam.nr_class):
|
||||
beam.is_valid[i][j] = 0
|
||||
beam.costs[i][j] = 9000
|
||||
else:
|
||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
||||
state, gold)
|
||||
if follow_gold:
|
||||
min_cost = 0
|
||||
for j in range(beam.nr_class):
|
||||
if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
|
||||
min_cost = beam.costs[i][j]
|
||||
for j in range(beam.nr_class):
|
||||
if beam.costs[i][j] > min_cost:
|
||||
beam.is_valid[i][j] = 0
|
||||
|
||||
|
||||
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
|
||||
cdef MaxViolation violn
|
||||
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
||||
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
||||
cdef StateClass state
|
||||
beam_maps = []
|
||||
backprops = []
|
||||
violns = [MaxViolation() for _ in range(len(states))]
|
||||
dones = [False for _ in states]
|
||||
while not pbeam.is_done or not gbeam.is_done:
|
||||
# The beam maps let us find the right row in the flattened scores
|
||||
# array for each state. States are identified by (example id,
|
||||
# history). We keep a different beam map for each step (since we'll
|
||||
# have a flat scores array for each step). The beam map will let us
|
||||
# take the per-state losses, and compute the gradient for each (step,
|
||||
# state, class).
|
||||
# Gather all states from the two beams in a list. Some stats may occur
|
||||
# in both beams. To figure out which beam each state belonged to,
|
||||
# we keep two lists of indices, p_indices and g_indices
|
||||
states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
|
||||
beam_maps.append(beam_map)
|
||||
if not states:
|
||||
break
|
||||
# Now that we have our flat list of states, feed them through the model
|
||||
scores, bp_scores = model.begin_update(states)
|
||||
assert scores.size != 0
|
||||
# Store the callbacks for the backward pass
|
||||
backprops.append(bp_scores)
|
||||
# Unpack the scores for the two beams. The indices arrays
|
||||
# tell us which example and state the scores-row refers to.
|
||||
# Now advance the states in the beams. The gold beam is constrained to
|
||||
# to follow only gold analyses.
|
||||
if not pbeam.is_done:
|
||||
pbeam.advance(model.ops.as_contig(scores[p_indices]))
|
||||
if not gbeam.is_done:
|
||||
gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
|
||||
# Track the "maximum violation", to use in the update.
|
||||
for i, violn in enumerate(violns):
|
||||
if not dones[i]:
|
||||
violn.check_crf(pbeam[i], gbeam[i])
|
||||
if pbeam[i].is_done and gbeam[i].is_done:
|
||||
dones[i] = True
|
||||
histories = []
|
||||
grads = []
|
||||
for violn in violns:
|
||||
if violn.p_hist:
|
||||
histories.append(violn.p_hist + violn.g_hist)
|
||||
d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
|
||||
grads.append(d_loss)
|
||||
else:
|
||||
histories.append([])
|
||||
grads.append([])
|
||||
loss = 0.0
|
||||
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
|
||||
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
|
||||
loss += (d_scores**2).mean()
|
||||
bp_scores(d_scores)
|
||||
return loss
|
||||
|
||||
|
||||
def collect_states(beams, docs):
|
||||
cdef StateClass state
|
||||
cdef Beam beam
|
||||
states = []
|
||||
for state_or_beam, doc in zip(beams, docs):
|
||||
if isinstance(state_or_beam, StateClass):
|
||||
states.append(state_or_beam)
|
||||
else:
|
||||
beam = state_or_beam
|
||||
state = StateClass.borrow(<StateC*>beam.at(0), doc)
|
||||
states.append(state)
|
||||
return states
|
||||
|
||||
|
||||
def get_unique_states(pbeams, gbeams):
|
||||
seen = {}
|
||||
states = []
|
||||
p_indices = []
|
||||
g_indices = []
|
||||
beam_map = {}
|
||||
docs = pbeams.docs
|
||||
cdef Beam pbeam, gbeam
|
||||
if len(pbeams) != len(gbeams):
|
||||
raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
|
||||
for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
|
||||
if not pbeam.is_done:
|
||||
for i in range(pbeam.size):
|
||||
state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
|
||||
if not state.is_final():
|
||||
key = tuple([eg_id] + pbeam.histories[i])
|
||||
if key in seen:
|
||||
raise ValueError(Errors.E080.format(key=key))
|
||||
seen[key] = len(states)
|
||||
p_indices.append(len(states))
|
||||
states.append(state)
|
||||
beam_map.update(seen)
|
||||
if not gbeam.is_done:
|
||||
for i in range(gbeam.size):
|
||||
state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
|
||||
if not state.is_final():
|
||||
key = tuple([eg_id] + gbeam.histories[i])
|
||||
if key in seen:
|
||||
g_indices.append(seen[key])
|
||||
else:
|
||||
g_indices.append(len(states))
|
||||
beam_map[key] = len(states)
|
||||
states.append(state)
|
||||
p_indices = numpy.asarray(p_indices, dtype='i')
|
||||
g_indices = numpy.asarray(g_indices, dtype='i')
|
||||
return states, p_indices, g_indices, beam_map
|
||||
|
||||
|
||||
def get_gradient(nr_class, beam_maps, histories, losses):
|
||||
"""The global model assigns a loss to each parse. The beam scores
|
||||
are additive, so the same gradient is applied to each action
|
||||
in the history. This gives the gradient of a single *action*
|
||||
for a beam state -- so we have "the gradient of loss for taking
|
||||
action i given history H."
|
||||
|
||||
Histories: Each hitory is a list of actions
|
||||
Each candidate has a history
|
||||
Each beam has multiple candidates
|
||||
Each batch has multiple beams
|
||||
So history is list of lists of lists of ints
|
||||
"""
|
||||
grads = []
|
||||
nr_steps = []
|
||||
for eg_id, hists in enumerate(histories):
|
||||
nr_step = 0
|
||||
for loss, hist in zip(losses[eg_id], hists):
|
||||
assert not numpy.isnan(loss)
|
||||
if loss != 0.0:
|
||||
nr_step = max(nr_step, len(hist))
|
||||
nr_steps.append(nr_step)
|
||||
for i in range(max(nr_steps)):
|
||||
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
|
||||
dtype='f'))
|
||||
if len(histories) != len(losses):
|
||||
raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
|
||||
for eg_id, hists in enumerate(histories):
|
||||
for loss, hist in zip(losses[eg_id], hists):
|
||||
assert not numpy.isnan(loss)
|
||||
if loss == 0.0:
|
||||
continue
|
||||
key = tuple([eg_id])
|
||||
# Adjust loss for length
|
||||
# We need to do this because each state in a short path is scored
|
||||
# multiple times, as we add in the average cost when we run out
|
||||
# of actions.
|
||||
avg_loss = loss / len(hist)
|
||||
loss += avg_loss * (nr_steps[eg_id] - len(hist))
|
||||
for step, clas in enumerate(hist):
|
||||
i = beam_maps[step][key]
|
||||
# In step j, at state i action clas
|
||||
# resulted in loss
|
||||
grads[step][i, clas] += loss
|
||||
key = key + tuple([clas])
|
||||
return grads
|
|
@ -1,6 +1,9 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
from libc.stdlib cimport calloc, free
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
cimport libcpp
|
||||
from libcpp.vector cimport vector
|
||||
from libcpp.set cimport set
|
||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
||||
|
@ -14,89 +17,48 @@ from ...typedefs cimport attr_t
|
|||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
||||
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
||||
|
||||
cdef struct RingBufferC:
|
||||
int[8] data
|
||||
int i
|
||||
int default
|
||||
|
||||
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
|
||||
ring.data[ring.i] = value
|
||||
ring.i += 1
|
||||
if ring.i >= 8:
|
||||
ring.i = 0
|
||||
|
||||
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
|
||||
if i >= ring.i:
|
||||
return ring.default
|
||||
else:
|
||||
return ring.data[ring.i-i]
|
||||
cdef struct ArcC:
|
||||
int head
|
||||
int child
|
||||
attr_t label
|
||||
|
||||
|
||||
cdef cppclass StateC:
|
||||
int* _stack
|
||||
int* _buffer
|
||||
bint* shifted
|
||||
TokenC* _sent
|
||||
SpanC* _ents
|
||||
int* _heads
|
||||
const TokenC* _sent
|
||||
vector[int] _stack
|
||||
vector[int] _rebuffer
|
||||
vector[SpanC] _ents
|
||||
vector[ArcC] _left_arcs
|
||||
vector[ArcC] _right_arcs
|
||||
vector[libcpp.bool] _unshiftable
|
||||
set[int] _sent_starts
|
||||
TokenC _empty_token
|
||||
RingBufferC _hist
|
||||
int length
|
||||
int offset
|
||||
int _s_i
|
||||
int _b_i
|
||||
int _e_i
|
||||
int _break
|
||||
|
||||
__init__(const TokenC* sent, int length) nogil:
|
||||
cdef int PADDING = 5
|
||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||
this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
|
||||
if not (this._buffer and this._stack and this.shifted
|
||||
and this._sent and this._ents):
|
||||
this._sent = sent
|
||||
this._heads = <int*>calloc(length, sizeof(int))
|
||||
if not (this._sent and this._heads):
|
||||
with gil:
|
||||
PyErr_SetFromErrno(MemoryError)
|
||||
PyErr_CheckSignals()
|
||||
memset(&this._hist, 0, sizeof(this._hist))
|
||||
this.offset = 0
|
||||
cdef int i
|
||||
for i in range(length + (PADDING * 2)):
|
||||
this._ents[i].end = -1
|
||||
this._sent[i].l_edge = i
|
||||
this._sent[i].r_edge = i
|
||||
for i in range(PADDING):
|
||||
this._sent[i].lex = &EMPTY_LEXEME
|
||||
this._sent += PADDING
|
||||
this._ents += PADDING
|
||||
this._buffer += PADDING
|
||||
this._stack += PADDING
|
||||
this.shifted += PADDING
|
||||
this.length = length
|
||||
this._break = -1
|
||||
this._s_i = 0
|
||||
this._b_i = 0
|
||||
this._e_i = 0
|
||||
for i in range(length):
|
||||
this._buffer[i] = i
|
||||
this._heads[i] = -1
|
||||
this._unshiftable.push_back(0)
|
||||
memset(&this._empty_token, 0, sizeof(TokenC))
|
||||
this._empty_token.lex = &EMPTY_LEXEME
|
||||
for i in range(length):
|
||||
this._sent[i] = sent[i]
|
||||
this._buffer[i] = i
|
||||
for i in range(length, length+PADDING):
|
||||
this._sent[i].lex = &EMPTY_LEXEME
|
||||
|
||||
__dealloc__():
|
||||
cdef int PADDING = 5
|
||||
free(this._sent - PADDING)
|
||||
free(this._ents - PADDING)
|
||||
free(this._buffer - PADDING)
|
||||
free(this._stack - PADDING)
|
||||
free(this.shifted - PADDING)
|
||||
free(this._heads)
|
||||
|
||||
void set_context_tokens(int* ids, int n) nogil:
|
||||
cdef int i, j
|
||||
if n == 1:
|
||||
if this.B(0) >= 0:
|
||||
ids[0] = this.B(0)
|
||||
|
@ -145,22 +107,18 @@ cdef cppclass StateC:
|
|||
ids[11] = this.R(this.S(1), 1)
|
||||
ids[12] = this.R(this.S(1), 2)
|
||||
elif n == 6:
|
||||
for i in range(6):
|
||||
ids[i] = -1
|
||||
if this.B(0) >= 0:
|
||||
ids[0] = this.B(0)
|
||||
ids[1] = this.B(0)-1
|
||||
else:
|
||||
ids[0] = -1
|
||||
ids[1] = -1
|
||||
ids[2] = this.B(1)
|
||||
ids[3] = this.E(0)
|
||||
if ids[3] >= 1:
|
||||
ids[4] = this.E(0)-1
|
||||
else:
|
||||
ids[4] = -1
|
||||
if (ids[3]+1) < this.length:
|
||||
ids[5] = this.E(0)+1
|
||||
else:
|
||||
ids[5] = -1
|
||||
if this.entity_is_open():
|
||||
ent = this.get_ent()
|
||||
j = 1
|
||||
for i in range(ent.start, this.B(0)):
|
||||
ids[j] = i
|
||||
j += 1
|
||||
if j >= 6:
|
||||
break
|
||||
else:
|
||||
# TODO error =/
|
||||
pass
|
||||
|
@ -171,329 +129,256 @@ cdef cppclass StateC:
|
|||
ids[i] = -1
|
||||
|
||||
int S(int i) nogil const:
|
||||
if i >= this._s_i:
|
||||
if i >= this._stack.size():
|
||||
return -1
|
||||
return this._stack[this._s_i - (i+1)]
|
||||
elif i < 0:
|
||||
return -1
|
||||
return this._stack.at(this._stack.size() - (i+1))
|
||||
|
||||
int B(int i) nogil const:
|
||||
if (i + this._b_i) >= this.length:
|
||||
if i < 0:
|
||||
return -1
|
||||
return this._buffer[this._b_i + i]
|
||||
|
||||
const TokenC* S_(int i) nogil const:
|
||||
return this.safe_get(this.S(i))
|
||||
elif i < this._rebuffer.size():
|
||||
return this._rebuffer.at(this._rebuffer.size() - (i+1))
|
||||
else:
|
||||
b_i = this._b_i + (i - this._rebuffer.size())
|
||||
if b_i >= this.length:
|
||||
return -1
|
||||
else:
|
||||
return b_i
|
||||
|
||||
const TokenC* B_(int i) nogil const:
|
||||
return this.safe_get(this.B(i))
|
||||
|
||||
const TokenC* H_(int i) nogil const:
|
||||
return this.safe_get(this.H(i))
|
||||
|
||||
const TokenC* E_(int i) nogil const:
|
||||
return this.safe_get(this.E(i))
|
||||
|
||||
const TokenC* L_(int i, int idx) nogil const:
|
||||
return this.safe_get(this.L(i, idx))
|
||||
|
||||
const TokenC* R_(int i, int idx) nogil const:
|
||||
return this.safe_get(this.R(i, idx))
|
||||
|
||||
const TokenC* safe_get(int i) nogil const:
|
||||
if i < 0 or i >= this.length:
|
||||
return &this._empty_token
|
||||
else:
|
||||
return &this._sent[i]
|
||||
|
||||
int H(int i) nogil const:
|
||||
if i < 0 or i >= this.length:
|
||||
void get_arcs(vector[ArcC]* arcs) nogil const:
|
||||
for i in range(this._left_arcs.size()):
|
||||
arc = this._left_arcs.at(i)
|
||||
if arc.head != -1 and arc.child != -1:
|
||||
arcs.push_back(arc)
|
||||
for i in range(this._right_arcs.size()):
|
||||
arc = this._right_arcs.at(i)
|
||||
if arc.head != -1 and arc.child != -1:
|
||||
arcs.push_back(arc)
|
||||
|
||||
int H(int child) nogil const:
|
||||
if child >= this.length or child < 0:
|
||||
return -1
|
||||
return this._sent[i].head + i
|
||||
else:
|
||||
return this._heads[child]
|
||||
|
||||
int E(int i) nogil const:
|
||||
if this._e_i <= 0 or this._e_i >= this.length:
|
||||
if this._ents.size() == 0:
|
||||
return -1
|
||||
if i < 0 or i >= this._e_i:
|
||||
return -1
|
||||
return this._ents[this._e_i - (i+1)].start
|
||||
|
||||
int L(int i, int idx) nogil const:
|
||||
if idx < 1:
|
||||
return -1
|
||||
if i < 0 or i >= this.length:
|
||||
return -1
|
||||
cdef const TokenC* target = &this._sent[i]
|
||||
if target.l_kids < <uint32_t>idx:
|
||||
return -1
|
||||
cdef const TokenC* ptr = &this._sent[target.l_edge]
|
||||
|
||||
while ptr < target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head >= 1) and (ptr + ptr.head) < target:
|
||||
ptr += ptr.head
|
||||
|
||||
elif ptr + ptr.head == target:
|
||||
idx -= 1
|
||||
if idx == 0:
|
||||
return ptr - this._sent
|
||||
ptr += 1
|
||||
else:
|
||||
ptr += 1
|
||||
return -1
|
||||
return this._ents.back().start
|
||||
|
||||
int R(int i, int idx) nogil const:
|
||||
if idx < 1:
|
||||
int L(int head, int idx) nogil const:
|
||||
if idx < 1 or this._left_arcs.size() == 0:
|
||||
return -1
|
||||
if i < 0 or i >= this.length:
|
||||
cdef vector[int] lefts
|
||||
for i in range(this._left_arcs.size()):
|
||||
arc = this._left_arcs.at(i)
|
||||
if arc.head == head and arc.child != -1 and arc.child < head:
|
||||
lefts.push_back(arc.child)
|
||||
idx = (<int>lefts.size()) - idx
|
||||
if idx < 0:
|
||||
return -1
|
||||
cdef const TokenC* target = &this._sent[i]
|
||||
if target.r_kids < <uint32_t>idx:
|
||||
return -1
|
||||
cdef const TokenC* ptr = &this._sent[target.r_edge]
|
||||
while ptr > target:
|
||||
# If this head is still to the right of us, we can skip to it
|
||||
# No token that's between this token and this head could be our
|
||||
# child.
|
||||
if (ptr.head < 0) and ((ptr + ptr.head) > target):
|
||||
ptr += ptr.head
|
||||
elif ptr + ptr.head == target:
|
||||
idx -= 1
|
||||
if idx == 0:
|
||||
return ptr - this._sent
|
||||
ptr -= 1
|
||||
else:
|
||||
ptr -= 1
|
||||
return lefts.at(idx)
|
||||
|
||||
int R(int head, int idx) nogil const:
|
||||
if idx < 1 or this._right_arcs.size() == 0:
|
||||
return -1
|
||||
cdef vector[int] rights
|
||||
for i in range(this._right_arcs.size()):
|
||||
arc = this._right_arcs.at(i)
|
||||
if arc.head == head and arc.child != -1 and arc.child > head:
|
||||
rights.push_back(arc.child)
|
||||
idx = (<int>rights.size()) - idx
|
||||
if idx < 0:
|
||||
return -1
|
||||
else:
|
||||
return rights.at(idx)
|
||||
|
||||
bint empty() nogil const:
|
||||
return this._s_i <= 0
|
||||
return this._stack.size() == 0
|
||||
|
||||
bint eol() nogil const:
|
||||
return this.buffer_length() == 0
|
||||
|
||||
bint at_break() nogil const:
|
||||
return this._break != -1
|
||||
|
||||
bint is_final() nogil const:
|
||||
return this.stack_depth() <= 0 and this._b_i >= this.length
|
||||
return this.stack_depth() <= 0 and this.eol()
|
||||
|
||||
bint has_head(int i) nogil const:
|
||||
return this.safe_get(i).head != 0
|
||||
int cannot_sent_start(int word) nogil const:
|
||||
if word < 0 or word >= this.length:
|
||||
return 0
|
||||
elif this._sent[word].sent_start == -1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
int n_L(int i) nogil const:
|
||||
return this.safe_get(i).l_kids
|
||||
int is_sent_start(int word) nogil const:
|
||||
if word < 0 or word >= this.length:
|
||||
return 0
|
||||
elif this._sent[word].sent_start == 1:
|
||||
return 1
|
||||
elif this._sent_starts.count(word) >= 1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
int n_R(int i) nogil const:
|
||||
return this.safe_get(i).r_kids
|
||||
void set_sent_start(int word, int value) nogil:
|
||||
if value >= 1:
|
||||
this._sent_starts.insert(word)
|
||||
|
||||
bint has_head(int child) nogil const:
|
||||
return this._heads[child] >= 0
|
||||
|
||||
int l_edge(int word) nogil const:
|
||||
return word
|
||||
|
||||
int r_edge(int word) nogil const:
|
||||
return word
|
||||
|
||||
int n_L(int head) nogil const:
|
||||
cdef int n = 0
|
||||
for i in range(this._left_arcs.size()):
|
||||
arc = this._left_arcs.at(i)
|
||||
if arc.head == head and arc.child != -1 and arc.child < arc.head:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
int n_R(int head) nogil const:
|
||||
cdef int n = 0
|
||||
for i in range(this._right_arcs.size()):
|
||||
arc = this._right_arcs.at(i)
|
||||
if arc.head == head and arc.child != -1 and arc.child > arc.head:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
bint stack_is_connected() nogil const:
|
||||
return False
|
||||
|
||||
bint entity_is_open() nogil const:
|
||||
if this._e_i < 1:
|
||||
if this._ents.size() == 0:
|
||||
return False
|
||||
return this._ents[this._e_i-1].end == -1
|
||||
else:
|
||||
return this._ents.back().end == -1
|
||||
|
||||
int stack_depth() nogil const:
|
||||
return this._s_i
|
||||
return this._stack.size()
|
||||
|
||||
int buffer_length() nogil const:
|
||||
if this._break != -1:
|
||||
return this._break - this._b_i
|
||||
else:
|
||||
return this.length - this._b_i
|
||||
|
||||
uint64_t hash() nogil const:
|
||||
cdef TokenC[11] sig
|
||||
sig[0] = this.S_(2)[0]
|
||||
sig[1] = this.S_(1)[0]
|
||||
sig[2] = this.R_(this.S(1), 1)[0]
|
||||
sig[3] = this.L_(this.S(0), 1)[0]
|
||||
sig[4] = this.L_(this.S(0), 2)[0]
|
||||
sig[5] = this.S_(0)[0]
|
||||
sig[6] = this.R_(this.S(0), 2)[0]
|
||||
sig[7] = this.R_(this.S(0), 1)[0]
|
||||
sig[8] = this.B_(0)[0]
|
||||
sig[9] = this.E_(0)[0]
|
||||
sig[10] = this.E_(1)[0]
|
||||
return hash64(sig, sizeof(sig), this._s_i) \
|
||||
+ hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
|
||||
|
||||
void push_hist(int act) nogil:
|
||||
ring_push(&this._hist, act+1)
|
||||
|
||||
int get_hist(int i) nogil:
|
||||
return ring_get(&this._hist, i)
|
||||
|
||||
void push() nogil:
|
||||
if this.B(0) != -1:
|
||||
this._stack[this._s_i] = this.B(0)
|
||||
this._s_i += 1
|
||||
b0 = this.B(0)
|
||||
if this._rebuffer.size():
|
||||
b0 = this._rebuffer.back()
|
||||
this._rebuffer.pop_back()
|
||||
else:
|
||||
b0 = this._b_i
|
||||
this._b_i += 1
|
||||
if this.safe_get(this.B_(0).l_edge).sent_start == 1:
|
||||
this.set_break(this.B_(0).l_edge)
|
||||
if this._b_i > this._break:
|
||||
this._break = -1
|
||||
this._stack.push_back(b0)
|
||||
|
||||
void pop() nogil:
|
||||
if this._s_i >= 1:
|
||||
this._s_i -= 1
|
||||
this._stack.pop_back()
|
||||
|
||||
void force_final() nogil:
|
||||
# This should only be used in desperate situations, as it may leave
|
||||
# the analysis in an unexpected state.
|
||||
this._s_i = 0
|
||||
this._stack.clear()
|
||||
this._b_i = this.length
|
||||
|
||||
void unshift() nogil:
|
||||
this._b_i -= 1
|
||||
this._buffer[this._b_i] = this.S(0)
|
||||
this._s_i -= 1
|
||||
this.shifted[this.B(0)] = True
|
||||
s0 = this._stack.back()
|
||||
this._unshiftable[s0] = 1
|
||||
this._rebuffer.push_back(s0)
|
||||
this._stack.pop_back()
|
||||
|
||||
int is_unshiftable(int item) nogil const:
|
||||
if item >= this._unshiftable.size():
|
||||
return 0
|
||||
else:
|
||||
return this._unshiftable.at(item)
|
||||
|
||||
void set_reshiftable(int item) nogil:
|
||||
if item < this._unshiftable.size():
|
||||
this._unshiftable[item] = 0
|
||||
|
||||
void add_arc(int head, int child, attr_t label) nogil:
|
||||
if this.has_head(child):
|
||||
this.del_arc(this.H(child), child)
|
||||
|
||||
cdef int dist = head - child
|
||||
this._sent[child].head = dist
|
||||
this._sent[child].dep = label
|
||||
cdef int i
|
||||
if child > head:
|
||||
this._sent[head].r_kids += 1
|
||||
# Some transition systems can have a word in the buffer have a
|
||||
# rightward child, e.g. from Unshift.
|
||||
this._sent[head].r_edge = this._sent[child].r_edge
|
||||
i = 0
|
||||
while this.has_head(head) and i < this.length:
|
||||
head = this.H(head)
|
||||
this._sent[head].r_edge = this._sent[child].r_edge
|
||||
i += 1 # Guard against infinite loops
|
||||
cdef ArcC arc
|
||||
arc.head = head
|
||||
arc.child = child
|
||||
arc.label = label
|
||||
if head > child:
|
||||
this._left_arcs.push_back(arc)
|
||||
else:
|
||||
this._sent[head].l_kids += 1
|
||||
this._sent[head].l_edge = this._sent[child].l_edge
|
||||
this._right_arcs.push_back(arc)
|
||||
this._heads[child] = head
|
||||
|
||||
void del_arc(int h_i, int c_i) nogil:
|
||||
cdef int dist = h_i - c_i
|
||||
cdef TokenC* h = &this._sent[h_i]
|
||||
cdef int i = 0
|
||||
if c_i > h_i:
|
||||
# this.R_(h_i, 2) returns the second-rightmost child token of h_i
|
||||
# If we have more than 2 rightmost children, our 2nd rightmost child's
|
||||
# rightmost edge is going to be our new rightmost edge.
|
||||
h.r_edge = this.R_(h_i, 2).r_edge if h.r_kids >= 2 else h_i
|
||||
h.r_kids -= 1
|
||||
new_edge = h.r_edge
|
||||
# Correct upwards in the tree --- see Issue #251
|
||||
while h.head < 0 and i < this.length: # Guard infinite loop
|
||||
h += h.head
|
||||
h.r_edge = new_edge
|
||||
i += 1
|
||||
cdef vector[ArcC]* arcs
|
||||
if h_i > c_i:
|
||||
arcs = &this._left_arcs
|
||||
else:
|
||||
# Same logic applies for left edge, but we don't need to walk up
|
||||
# the tree, as the head is off the stack.
|
||||
h.l_edge = this.L_(h_i, 2).l_edge if h.l_kids >= 2 else h_i
|
||||
h.l_kids -= 1
|
||||
arcs = &this._right_arcs
|
||||
if arcs.size() == 0:
|
||||
return
|
||||
arc = arcs.back()
|
||||
if arc.head == h_i and arc.child == c_i:
|
||||
arcs.pop_back()
|
||||
else:
|
||||
for i in range(arcs.size()-1):
|
||||
arc = arcs.at(i)
|
||||
if arc.head == h_i and arc.child == c_i:
|
||||
arc.head = -1
|
||||
arc.child = -1
|
||||
arc.label = 0
|
||||
break
|
||||
|
||||
SpanC get_ent() nogil const:
|
||||
cdef SpanC ent
|
||||
if this._ents.size() == 0:
|
||||
ent.start = 0
|
||||
ent.end = 0
|
||||
ent.label = 0
|
||||
return ent
|
||||
else:
|
||||
return this._ents.back()
|
||||
|
||||
void open_ent(attr_t label) nogil:
|
||||
this._ents[this._e_i].start = this.B(0)
|
||||
this._ents[this._e_i].label = label
|
||||
this._ents[this._e_i].end = -1
|
||||
this._e_i += 1
|
||||
cdef SpanC ent
|
||||
ent.start = this.B(0)
|
||||
ent.label = label
|
||||
ent.end = -1
|
||||
this._ents.push_back(ent)
|
||||
|
||||
void close_ent() nogil:
|
||||
# Note that we don't decrement _e_i here! We want to maintain all
|
||||
# entities, not over-write them...
|
||||
this._ents[this._e_i-1].end = this.B(0)+1
|
||||
this._sent[this.B(0)].ent_iob = 1
|
||||
|
||||
void set_ent_tag(int i, int ent_iob, attr_t ent_type) nogil:
|
||||
if 0 <= i < this.length:
|
||||
this._sent[i].ent_iob = ent_iob
|
||||
this._sent[i].ent_type = ent_type
|
||||
|
||||
void set_break(int i) nogil:
|
||||
if 0 <= i < this.length:
|
||||
this._sent[i].sent_start = 1
|
||||
this._break = this._b_i
|
||||
this._ents.back().end = this.B(0)+1
|
||||
|
||||
void clone(const StateC* src) nogil:
|
||||
this.length = src.length
|
||||
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
|
||||
memcpy(this._stack, src._stack, this.length * sizeof(int))
|
||||
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
|
||||
memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
|
||||
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
|
||||
this._sent = src._sent
|
||||
this._stack = src._stack
|
||||
this._rebuffer = src._rebuffer
|
||||
this._sent_starts = src._sent_starts
|
||||
this._unshiftable = src._unshiftable
|
||||
memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0]))
|
||||
this._ents = src._ents
|
||||
this._left_arcs = src._left_arcs
|
||||
this._right_arcs = src._right_arcs
|
||||
this._b_i = src._b_i
|
||||
this._s_i = src._s_i
|
||||
this._e_i = src._e_i
|
||||
this._break = src._break
|
||||
this.offset = src.offset
|
||||
this._empty_token = src._empty_token
|
||||
|
||||
void fast_forward() nogil:
|
||||
# space token attachement policy:
|
||||
# - attach space tokens always to the last preceding real token
|
||||
# - except if it's the beginning of a sentence, then attach to the first following
|
||||
# - boundary case: a document containing multiple space tokens but nothing else,
|
||||
# then make the last space token the head of all others
|
||||
|
||||
while is_space_token(this.B_(0)) \
|
||||
or this.buffer_length() == 0 \
|
||||
or this.stack_depth() == 0:
|
||||
if this.buffer_length() == 0:
|
||||
# remove the last sentence's root from the stack
|
||||
if this.stack_depth() == 1:
|
||||
this.pop()
|
||||
# parser got stuck: reduce stack or unshift
|
||||
elif this.stack_depth() > 1:
|
||||
if this.has_head(this.S(0)):
|
||||
this.pop()
|
||||
else:
|
||||
this.unshift()
|
||||
# stack is empty but there is another sentence on the buffer
|
||||
elif (this.length - this._b_i) >= 1:
|
||||
this.push()
|
||||
else: # stack empty and nothing else coming
|
||||
break
|
||||
|
||||
elif is_space_token(this.B_(0)):
|
||||
# the normal case: we're somewhere inside a sentence
|
||||
if this.stack_depth() > 0:
|
||||
# assert not is_space_token(this.S_(0))
|
||||
# attach all coming space tokens to their last preceding
|
||||
# real token (which should be on the top of the stack)
|
||||
while is_space_token(this.B_(0)):
|
||||
this.add_arc(this.S(0),this.B(0),0)
|
||||
this.push()
|
||||
this.pop()
|
||||
# the rare case: we're at the beginning of a document:
|
||||
# space tokens are attached to the first real token on the buffer
|
||||
elif this.stack_depth() == 0:
|
||||
# store all space tokens on the stack until a real token shows up
|
||||
# or the last token on the buffer is reached
|
||||
while is_space_token(this.B_(0)) and this.buffer_length() > 1:
|
||||
this.push()
|
||||
# empty the stack by attaching all space tokens to the
|
||||
# first token on the buffer
|
||||
# boundary case: if all tokens are space tokens, the last one
|
||||
# becomes the head of all others
|
||||
while this.stack_depth() > 0:
|
||||
this.add_arc(this.B(0),this.S(0),0)
|
||||
this.pop()
|
||||
# move the first token onto the stack
|
||||
this.push()
|
||||
|
||||
elif this.stack_depth() == 0:
|
||||
# for one token sentences (?)
|
||||
if this.buffer_length() == 1:
|
||||
this.push()
|
||||
this.pop()
|
||||
# with an empty stack and a non-empty buffer
|
||||
# only shift is valid anyway
|
||||
elif (this.length - this._b_i) >= 1:
|
||||
this.push()
|
||||
|
||||
else: # can this even happen?
|
||||
break
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from .transition_system cimport Transition, TransitionSystem
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
||||
pass
|
||||
|
||||
|
||||
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil
|
||||
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil
|
||||
|
|
|
@ -14,16 +14,11 @@ from ._state cimport StateC
|
|||
|
||||
from ...errors import Errors
|
||||
|
||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||
cdef int BINARY_COSTS = 1
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
cdef attr_t SUBTOK_LABEL = hash_string(u'subtok')
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
|
||||
# Break transition from here
|
||||
# http://www.aclweb.org/anthology/P13-1074
|
||||
cdef enum:
|
||||
SHIFT
|
||||
REDUCE
|
||||
|
@ -61,9 +56,11 @@ cdef struct GoldParseStateC:
|
|||
int32_t* n_kids
|
||||
int32_t length
|
||||
int32_t stride
|
||||
weight_t push_cost
|
||||
weight_t pop_cost
|
||||
|
||||
|
||||
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
|
||||
cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
|
||||
heads, labels, sent_starts) except *:
|
||||
cdef GoldParseStateC gs
|
||||
gs.length = len(heads)
|
||||
|
@ -142,10 +139,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
|
|||
if head != i:
|
||||
gs.kids[head][js[head]] = i
|
||||
js[head] += 1
|
||||
gs.push_cost = push_cost(state, &gs)
|
||||
gs.pop_cost = pop_cost(state, &gs)
|
||||
return gs
|
||||
|
||||
|
||||
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
|
||||
cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil:
|
||||
for i in range(gs.length):
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
|
@ -160,9 +159,9 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
|
|||
gs.n_kids_in_stack[i] = 0
|
||||
gs.n_kids_in_buffer[i] = 0
|
||||
|
||||
for i in range(stcls.stack_depth()):
|
||||
s_i = stcls.S(i)
|
||||
if not is_head_unknown(gs, s_i):
|
||||
for i in range(s.stack_depth()):
|
||||
s_i = s.S(i)
|
||||
if not is_head_unknown(gs, s_i) and gs.heads[s_i] != s_i:
|
||||
gs.n_kids_in_stack[gs.heads[s_i]] += 1
|
||||
for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
|
||||
gs.state_bits[kid] = set_state_flag(
|
||||
|
@ -170,9 +169,11 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
|
|||
HEAD_IN_STACK,
|
||||
1
|
||||
)
|
||||
for i in range(stcls.buffer_length()):
|
||||
b_i = stcls.B(i)
|
||||
if not is_head_unknown(gs, b_i):
|
||||
for i in range(s.buffer_length()):
|
||||
b_i = s.B(i)
|
||||
if s.is_sent_start(b_i):
|
||||
break
|
||||
if not is_head_unknown(gs, b_i) and gs.heads[b_i] != b_i:
|
||||
gs.n_kids_in_buffer[gs.heads[b_i]] += 1
|
||||
for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
|
||||
gs.state_bits[kid] = set_state_flag(
|
||||
|
@ -180,6 +181,8 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
|
|||
HEAD_IN_BUFFER,
|
||||
1
|
||||
)
|
||||
gs.push_cost = push_cost(s, gs)
|
||||
gs.pop_cost = pop_cost(s, gs)
|
||||
|
||||
|
||||
cdef class ArcEagerGold:
|
||||
|
@ -191,17 +194,17 @@ cdef class ArcEagerGold:
|
|||
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||
labels = [label if label is not None else "" for label in labels]
|
||||
labels = [example.x.vocab.strings.add(label) for label in labels]
|
||||
sent_starts = example.get_aligned("SENT_START")
|
||||
assert len(heads) == len(labels) == len(sent_starts)
|
||||
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
|
||||
sent_starts = example.get_aligned_sent_starts()
|
||||
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
|
||||
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
|
||||
|
||||
def update(self, StateClass stcls):
|
||||
update_gold_state(&self.c, stcls)
|
||||
update_gold_state(&self.c, stcls.c)
|
||||
|
||||
|
||||
cdef int check_state_gold(char state_bits, char flag) nogil:
|
||||
cdef char one = 1
|
||||
return state_bits & (one << flag)
|
||||
return 1 if (state_bits & (one << flag)) else 0
|
||||
|
||||
|
||||
cdef int set_state_flag(char state_bits, char flag, int value) nogil:
|
||||
|
@ -232,41 +235,30 @@ cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
|
|||
|
||||
# Helper functions for the arc-eager oracle
|
||||
|
||||
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
|
||||
cdef weight_t cost = 0
|
||||
if is_head_in_stack(gold, target):
|
||||
b0 = state.B(0)
|
||||
if b0 < 0:
|
||||
return 9000
|
||||
if is_head_in_stack(gold, b0):
|
||||
cost += 1
|
||||
cost += gold.n_kids_in_stack[target]
|
||||
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||
cost += gold.n_kids_in_stack[b0]
|
||||
if Break.is_valid(state, 0) and is_sent_start(gold, state.B(1)):
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
cdef weight_t pop_cost(StateClass stcls, const void* _gold, int target) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
|
||||
cdef weight_t cost = 0
|
||||
if is_head_in_buffer(gold, target):
|
||||
cost += 1
|
||||
cost += gold[0].n_kids_in_buffer[target]
|
||||
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||
s0 = state.S(0)
|
||||
if s0 < 0:
|
||||
return 9000
|
||||
if is_head_in_buffer(gold, s0):
|
||||
cost += 1
|
||||
cost += gold.n_kids_in_buffer[s0]
|
||||
return cost
|
||||
|
||||
|
||||
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
if arc_is_gold(gold, head, child):
|
||||
return 0
|
||||
elif stcls.H(child) == gold.heads[child]:
|
||||
return 1
|
||||
# Head in buffer
|
||||
elif is_head_in_buffer(gold, child):
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
|
||||
if is_head_unknown(gold, child):
|
||||
return True
|
||||
|
@ -276,7 +268,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
|
|||
return False
|
||||
|
||||
|
||||
cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil:
|
||||
cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil:
|
||||
if is_head_unknown(gold, child):
|
||||
return True
|
||||
elif label == 0:
|
||||
|
@ -292,218 +284,251 @@ cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
|
|||
|
||||
|
||||
cdef class Shift:
|
||||
"""Move the first word of the buffer onto the stack and mark it as "shifted"
|
||||
|
||||
Validity:
|
||||
* If stack is empty
|
||||
* At least two words in sentence
|
||||
* Word has not been shifted before
|
||||
|
||||
Cost: push_cost
|
||||
|
||||
Action:
|
||||
* Mark B[0] as 'shifted'
|
||||
* Push stack
|
||||
* Advance buffer
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
||||
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1
|
||||
if st.stack_depth() == 0:
|
||||
return 1
|
||||
elif st.buffer_length() < 2:
|
||||
return 0
|
||||
elif st.is_sent_start(st.B(0)):
|
||||
return 0
|
||||
elif st.is_unshiftable(st.B(0)):
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return push_cost(s, gold, s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
return 0
|
||||
return gold.push_cost
|
||||
|
||||
|
||||
cdef class Reduce:
|
||||
"""
|
||||
Pop from the stack. If it has no head and the stack isn't empty, place
|
||||
it back on the buffer.
|
||||
|
||||
Validity:
|
||||
* Stack not empty
|
||||
* Buffer nt empty
|
||||
* Stack depth 1 and cannot sent start l_edge(st.B(0))
|
||||
|
||||
Cost:
|
||||
* If B[0] is the start of a sentence, cost is 0
|
||||
* Arcs between stack and buffer
|
||||
* If arc has no head, we're saving arcs between S[0] and S[1:], so decrement
|
||||
cost by those arcs.
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
return st.stack_depth() >= 2
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
if st.has_head(st.S(0)):
|
||||
st.pop()
|
||||
else:
|
||||
st.unshift()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass st, const void* _gold) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
s0 = st.S(0)
|
||||
cost = pop_cost(st, gold, s0)
|
||||
return_to_buffer = not st.has_head(s0)
|
||||
if return_to_buffer:
|
||||
# Decrement cost for the arcs we save, as we'll be putting this
|
||||
# back to the buffer
|
||||
if is_head_in_stack(gold, s0):
|
||||
cost -= 1
|
||||
cost -= gold.n_kids_in_stack[s0]
|
||||
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||
cost -= 1
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
|
||||
return 0
|
||||
|
||||
|
||||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
|
||||
return 0
|
||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
||||
return sent_start != 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
st.pop()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
|
||||
cdef weight_t cost = 0
|
||||
s0 = s.S(0)
|
||||
b0 = s.B(0)
|
||||
if arc_is_gold(gold, b0, s0):
|
||||
# Have a negative cost if we 'recover' from the wrong dependency
|
||||
return 0 if not s.has_head(s0) else -1
|
||||
else:
|
||||
# Account for deps we might lose between S0 and stack
|
||||
if not s.has_head(s0):
|
||||
cost += gold.n_kids_in_stack[s0]
|
||||
if is_head_in_buffer(gold, s0):
|
||||
cost += 1
|
||||
return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
|
||||
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
||||
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
|
||||
return 0
|
||||
sent_start = st._sent[st.B_(0).l_edge].sent_start
|
||||
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
if arc_is_gold(gold, s.S(0), s.B(0)):
|
||||
return 0
|
||||
elif s.c.shifted[s.B(0)]:
|
||||
return push_cost(s, gold, s.B(0))
|
||||
else:
|
||||
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
|
||||
|
||||
|
||||
cdef class Break:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
if st.stack_depth() == 0:
|
||||
return False
|
||||
elif st.at_break():
|
||||
return False
|
||||
elif st.stack_depth() < 1:
|
||||
return False
|
||||
elif st.B_(0).l_edge < 0:
|
||||
return False
|
||||
elif st._sent[st.B_(0).l_edge].sent_start < 0:
|
||||
elif st.buffer_length() == 0:
|
||||
return True
|
||||
elif st.stack_depth() == 1 and st.cannot_sent_start(st.l_edge(st.B(0))):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.set_break(st.B_(0).l_edge)
|
||||
st.fast_forward()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cost = 0
|
||||
for i in range(s.stack_depth()):
|
||||
S_i = s.S(i)
|
||||
cost += gold.n_kids_in_buffer[S_i]
|
||||
if is_head_in_buffer(gold, S_i):
|
||||
cost += 1
|
||||
# It's weird not to check the gold sentence boundaries but if we do,
|
||||
# we can't account for "sunk costs", i.e. situations where we're already
|
||||
# wrong.
|
||||
s0_root = _get_root(s.S(0), gold)
|
||||
b0_root = _get_root(s.B(0), gold)
|
||||
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
|
||||
return cost
|
||||
if st.has_head(st.S(0)) or st.stack_depth() == 1:
|
||||
st.pop()
|
||||
else:
|
||||
return cost + 1
|
||||
st.unshift()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
if state.is_sent_start(state.B(0)):
|
||||
return 0
|
||||
s0 = state.S(0)
|
||||
cost = gold.pop_cost
|
||||
if not state.has_head(s0):
|
||||
# Decrement cost for the arcs we save, as we'll be putting this
|
||||
# back to the buffer
|
||||
if is_head_in_stack(gold, s0):
|
||||
cost -= 1
|
||||
cost -= gold.n_kids_in_stack[s0]
|
||||
return cost
|
||||
|
||||
cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
|
||||
if is_head_unknown(gold, word):
|
||||
return -1
|
||||
while gold.heads[word] != word and word >= 0:
|
||||
word = gold.heads[word]
|
||||
if is_head_unknown(gold, word):
|
||||
return -1
|
||||
|
||||
cdef class LeftArc:
|
||||
"""Add an arc between B[0] and S[0], replacing the previous head of S[0] if
|
||||
one is set. Pop S[0] from the stack.
|
||||
|
||||
Validity:
|
||||
* len(S) >= 1
|
||||
* len(B) >= 1
|
||||
* not is_sent_start(B[0])
|
||||
|
||||
Cost:
|
||||
pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0]))
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if st.stack_depth() == 0:
|
||||
return 0
|
||||
elif st.buffer_length() == 0:
|
||||
return 0
|
||||
elif st.is_sent_start(st.B(0)):
|
||||
return 0
|
||||
elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
|
||||
return 0
|
||||
else:
|
||||
return word
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
# If we change the stack, it's okay to remove the shifted mark, as
|
||||
# we can't get in an infinite loop this way.
|
||||
st.set_reshiftable(st.B(0))
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cdef weight_t cost = gold.pop_cost
|
||||
s0 = state.S(0)
|
||||
s1 = state.S(1)
|
||||
b0 = state.B(0)
|
||||
if state.has_head(s0):
|
||||
# Increment cost if we're clobbering a correct arc
|
||||
cost += gold.heads[s0] == s1
|
||||
else:
|
||||
# If there's no head, we're losing arcs between S0 and S[1:].
|
||||
cost += is_head_in_stack(gold, s0)
|
||||
cost += gold.n_kids_in_stack[s0]
|
||||
if b0 != -1 and s0 != -1 and gold.heads[s0] == b0:
|
||||
cost -= 1
|
||||
cost += not label_is_gold(gold, s0, label)
|
||||
return cost
|
||||
|
||||
|
||||
cdef class RightArc:
|
||||
"""
|
||||
Add an arc from S[0] to B[0]. Push B[0].
|
||||
|
||||
Validity:
|
||||
* len(S) >= 1
|
||||
* len(B) >= 1
|
||||
* not is_sent_start(B[0])
|
||||
|
||||
Cost:
|
||||
push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label)
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if st.stack_depth() == 0:
|
||||
return 0
|
||||
elif st.buffer_length() == 0:
|
||||
return 0
|
||||
elif st.is_sent_start(st.B(0)):
|
||||
return 0
|
||||
elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
|
||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
|
||||
@staticmethod
|
||||
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cost = gold.push_cost
|
||||
s0 = state.S(0)
|
||||
b0 = state.B(0)
|
||||
if s0 != -1 and b0 != -1 and gold.heads[b0] == s0:
|
||||
cost -= 1
|
||||
cost += not label_is_gold(gold, b0, label)
|
||||
elif is_head_in_buffer(gold, b0) and not state.is_unshiftable(b0):
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
cdef class Break:
|
||||
"""Mark the second word of the buffer as the start of a
|
||||
sentence.
|
||||
|
||||
Validity:
|
||||
* len(buffer) >= 2
|
||||
* B[1] == B[0] + 1
|
||||
* not is_sent_start(B[1])
|
||||
* not cannot_sent_start(B[1])
|
||||
|
||||
Action:
|
||||
* mark_sent_start(B[1])
|
||||
|
||||
Cost:
|
||||
* not is_sent_start(B[1])
|
||||
* Arcs between B[0] and B[1:]
|
||||
* Arcs between S and B[1]
|
||||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef int i
|
||||
if st.buffer_length() < 2:
|
||||
return False
|
||||
elif st.B(1) != st.B(0) + 1:
|
||||
return False
|
||||
elif st.is_sent_start(st.B(1)):
|
||||
return False
|
||||
elif st.cannot_sent_start(st.B(1)):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.set_sent_start(st.B(1), 1)
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
|
||||
gold = <const GoldParseStateC*>_gold
|
||||
cdef int b0 = state.B(0)
|
||||
cdef int cost = 0
|
||||
cdef int si
|
||||
for i in range(state.stack_depth()):
|
||||
si = state.S(i)
|
||||
if is_head_in_buffer(gold, si):
|
||||
cost += 1
|
||||
cost += gold.n_kids_in_buffer[si]
|
||||
# We need to score into B[1:], so subtract deps that are at b0
|
||||
if gold.heads[b0] == si:
|
||||
cost -= 1
|
||||
if gold.heads[si] == b0:
|
||||
cost -= 1
|
||||
if not is_sent_start(gold, state.B(1)) \
|
||||
and not is_sent_start_unknown(gold, state.B(1)):
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
st = new StateC(<const TokenC*>tokens, length)
|
||||
for i in range(st.length):
|
||||
if st._sent[i].dep == 0:
|
||||
st._sent[i].l_edge = i
|
||||
st._sent[i].r_edge = i
|
||||
st._sent[i].head = 0
|
||||
st._sent[i].dep = 0
|
||||
st._sent[i].l_kids = 0
|
||||
st._sent[i].r_kids = 0
|
||||
st.fast_forward()
|
||||
return <void*>st
|
||||
|
||||
|
||||
|
@ -515,6 +540,8 @@ cdef int _del_state(Pool mem, void* state, void* x) except -1:
|
|||
cdef class ArcEager(TransitionSystem):
|
||||
def __init__(self, *args, **kwargs):
|
||||
TransitionSystem.__init__(self, *args, **kwargs)
|
||||
self.init_beam_state = _init_state
|
||||
self.del_beam_state = _del_state
|
||||
|
||||
@classmethod
|
||||
def get_actions(cls, **kwargs):
|
||||
|
@ -537,7 +564,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
label = 'ROOT'
|
||||
if head == child:
|
||||
actions[BREAK][label] += 1
|
||||
elif head < child:
|
||||
if head < child:
|
||||
actions[RIGHT][label] += 1
|
||||
actions[REDUCE][''] += 1
|
||||
elif head > child:
|
||||
|
@ -567,8 +594,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
t.do(state.c, t.label)
|
||||
return state
|
||||
|
||||
def is_gold_parse(self, StateClass state, gold):
|
||||
raise NotImplementedError
|
||||
def is_gold_parse(self, StateClass state, ArcEagerGold gold):
|
||||
for i in range(state.c.length):
|
||||
token = state.c.safe_get(i)
|
||||
if not arc_is_gold(&gold.c, i, i+token.head):
|
||||
return False
|
||||
elif not label_is_gold(&gold.c, i, token.dep):
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_gold(self, StateClass state, Example example):
|
||||
gold = ArcEagerGold(self, state, example)
|
||||
|
@ -576,6 +609,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
return gold
|
||||
|
||||
def init_gold_batch(self, examples):
|
||||
# TODO: Projectivitity?
|
||||
all_states = self.init_batch([eg.predicted for eg in examples])
|
||||
golds = []
|
||||
states = []
|
||||
|
@ -662,24 +696,13 @@ cdef class ArcEager(TransitionSystem):
|
|||
raise ValueError(Errors.E019.format(action=move, src='arc_eager'))
|
||||
return t
|
||||
|
||||
cdef int initialize_state(self, StateC* st) nogil:
|
||||
for i in range(st.length):
|
||||
if st._sent[i].dep == 0:
|
||||
st._sent[i].l_edge = i
|
||||
st._sent[i].r_edge = i
|
||||
st._sent[i].head = 0
|
||||
st._sent[i].dep = 0
|
||||
st._sent[i].l_kids = 0
|
||||
st._sent[i].r_kids = 0
|
||||
st.fast_forward()
|
||||
|
||||
cdef int finalize_state(self, StateC* st) nogil:
|
||||
cdef int i
|
||||
for i in range(st.length):
|
||||
if st._sent[i].head == 0:
|
||||
st._sent[i].dep = self.root_label
|
||||
|
||||
def finalize_doc(self, Doc doc):
|
||||
def set_annotations(self, StateClass state, Doc doc):
|
||||
for arc in state.arcs:
|
||||
doc.c[arc["child"]].head = arc["head"] - arc["child"]
|
||||
doc.c[arc["child"]].dep = arc["label"]
|
||||
for i in range(doc.length):
|
||||
if doc.c[i].head == 0:
|
||||
doc.c[i].dep = self.root_label
|
||||
set_children_from_heads(doc.c, 0, doc.length)
|
||||
|
||||
def has_gold(self, Example eg, start=0, end=None):
|
||||
|
@ -690,7 +713,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
return False
|
||||
|
||||
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
cdef int[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = Shift.is_valid(st, 0)
|
||||
is_valid[REDUCE] = Reduce.is_valid(st, 0)
|
||||
is_valid[LEFT] = LeftArc.is_valid(st, 0)
|
||||
|
@ -710,29 +733,31 @@ cdef class ArcEager(TransitionSystem):
|
|||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
||||
else:
|
||||
cost = 9000
|
||||
return cost
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass stcls, gold) except -1:
|
||||
const StateC* state, gold) except -1:
|
||||
if not isinstance(gold, ArcEagerGold):
|
||||
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
|
||||
cdef ArcEagerGold gold_ = gold
|
||||
gold_.update(stcls)
|
||||
gold_state = gold_.c
|
||||
update_gold_state(&gold_state, state)
|
||||
self.set_valid(is_valid, state)
|
||||
cdef int n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
is_valid[i] = True
|
||||
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
if is_valid[i]:
|
||||
costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
|
||||
if costs[i] <= 0:
|
||||
n_gold += 1
|
||||
else:
|
||||
is_valid[i] = False
|
||||
costs[i] = 9000
|
||||
if n_gold < 1:
|
||||
for i in range(self.n_moves):
|
||||
print(self.get_class_name(i), is_valid[i], costs[i])
|
||||
print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
|
||||
raise ValueError
|
||||
|
||||
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
|
||||
|
@ -748,12 +773,13 @@ cdef class ArcEager(TransitionSystem):
|
|||
failed = False
|
||||
while not state.is_final():
|
||||
try:
|
||||
self.set_costs(is_valid, costs, state, gold)
|
||||
self.set_costs(is_valid, costs, state.c, gold)
|
||||
except ValueError:
|
||||
failed = True
|
||||
break
|
||||
min_cost = min(costs[i] for i in range(self.n_moves))
|
||||
for i in range(self.n_moves):
|
||||
if is_valid[i] and costs[i] <= 0:
|
||||
if is_valid[i] and costs[i] <= min_cost:
|
||||
action = self.c[i]
|
||||
history.append(i)
|
||||
s0 = state.S(0)
|
||||
|
@ -762,9 +788,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
example = _debug
|
||||
debug_log.append(" ".join((
|
||||
self.get_class_name(i),
|
||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
|
||||
"S0 head?", str(state.has_head(state.S(0))),
|
||||
state.print_state()
|
||||
)))
|
||||
action.do(state.c, action.label)
|
||||
break
|
||||
|
@ -783,6 +807,8 @@ cdef class ArcEager(TransitionSystem):
|
|||
print("Aligned heads")
|
||||
for i, head in enumerate(aligned_heads):
|
||||
print(example.x[i], example.x[head] if head is not None else "__")
|
||||
print("Aligned sent starts")
|
||||
print(example.get_aligned_sent_starts())
|
||||
|
||||
print("Predicted tokens")
|
||||
print([(w.i, w.text) for w in example.x])
|
||||
|
|
|
@ -3,9 +3,12 @@ from cymem.cymem cimport Pool
|
|||
|
||||
from collections import Counter
|
||||
|
||||
from ...tokens.doc cimport Doc
|
||||
from ...tokens.span import Span
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from ...lexeme cimport Lexeme
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...structs cimport TokenC
|
||||
from ...training.example cimport Example
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
@ -46,17 +49,17 @@ cdef class BiluoGold:
|
|||
|
||||
def __init__(self, BiluoPushDown moves, StateClass stcls, Example example):
|
||||
self.mem = Pool()
|
||||
self.c = create_gold_state(self.mem, moves, stcls, example)
|
||||
self.c = create_gold_state(self.mem, moves, stcls.c, example)
|
||||
|
||||
def update(self, StateClass stcls):
|
||||
update_gold_state(&self.c, stcls)
|
||||
update_gold_state(&self.c, stcls.c)
|
||||
|
||||
|
||||
|
||||
cdef GoldNERStateC create_gold_state(
|
||||
Pool mem,
|
||||
BiluoPushDown moves,
|
||||
StateClass stcls,
|
||||
const StateC* stcls,
|
||||
Example example
|
||||
) except *:
|
||||
cdef GoldNERStateC gs
|
||||
|
@ -67,7 +70,7 @@ cdef GoldNERStateC create_gold_state(
|
|||
return gs
|
||||
|
||||
|
||||
cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
|
||||
cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
|
||||
# We don't need to update each time, unlike the parser.
|
||||
pass
|
||||
|
||||
|
@ -75,14 +78,15 @@ cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
|
|||
cdef do_func_t[N_MOVES] do_funcs
|
||||
|
||||
|
||||
cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
|
||||
if not st.entity_is_open():
|
||||
cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
|
||||
if not state.entity_is_open():
|
||||
return False
|
||||
|
||||
cdef const Transition* gold = &golds[st.E(0)]
|
||||
cdef const Transition* gold = &golds[state.E(0)]
|
||||
ent = state.get_ent()
|
||||
if gold.move != BEGIN and gold.move != UNIT:
|
||||
return True
|
||||
elif gold.label != st.E_(0).ent_type:
|
||||
elif gold.label != ent.label:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
@ -228,15 +232,18 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
self.labels[action][label_name] = -1
|
||||
return 1
|
||||
|
||||
cdef int initialize_state(self, StateC* st) nogil:
|
||||
# This is especially necessary when we use limited training data.
|
||||
for i in range(st.length):
|
||||
if st._sent[i].ent_type != 0:
|
||||
with gil:
|
||||
self.add_action(BEGIN, st._sent[i].ent_type)
|
||||
self.add_action(IN, st._sent[i].ent_type)
|
||||
self.add_action(UNIT, st._sent[i].ent_type)
|
||||
self.add_action(LAST, st._sent[i].ent_type)
|
||||
def set_annotations(self, StateClass state, Doc doc):
|
||||
cdef int i
|
||||
ents = []
|
||||
for i in range(state.c._ents.size()):
|
||||
ent = state.c._ents.at(i)
|
||||
if ent.start != -1 and ent.end != -1:
|
||||
ents.append(Span(doc, ent.start, ent.end, label=ent.label))
|
||||
doc.set_ents(ents, default="unmodified")
|
||||
# Set non-blocked tokens to O
|
||||
for i in range(doc.length):
|
||||
if doc.c[i].ent_iob == 0:
|
||||
doc.c[i].ent_iob = 2
|
||||
|
||||
def init_gold(self, StateClass state, Example example):
|
||||
return BiluoGold(self, state, example)
|
||||
|
@ -255,26 +262,25 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
||||
else:
|
||||
cost = 9000
|
||||
return cost
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass stcls, gold) except -1:
|
||||
const StateC* state, gold) except -1:
|
||||
if not isinstance(gold, BiluoGold):
|
||||
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
||||
cdef BiluoGold gold_ = gold
|
||||
gold_.update(stcls)
|
||||
gold_state = gold_.c
|
||||
update_gold_state(&gold_state, state)
|
||||
n_gold = 0
|
||||
self.set_valid(is_valid, state)
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
is_valid[i] = 1
|
||||
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
if is_valid[i]:
|
||||
costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
|
||||
n_gold += costs[i] <= 0
|
||||
else:
|
||||
is_valid[i] = 0
|
||||
costs[i] = 9000
|
||||
if n_gold < 1:
|
||||
raise ValueError
|
||||
|
@ -290,7 +296,7 @@ cdef class Missing:
|
|||
pass
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
return 9000
|
||||
|
||||
|
||||
|
@ -299,10 +305,10 @@ cdef class Begin:
|
|||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
||||
# If we're the last token of the input, we can't B -- must U or O.
|
||||
if st.B(1) == -1:
|
||||
if st.entity_is_open():
|
||||
return False
|
||||
elif st.entity_is_open():
|
||||
if st.buffer_length() < 2:
|
||||
# If we're the last token of the input, we can't B -- must U or O.
|
||||
return False
|
||||
elif label == 0:
|
||||
return False
|
||||
|
@ -337,12 +343,11 @@ cdef class Begin:
|
|||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.open_ent(label)
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||
|
@ -366,16 +371,17 @@ cdef class Begin:
|
|||
cdef class In:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if not st.entity_is_open():
|
||||
return False
|
||||
if st.buffer_length() < 2:
|
||||
# If we're at the end, we can't I.
|
||||
return False
|
||||
ent = st.get_ent()
|
||||
cdef int preset_ent_iob = st.B_(0).ent_iob
|
||||
cdef attr_t preset_ent_label = st.B_(0).ent_type
|
||||
if label == 0:
|
||||
return False
|
||||
elif st.E_(0).ent_type != label:
|
||||
return False
|
||||
elif not st.entity_is_open():
|
||||
return False
|
||||
elif st.B(1) == -1:
|
||||
# If we're at the end, we can't I.
|
||||
elif ent.label != label:
|
||||
return False
|
||||
elif preset_ent_iob == 3:
|
||||
return False
|
||||
|
@ -401,12 +407,11 @@ cdef class In:
|
|||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.set_ent_tag(st.B(0), 1, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
move = IN
|
||||
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
||||
|
@ -457,7 +462,7 @@ cdef class Last:
|
|||
# Otherwise, force acceptance, even if we're across a sentence
|
||||
# boundary or the token is whitespace.
|
||||
return True
|
||||
elif st.E_(0).ent_type != label:
|
||||
elif st.get_ent().label != label:
|
||||
return False
|
||||
elif st.B_(1).ent_iob == 1:
|
||||
# If a preset entity has I next, we can't L here.
|
||||
|
@ -468,12 +473,11 @@ cdef class Last:
|
|||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.close_ent()
|
||||
st.set_ent_tag(st.B(0), 1, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
move = LAST
|
||||
|
||||
|
@ -537,12 +541,11 @@ cdef class Unit:
|
|||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.open_ent(label)
|
||||
st.close_ent()
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||
|
@ -578,12 +581,11 @@ cdef class Out:
|
|||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.set_ent_tag(st.B(0), 2, 0)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||
|
|
|
@ -2,30 +2,24 @@ from cymem.cymem cimport Pool
|
|||
|
||||
from ...structs cimport TokenC, SpanC
|
||||
from ...typedefs cimport attr_t
|
||||
from ...tokens.doc cimport Doc
|
||||
|
||||
from ._state cimport StateC
|
||||
|
||||
|
||||
cdef class StateClass:
|
||||
cdef Pool mem
|
||||
cdef StateC* c
|
||||
cdef readonly Doc doc
|
||||
cdef int _borrowed
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass init(const TokenC* sent, int length):
|
||||
cdef inline StateClass borrow(StateC* ptr, Doc doc):
|
||||
cdef StateClass self = StateClass()
|
||||
self.c = new StateC(sent, length)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass borrow(StateC* ptr):
|
||||
cdef StateClass self = StateClass()
|
||||
del self.c
|
||||
self.c = ptr
|
||||
self._borrowed = 1
|
||||
self.doc = doc
|
||||
return self
|
||||
|
||||
|
||||
@staticmethod
|
||||
cdef inline StateClass init_offset(const TokenC* sent, int length, int
|
||||
offset):
|
||||
|
@ -33,105 +27,3 @@ cdef class StateClass:
|
|||
self.c = new StateC(sent, length)
|
||||
self.c.offset = offset
|
||||
return self
|
||||
|
||||
cdef inline int S(self, int i) nogil:
|
||||
return self.c.S(i)
|
||||
|
||||
cdef inline int B(self, int i) nogil:
|
||||
return self.c.B(i)
|
||||
|
||||
cdef inline const TokenC* S_(self, int i) nogil:
|
||||
return self.c.S_(i)
|
||||
|
||||
cdef inline const TokenC* B_(self, int i) nogil:
|
||||
return self.c.B_(i)
|
||||
|
||||
cdef inline const TokenC* H_(self, int i) nogil:
|
||||
return self.c.H_(i)
|
||||
|
||||
cdef inline const TokenC* E_(self, int i) nogil:
|
||||
return self.c.E_(i)
|
||||
|
||||
cdef inline const TokenC* L_(self, int i, int idx) nogil:
|
||||
return self.c.L_(i, idx)
|
||||
|
||||
cdef inline const TokenC* R_(self, int i, int idx) nogil:
|
||||
return self.c.R_(i, idx)
|
||||
|
||||
cdef inline const TokenC* safe_get(self, int i) nogil:
|
||||
return self.c.safe_get(i)
|
||||
|
||||
cdef inline int H(self, int i) nogil:
|
||||
return self.c.H(i)
|
||||
|
||||
cdef inline int E(self, int i) nogil:
|
||||
return self.c.E(i)
|
||||
|
||||
cdef inline int L(self, int i, int idx) nogil:
|
||||
return self.c.L(i, idx)
|
||||
|
||||
cdef inline int R(self, int i, int idx) nogil:
|
||||
return self.c.R(i, idx)
|
||||
|
||||
cdef inline bint empty(self) nogil:
|
||||
return self.c.empty()
|
||||
|
||||
cdef inline bint eol(self) nogil:
|
||||
return self.c.eol()
|
||||
|
||||
cdef inline bint at_break(self) nogil:
|
||||
return self.c.at_break()
|
||||
|
||||
cdef inline bint has_head(self, int i) nogil:
|
||||
return self.c.has_head(i)
|
||||
|
||||
cdef inline int n_L(self, int i) nogil:
|
||||
return self.c.n_L(i)
|
||||
|
||||
cdef inline int n_R(self, int i) nogil:
|
||||
return self.c.n_R(i)
|
||||
|
||||
cdef inline bint stack_is_connected(self) nogil:
|
||||
return False
|
||||
|
||||
cdef inline bint entity_is_open(self) nogil:
|
||||
return self.c.entity_is_open()
|
||||
|
||||
cdef inline int stack_depth(self) nogil:
|
||||
return self.c.stack_depth()
|
||||
|
||||
cdef inline int buffer_length(self) nogil:
|
||||
return self.c.buffer_length()
|
||||
|
||||
cdef inline void push(self) nogil:
|
||||
self.c.push()
|
||||
|
||||
cdef inline void pop(self) nogil:
|
||||
self.c.pop()
|
||||
|
||||
cdef inline void unshift(self) nogil:
|
||||
self.c.unshift()
|
||||
|
||||
cdef inline void add_arc(self, int head, int child, attr_t label) nogil:
|
||||
self.c.add_arc(head, child, label)
|
||||
|
||||
cdef inline void del_arc(self, int head, int child) nogil:
|
||||
self.c.del_arc(head, child)
|
||||
|
||||
cdef inline void open_ent(self, attr_t label) nogil:
|
||||
self.c.open_ent(label)
|
||||
|
||||
cdef inline void close_ent(self) nogil:
|
||||
self.c.close_ent()
|
||||
|
||||
cdef inline void set_ent_tag(self, int i, int ent_iob, attr_t ent_type) nogil:
|
||||
self.c.set_ent_tag(i, ent_iob, ent_type)
|
||||
|
||||
cdef inline void set_break(self, int i) nogil:
|
||||
self.c.set_break(i)
|
||||
|
||||
cdef inline void clone(self, StateClass src) nogil:
|
||||
self.c.clone(src.c)
|
||||
|
||||
cdef inline void fast_forward(self) nogil:
|
||||
self.c.fast_forward()
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
# cython: infer_types=True
|
||||
import numpy
|
||||
from libcpp.vector cimport vector
|
||||
from ._state cimport ArcC
|
||||
|
||||
from ...tokens.doc cimport Doc
|
||||
|
||||
|
||||
cdef class StateClass:
|
||||
def __init__(self, Doc doc=None, int offset=0):
|
||||
cdef Pool mem = Pool()
|
||||
self.mem = mem
|
||||
self._borrowed = 0
|
||||
if doc is not None:
|
||||
self.c = new StateC(doc.c, doc.length)
|
||||
self.c.offset = offset
|
||||
self.doc = doc
|
||||
else:
|
||||
self.doc = None
|
||||
|
||||
def __dealloc__(self):
|
||||
if self._borrowed != 1:
|
||||
|
@ -19,36 +22,157 @@ cdef class StateClass:
|
|||
|
||||
@property
|
||||
def stack(self):
|
||||
return {self.S(i) for i in range(self.c._s_i)}
|
||||
return [self.S(i) for i in range(self.c.stack_depth())]
|
||||
|
||||
@property
|
||||
def queue(self):
|
||||
return {self.B(i) for i in range(self.c.buffer_length())}
|
||||
return [self.B(i) for i in range(self.c.buffer_length())]
|
||||
|
||||
@property
|
||||
def token_vector_lenth(self):
|
||||
return self.doc.tensor.shape[1]
|
||||
|
||||
@property
|
||||
def history(self):
|
||||
hist = numpy.ndarray((8,), dtype='i')
|
||||
for i in range(8):
|
||||
hist[i] = self.c.get_hist(i+1)
|
||||
return hist
|
||||
def arcs(self):
|
||||
cdef vector[ArcC] arcs
|
||||
self.c.get_arcs(&arcs)
|
||||
return list(arcs)
|
||||
#py_arcs = []
|
||||
#for arc in arcs:
|
||||
# if arc.head != -1 and arc.child != -1:
|
||||
# py_arcs.append((arc.head, arc.child, arc.label))
|
||||
#return arcs
|
||||
|
||||
def add_arc(self, int head, int child, int label):
|
||||
self.c.add_arc(head, child, label)
|
||||
|
||||
def del_arc(self, int head, int child):
|
||||
self.c.del_arc(head, child)
|
||||
|
||||
def H(self, int child):
|
||||
return self.c.H(child)
|
||||
|
||||
def L(self, int head, int idx):
|
||||
return self.c.L(head, idx)
|
||||
|
||||
def R(self, int head, int idx):
|
||||
return self.c.R(head, idx)
|
||||
|
||||
@property
|
||||
def _b_i(self):
|
||||
return self.c._b_i
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return self.c.length
|
||||
|
||||
def is_final(self):
|
||||
return self.c.is_final()
|
||||
|
||||
def copy(self):
|
||||
cdef StateClass new_state = StateClass.init(self.c._sent, self.c.length)
|
||||
cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset)
|
||||
new_state.c.clone(self.c)
|
||||
return new_state
|
||||
|
||||
def print_state(self, words):
|
||||
def print_state(self):
|
||||
words = [token.text for token in self.doc]
|
||||
words = list(words) + ['_']
|
||||
top = f"{words[self.S(0)]}_{self.S_(0).head}"
|
||||
second = f"{words[self.S(1)]}_{self.S_(1).head}"
|
||||
third = f"{words[self.S(2)]}_{self.S_(2).head}"
|
||||
n0 = words[self.B(0)]
|
||||
n1 = words[self.B(1)]
|
||||
return ' '.join((third, second, top, '|', n0, n1))
|
||||
bools = ["F", "T"]
|
||||
sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))]
|
||||
shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)]
|
||||
shifted.append("")
|
||||
sent_starts.append("")
|
||||
top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}"
|
||||
second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}"
|
||||
third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}"
|
||||
n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}"
|
||||
n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}"
|
||||
return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1))
|
||||
|
||||
def S(self, int i):
|
||||
return self.c.S(i)
|
||||
|
||||
def B(self, int i):
|
||||
return self.c.B(i)
|
||||
|
||||
def H(self, int i):
|
||||
return self.c.H(i)
|
||||
|
||||
def E(self, int i):
|
||||
return self.c.E(i)
|
||||
|
||||
def L(self, int i, int idx):
|
||||
return self.c.L(i, idx)
|
||||
|
||||
def R(self, int i, int idx):
|
||||
return self.c.R(i, idx)
|
||||
|
||||
def S_(self, int i):
|
||||
return self.doc[self.c.S(i)]
|
||||
|
||||
def B_(self, int i):
|
||||
return self.doc[self.c.B(i)]
|
||||
|
||||
def H_(self, int i):
|
||||
return self.doc[self.c.H(i)]
|
||||
|
||||
def E_(self, int i):
|
||||
return self.doc[self.c.E(i)]
|
||||
|
||||
def L_(self, int i, int idx):
|
||||
return self.doc[self.c.L(i, idx)]
|
||||
|
||||
def R_(self, int i, int idx):
|
||||
return self.doc[self.c.R(i, idx)]
|
||||
|
||||
def empty(self):
|
||||
return self.c.empty()
|
||||
|
||||
def eol(self):
|
||||
return self.c.eol()
|
||||
|
||||
def at_break(self):
|
||||
return False
|
||||
#return self.c.at_break()
|
||||
|
||||
def has_head(self, int i):
|
||||
return self.c.has_head(i)
|
||||
|
||||
def n_L(self, int i):
|
||||
return self.c.n_L(i)
|
||||
|
||||
def n_R(self, int i):
|
||||
return self.c.n_R(i)
|
||||
|
||||
def entity_is_open(self):
|
||||
return self.c.entity_is_open()
|
||||
|
||||
def stack_depth(self):
|
||||
return self.c.stack_depth()
|
||||
|
||||
def buffer_length(self):
|
||||
return self.c.buffer_length()
|
||||
|
||||
def push(self):
|
||||
self.c.push()
|
||||
|
||||
def pop(self):
|
||||
self.c.pop()
|
||||
|
||||
def unshift(self):
|
||||
self.c.unshift()
|
||||
|
||||
def add_arc(self, int head, int child, attr_t label):
|
||||
self.c.add_arc(head, child, label)
|
||||
|
||||
def del_arc(self, int head, int child):
|
||||
self.c.del_arc(head, child)
|
||||
|
||||
def open_ent(self, attr_t label):
|
||||
self.c.open_ent(label)
|
||||
|
||||
def close_ent(self):
|
||||
self.c.close_ent()
|
||||
|
||||
def clone(self, StateClass src):
|
||||
self.c.clone(src.c)
|
||||
|
|
|
@ -16,14 +16,14 @@ cdef struct Transition:
|
|||
weight_t score
|
||||
|
||||
bint (*is_valid)(const StateC* state, attr_t label) nogil
|
||||
weight_t (*get_cost)(StateClass state, const void* gold, attr_t label) nogil
|
||||
weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil
|
||||
int (*do)(StateC* state, attr_t label) nogil
|
||||
|
||||
|
||||
ctypedef weight_t (*get_cost_func_t)(StateClass state, const void* gold,
|
||||
ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
|
||||
attr_tlabel) nogil
|
||||
ctypedef weight_t (*move_cost_func_t)(StateClass state, const void* gold) nogil
|
||||
ctypedef weight_t (*label_cost_func_t)(StateClass state, const void*
|
||||
ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
|
||||
ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
|
||||
gold, attr_t label) nogil
|
||||
|
||||
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
|
||||
|
@ -41,9 +41,8 @@ cdef class TransitionSystem:
|
|||
cdef public attr_t root_label
|
||||
cdef public freqs
|
||||
cdef public object labels
|
||||
|
||||
cdef int initialize_state(self, StateC* state) nogil
|
||||
cdef int finalize_state(self, StateC* state) nogil
|
||||
cdef init_state_t init_beam_state
|
||||
cdef del_state_t del_beam_state
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *
|
||||
|
||||
|
@ -52,4 +51,4 @@ cdef class TransitionSystem:
|
|||
cdef int set_valid(self, int* output, const StateC* st) nogil
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass state, gold) except -1
|
||||
const StateC* state, gold) except -1
|
||||
|
|
|
@ -5,6 +5,7 @@ from cymem.cymem cimport Pool
|
|||
from collections import Counter
|
||||
import srsly
|
||||
|
||||
from . cimport _beam_utils
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from ...tokens.doc cimport Doc
|
||||
from ...structs cimport TokenC
|
||||
|
@ -44,6 +45,8 @@ cdef class TransitionSystem:
|
|||
if labels_by_action:
|
||||
self.initialize_actions(labels_by_action, min_freq=min_freq)
|
||||
self.root_label = self.strings.add('ROOT')
|
||||
self.init_beam_state = _init_state
|
||||
self.del_beam_state = _del_state
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.strings, self.labels), None, None)
|
||||
|
@ -54,7 +57,6 @@ cdef class TransitionSystem:
|
|||
offset = 0
|
||||
for doc in docs:
|
||||
state = StateClass(doc, offset=offset)
|
||||
self.initialize_state(state.c)
|
||||
states.append(state)
|
||||
offset += len(doc)
|
||||
return states
|
||||
|
@ -80,7 +82,7 @@ cdef class TransitionSystem:
|
|||
history = []
|
||||
debug_log = []
|
||||
while not state.is_final():
|
||||
self.set_costs(is_valid, costs, state, gold)
|
||||
self.set_costs(is_valid, costs, state.c, gold)
|
||||
for i in range(self.n_moves):
|
||||
if is_valid[i] and costs[i] <= 0:
|
||||
action = self.c[i]
|
||||
|
@ -124,15 +126,6 @@ cdef class TransitionSystem:
|
|||
action = self.lookup_transition(name)
|
||||
action.do(state.c, action.label)
|
||||
|
||||
cdef int initialize_state(self, StateC* state) nogil:
|
||||
pass
|
||||
|
||||
cdef int finalize_state(self, StateC* state) nogil:
|
||||
pass
|
||||
|
||||
def finalize_doc(self, doc):
|
||||
pass
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -151,7 +144,7 @@ cdef class TransitionSystem:
|
|||
is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass stcls, gold) except -1:
|
||||
const StateC* state, gold) except -1:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_class_name(self, int clas):
|
||||
|
|
|
@ -105,6 +105,93 @@ def make_parser(
|
|||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
multitasks=[],
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq,
|
||||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
)
|
||||
|
||||
@Language.factory(
|
||||
"beam_parser",
|
||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||
default_config={
|
||||
"beam_width": 8,
|
||||
"beam_density": 0.01,
|
||||
"beam_update_prob": 0.5,
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"model": DEFAULT_PARSER_MODEL,
|
||||
},
|
||||
default_score_weights={
|
||||
"dep_uas": 0.5,
|
||||
"dep_las": 0.5,
|
||||
"dep_las_per_type": None,
|
||||
"sents_p": None,
|
||||
"sents_r": None,
|
||||
"sents_f": 0.0,
|
||||
},
|
||||
)
|
||||
def make_beam_parser(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[list],
|
||||
update_with_oracle_cut_size: int,
|
||||
learn_tokens: bool,
|
||||
min_action_freq: int,
|
||||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
):
|
||||
"""Create a transition-based DependencyParser component that uses beam-search.
|
||||
The dependency parser jointly learns sentence segmentation and labelled
|
||||
dependency parsing, and can optionally learn to merge tokens that had been
|
||||
over-segmented by the tokenizer.
|
||||
|
||||
The parser uses a variant of the non-monotonic arc-eager transition-system
|
||||
described by Honnibal and Johnson (2014), with the addition of a "break"
|
||||
transition to perform the sentence segmentation. Nivre's pseudo-projective
|
||||
dependency transformation is used to allow the parser to predict
|
||||
non-projective parses.
|
||||
|
||||
The parser is trained using a global objective. That is, it learns to assign
|
||||
probabilities to whole parses.
|
||||
|
||||
model (Model): The model for the transition-based parser. The model needs
|
||||
to have a specific substructure of named components --- see the
|
||||
spacy.ml.tb_framework.TransitionModel for details.
|
||||
moves (List[str]): A list of transition names. Inferred from the data if not
|
||||
provided.
|
||||
beam_width (int): The number of candidate analyses to maintain.
|
||||
beam_density (float): The minimum ratio between the scores of the first and
|
||||
last candidates in the beam. This allows the parser to avoid exploring
|
||||
candidates that are too far behind. This is mostly intended to improve
|
||||
efficiency, but it can also improve accuracy as deeper search is not
|
||||
always better.
|
||||
beam_update_prob (float): The chance of making a beam update, instead of a
|
||||
greedy update. Greedy updates are an approximation for the beam updates,
|
||||
and are faster to compute.
|
||||
learn_tokens (bool): Whether to learn to merge subtokens that are split
|
||||
relative to the gold standard. Experimental.
|
||||
min_action_freq (int): The minimum frequency of labelled actions to retain.
|
||||
Rarer labelled actions have their label backed-off to "dep". While this
|
||||
primarily affects the label accuracy, it can also affect the attachment
|
||||
structure, as the labels are used to represent the pseudo-projectivity
|
||||
transformation.
|
||||
"""
|
||||
return DependencyParser(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
multitasks=[],
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq
|
||||
)
|
||||
|
||||
|
|
|
@ -82,6 +82,79 @@ def make_ner(
|
|||
multitasks=[],
|
||||
min_action_freq=1,
|
||||
learn_tokens=False,
|
||||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
)
|
||||
|
||||
@Language.factory(
|
||||
"beam_ner",
|
||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||
default_config={
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"model": DEFAULT_NER_MODEL,
|
||||
"beam_density": 0.01,
|
||||
"beam_update_prob": 0.5,
|
||||
"beam_width": 32
|
||||
},
|
||||
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
|
||||
)
|
||||
def make_beam_ner(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
moves: Optional[list],
|
||||
update_with_oracle_cut_size: int,
|
||||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
):
|
||||
"""Create a transition-based EntityRecognizer component that uses beam-search.
|
||||
The entity recognizer identifies non-overlapping labelled spans of tokens.
|
||||
|
||||
The transition-based algorithm used encodes certain assumptions that are
|
||||
effective for "traditional" named entity recognition tasks, but may not be
|
||||
a good fit for every span identification problem. Specifically, the loss
|
||||
function optimizes for whole entity accuracy, so if your inter-annotator
|
||||
agreement on boundary tokens is low, the component will likely perform poorly
|
||||
on your problem. The transition-based algorithm also assumes that the most
|
||||
decisive information about your entities will be close to their initial tokens.
|
||||
If your entities are long and characterised by tokens in their middle, the
|
||||
component will likely do poorly on your task.
|
||||
|
||||
model (Model): The model for the transition-based parser. The model needs
|
||||
to have a specific substructure of named components --- see the
|
||||
spacy.ml.tb_framework.TransitionModel for details.
|
||||
moves (list[str]): A list of transition names. Inferred from the data if not
|
||||
provided.
|
||||
update_with_oracle_cut_size (int):
|
||||
During training, cut long sequences into shorter segments by creating
|
||||
intermediate states based on the gold-standard history. The model is
|
||||
not very sensitive to this parameter, so you usually won't need to change
|
||||
it. 100 is a good default.
|
||||
beam_width (int): The number of candidate analyses to maintain.
|
||||
beam_density (float): The minimum ratio between the scores of the first and
|
||||
last candidates in the beam. This allows the parser to avoid exploring
|
||||
candidates that are too far behind. This is mostly intended to improve
|
||||
efficiency, but it can also improve accuracy as deeper search is not
|
||||
always better.
|
||||
beam_update_prob (float): The chance of making a beam update, instead of a
|
||||
greedy update. Greedy updates are an approximation for the beam updates,
|
||||
and are faster to compute.
|
||||
"""
|
||||
return EntityRecognizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
moves=moves,
|
||||
update_with_oracle_cut_size=update_with_oracle_cut_size,
|
||||
multitasks=[],
|
||||
min_action_freq=1,
|
||||
learn_tokens=False,
|
||||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,13 +4,14 @@ from cymem.cymem cimport Pool
|
|||
cimport numpy as np
|
||||
from itertools import islice
|
||||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate
|
||||
from thinc.api import set_dropout_rate, CupyOps
|
||||
from thinc.extra.search cimport Beam
|
||||
import numpy.random
|
||||
import numpy
|
||||
import warnings
|
||||
|
@ -22,6 +23,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
|
|||
from ..ml.parser_model cimport get_c_weights, get_c_sizes
|
||||
from ..tokens.doc cimport Doc
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ._parser_internals cimport _beam_utils
|
||||
from ._parser_internals import _beam_utils
|
||||
|
||||
from ..training import validate_examples, validate_get_examples
|
||||
from ..errors import Errors, Warnings
|
||||
|
@ -41,9 +44,12 @@ cdef class Parser(TrainablePipe):
|
|||
moves=None,
|
||||
*,
|
||||
update_with_oracle_cut_size,
|
||||
multitasks=tuple(),
|
||||
min_action_freq,
|
||||
learn_tokens,
|
||||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
multitasks=tuple(),
|
||||
):
|
||||
"""Create a Parser.
|
||||
|
||||
|
@ -61,7 +67,10 @@ cdef class Parser(TrainablePipe):
|
|||
"update_with_oracle_cut_size": update_with_oracle_cut_size,
|
||||
"multitasks": list(multitasks),
|
||||
"min_action_freq": min_action_freq,
|
||||
"learn_tokens": learn_tokens
|
||||
"learn_tokens": learn_tokens,
|
||||
"beam_width": beam_width,
|
||||
"beam_density": beam_density,
|
||||
"beam_update_prob": beam_update_prob
|
||||
}
|
||||
if moves is None:
|
||||
# defined by EntityRecognizer as a BiluoPushDown
|
||||
|
@ -183,7 +192,15 @@ cdef class Parser(TrainablePipe):
|
|||
result = self.moves.init_batch(docs)
|
||||
self._resize()
|
||||
return result
|
||||
if self.cfg["beam_width"] == 1:
|
||||
return self.greedy_parse(docs, drop=0.0)
|
||||
else:
|
||||
return self.beam_parse(
|
||||
docs,
|
||||
drop=0.0,
|
||||
beam_width=self.cfg["beam_width"],
|
||||
beam_density=self.cfg["beam_density"]
|
||||
)
|
||||
|
||||
def greedy_parse(self, docs, drop=0.):
|
||||
cdef vector[StateC*] states
|
||||
|
@ -207,6 +224,31 @@ cdef class Parser(TrainablePipe):
|
|||
del model
|
||||
return batch
|
||||
|
||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||
cdef Beam beam
|
||||
cdef Doc doc
|
||||
batch = _beam_utils.BeamBatch(
|
||||
self.moves,
|
||||
self.moves.init_batch(docs),
|
||||
None,
|
||||
beam_width,
|
||||
density=beam_density
|
||||
)
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
# expand our model output.
|
||||
self._resize()
|
||||
model = self.model.predict(docs)
|
||||
while not batch.is_done:
|
||||
states = batch.get_unfinished_states()
|
||||
if not states:
|
||||
break
|
||||
scores = model.predict(states)
|
||||
batch.advance(scores)
|
||||
model.clear_memory()
|
||||
del model
|
||||
return list(batch)
|
||||
|
||||
cdef void _parseC(self, StateC** states,
|
||||
WeightsC weights, SizesC sizes) nogil:
|
||||
cdef int i, j
|
||||
|
@ -227,14 +269,13 @@ cdef class Parser(TrainablePipe):
|
|||
unfinished.clear()
|
||||
free_activations(&activations)
|
||||
|
||||
def set_annotations(self, docs, states):
|
||||
def set_annotations(self, docs, states_or_beams):
|
||||
cdef StateClass state
|
||||
cdef Beam beam
|
||||
cdef Doc doc
|
||||
states = _beam_utils.collect_states(states_or_beams, docs)
|
||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||
self.moves.finalize_state(state.c)
|
||||
for j in range(doc.length):
|
||||
doc.c[j] = state.c._sent[j]
|
||||
self.moves.finalize_doc(doc)
|
||||
self.moves.set_annotations(state, doc)
|
||||
for hook in self.postprocesses:
|
||||
hook(doc)
|
||||
|
||||
|
@ -265,7 +306,6 @@ cdef class Parser(TrainablePipe):
|
|||
else:
|
||||
action = self.moves.c[guess]
|
||||
action.do(states[i], action.label)
|
||||
states[i].push_hist(guess)
|
||||
free(is_valid)
|
||||
|
||||
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||
|
@ -276,13 +316,23 @@ cdef class Parser(TrainablePipe):
|
|||
validate_examples(examples, "Parser.update")
|
||||
for multitask in self._multitasks:
|
||||
multitask.update(examples, drop=drop, sgd=sgd)
|
||||
|
||||
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||
if n_examples == 0:
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
[eg.predicted for eg in examples])
|
||||
# The probability we use beam update, instead of falling back to
|
||||
# a greedy update
|
||||
beam_update_prob = self.cfg["beam_update_prob"]
|
||||
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
|
||||
return self.update_beam(
|
||||
examples,
|
||||
beam_width=self.cfg["beam_width"],
|
||||
set_annotations=set_annotations,
|
||||
sgd=sgd,
|
||||
losses=losses,
|
||||
beam_density=self.cfg["beam_density"]
|
||||
)
|
||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||
if max_moves >= 1:
|
||||
# Chop sequences into lengths of this many words, to make the
|
||||
|
@ -296,6 +346,8 @@ cdef class Parser(TrainablePipe):
|
|||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
if not states:
|
||||
return losses
|
||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
||||
|
||||
all_states = list(states)
|
||||
states_golds = list(zip(states, golds))
|
||||
n_moves = 0
|
||||
|
@ -379,6 +431,27 @@ cdef class Parser(TrainablePipe):
|
|||
del tutor
|
||||
return losses
|
||||
|
||||
def update_beam(self, examples, *, beam_width,
|
||||
drop=0., sgd=None, losses=None, set_annotations=False, beam_density=0.0):
|
||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
if not states:
|
||||
return losses
|
||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
model, backprop_tok2vec = self.model.begin_update(
|
||||
[eg.predicted for eg in examples])
|
||||
loss = _beam_utils.update_beam(
|
||||
self.moves,
|
||||
states,
|
||||
golds,
|
||||
model,
|
||||
beam_width,
|
||||
beam_density=beam_density,
|
||||
)
|
||||
losses[self.name] += loss
|
||||
backprop_tok2vec(golds)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
|
||||
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
||||
cdef StateClass state
|
||||
cdef Pool mem = Pool()
|
||||
|
@ -396,7 +469,7 @@ cdef class Parser(TrainablePipe):
|
|||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
||||
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
||||
self.moves.set_costs(is_valid, costs, state, gold)
|
||||
self.moves.set_costs(is_valid, costs, state.c, gold)
|
||||
for j in range(self.moves.n_moves):
|
||||
if costs[j] <= 0.0 and j in unseen_classes:
|
||||
unseen_classes.remove(j)
|
||||
|
@ -539,7 +612,6 @@ cdef class Parser(TrainablePipe):
|
|||
for clas in oracle_actions[i:i+max_length]:
|
||||
action = self.moves.c[clas]
|
||||
action.do(state.c, action.label)
|
||||
state.c.push_hist(action.clas)
|
||||
if state.is_final():
|
||||
break
|
||||
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||
|
|
|
@ -7,6 +7,7 @@ from spacy.tokens import Doc
|
|||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
from spacy.pipeline._parser_internals.stateclass import StateClass
|
||||
|
||||
|
||||
def get_sequence_costs(M, words, heads, deps, transitions):
|
||||
|
@ -47,15 +48,24 @@ def test_oracle_four_words(arc_eager, vocab):
|
|||
for dep in deps:
|
||||
arc_eager.add_action(2, dep) # Left
|
||||
arc_eager.add_action(3, dep) # Right
|
||||
actions = ["L-left", "B-ROOT", "L-left"]
|
||||
actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
||||
expected_gold = [
|
||||
["S"],
|
||||
["B-ROOT", "L-left"],
|
||||
["B-ROOT"],
|
||||
["S"],
|
||||
["D"],
|
||||
["S"],
|
||||
["L-left"],
|
||||
["S"],
|
||||
["D"]
|
||||
]
|
||||
assert state.is_final()
|
||||
for i, state_costs in enumerate(cost_history):
|
||||
# Check gold moves is 0 cost
|
||||
assert state_costs[actions[i]] == 0.0, actions[i]
|
||||
for other_action, cost in state_costs.items():
|
||||
if other_action != actions[i]:
|
||||
assert cost >= 1, (i, other_action)
|
||||
golds = [act for act, cost in state_costs.items() if cost < 1]
|
||||
assert golds == expected_gold[i], (i, golds, expected_gold[i])
|
||||
|
||||
|
||||
annot_tuples = [
|
||||
|
@ -169,12 +179,15 @@ def test_oracle_dev_sentence(vocab, arc_eager):
|
|||
. punct said
|
||||
"""
|
||||
expected_transitions = [
|
||||
"S", # Shift "Rolls-Royce"
|
||||
"S", # Shift 'Motor'
|
||||
"S", # Shift 'Cars'
|
||||
"L-nn", # Attach 'Cars' to 'Inc.'
|
||||
"L-nn", # Attach 'Motor' to 'Inc.'
|
||||
"L-nn", # Attach 'Rolls-Royce' to 'Inc.', force shift
|
||||
"L-nn", # Attach 'Rolls-Royce' to 'Inc.'
|
||||
"S", # Shift "Inc."
|
||||
"L-nsubj", # Attach 'Inc.' to 'said'
|
||||
"S", # Shift 'said'
|
||||
"S", # Shift 'it'
|
||||
"L-nsubj", # Attach 'it.' to 'expects'
|
||||
"R-ccomp", # Attach 'expects' to 'said'
|
||||
|
@ -204,6 +217,8 @@ def test_oracle_dev_sentence(vocab, arc_eager):
|
|||
"D", # Reduce "steady"
|
||||
"D", # Reduce "expects"
|
||||
"R-punct", # Attach "." to "said"
|
||||
"D", # Reduce "."
|
||||
"D", # Reduce "said"
|
||||
]
|
||||
|
||||
gold_words = []
|
||||
|
@ -221,10 +236,40 @@ def test_oracle_dev_sentence(vocab, arc_eager):
|
|||
for dep in gold_deps:
|
||||
arc_eager.add_action(2, dep) # Left
|
||||
arc_eager.add_action(3, dep) # Right
|
||||
|
||||
doc = Doc(Vocab(), words=gold_words)
|
||||
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
|
||||
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example)
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
|
||||
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
|
||||
assert ae_oracle_actions == expected_transitions
|
||||
|
||||
|
||||
def test_oracle_bad_tokenization(vocab, arc_eager):
|
||||
words_deps_heads = """
|
||||
[catalase] dep is
|
||||
: punct is
|
||||
that nsubj is
|
||||
is root is
|
||||
bad comp is
|
||||
"""
|
||||
|
||||
gold_words = []
|
||||
gold_deps = []
|
||||
gold_heads = []
|
||||
for line in words_deps_heads.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
word, dep, head = line.split()
|
||||
gold_words.append(word)
|
||||
gold_deps.append(dep)
|
||||
gold_heads.append(head)
|
||||
gold_heads = [gold_words.index(head) for head in gold_heads]
|
||||
for dep in gold_deps:
|
||||
arc_eager.add_action(2, dep) # Left
|
||||
arc_eager.add_action(3, dep) # Right
|
||||
reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
|
||||
predicted = Doc(reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"])
|
||||
example = Example(predicted=predicted, reference=reference)
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
|
||||
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
|
||||
assert ae_oracle_actions
|
||||
|
|
|
@ -54,7 +54,7 @@ def tsys(vocab, entity_types):
|
|||
|
||||
def test_get_oracle_moves(tsys, doc, entity_annots):
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
act_classes = tsys.get_oracle_sequence(example, _debug=False)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
||||
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
import hypothesis
|
||||
import hypothesis.strategies
|
||||
import numpy
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import DependencyParser
|
||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
||||
from spacy.tokens import Doc
|
||||
from spacy.pipeline._parser_internals._beam_utils import BeamBatch
|
||||
from spacy.pipeline._parser_internals.stateclass import StateClass
|
||||
from spacy.training import Example
|
||||
from thinc.tests.strategies import ndarrays_of_shape
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vocab():
|
||||
return Vocab()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def moves(vocab):
|
||||
aeager = ArcEager(vocab.strings, {})
|
||||
aeager.add_action(0, "")
|
||||
aeager.add_action(1, "")
|
||||
aeager.add_action(2, "nsubj")
|
||||
aeager.add_action(2, "punct")
|
||||
aeager.add_action(2, "aux")
|
||||
aeager.add_action(2, "nsubjpass")
|
||||
aeager.add_action(3, "dobj")
|
||||
aeager.add_action(2, "aux")
|
||||
aeager.add_action(4, "ROOT")
|
||||
return aeager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def docs(vocab):
|
||||
return [
|
||||
Doc(
|
||||
vocab,
|
||||
words=["Rats", "bite", "things"],
|
||||
heads=[1, 1, 1],
|
||||
deps=["nsubj", "ROOT", "dobj"],
|
||||
sent_starts=[True, False, False]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def examples(docs):
|
||||
return [Example(doc, doc.copy()) for doc in docs]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def states(docs):
|
||||
return [StateClass(doc) for doc in docs]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokvecs(docs, vector_size):
|
||||
output = []
|
||||
for doc in docs:
|
||||
vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size))
|
||||
output.append(numpy.asarray(vec))
|
||||
return output
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def batch_size(docs):
|
||||
return len(docs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def beam_width():
|
||||
return 4
|
||||
|
||||
@pytest.fixture(params=[0.0, 0.5, 1.0])
|
||||
def beam_density(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def vector_size():
|
||||
return 6
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def beam(moves, examples, beam_width):
|
||||
states, golds, _ = moves.init_gold_batch(examples)
|
||||
return BeamBatch(moves, states, golds, width=beam_width, density=0.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scores(moves, batch_size, beam_width):
|
||||
return numpy.asarray(
|
||||
numpy.concatenate(
|
||||
[
|
||||
numpy.random.uniform(-0.1, 0.1, (beam_width, moves.n_moves))
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
), dtype="float32")
|
||||
|
||||
|
||||
def test_create_beam(beam):
|
||||
pass
|
||||
|
||||
|
||||
def test_beam_advance(beam, scores):
|
||||
beam.advance(scores)
|
||||
|
||||
|
||||
def test_beam_advance_too_few_scores(beam, scores):
|
||||
n_state = sum(len(beam) for beam in beam)
|
||||
scores = scores[:n_state]
|
||||
with pytest.raises(IndexError):
|
||||
beam.advance(scores[:-1])
|
||||
|
||||
|
||||
def test_beam_parse(examples, beam_width):
|
||||
nlp = Language()
|
||||
parser = nlp.add_pipe("beam_parser")
|
||||
parser.cfg["beam_width"] = beam_width
|
||||
parser.add_label("nsubj")
|
||||
parser.initialize(lambda: examples)
|
||||
doc = nlp.make_doc("Australia is a country")
|
||||
parser(doc)
|
||||
|
||||
|
||||
|
||||
|
||||
@hypothesis.given(hyp=hypothesis.strategies.data())
|
||||
def test_beam_density(moves, examples, beam_width, hyp):
|
||||
beam_density = float(hyp.draw(hypothesis.strategies.floats(0.0, 1.0, width=32)))
|
||||
states, golds, _ = moves.init_gold_batch(examples)
|
||||
beam = BeamBatch(moves, states, golds, width=beam_width, density=beam_density)
|
||||
n_state = sum(len(beam) for beam in beam)
|
||||
scores = hyp.draw(ndarrays_of_shape((n_state, moves.n_moves)))
|
||||
beam.advance(scores)
|
||||
for b in beam:
|
||||
beam_probs = b.probs
|
||||
assert b.min_density == beam_density
|
||||
assert beam_probs[-1] >= beam_probs[0] * beam_density
|
|
@ -22,6 +22,7 @@ def _parser_example(parser):
|
|||
|
||||
@pytest.fixture
|
||||
def parser(vocab):
|
||||
vocab.strings.add("ROOT")
|
||||
config = {
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
|
@ -76,13 +77,16 @@ def test_sents_1_2(parser):
|
|||
|
||||
def test_sents_1_3(parser):
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
doc[1].sent_start = True
|
||||
doc[3].sent_start = True
|
||||
doc[0].is_sent_start = True
|
||||
doc[1].is_sent_start = True
|
||||
doc[2].is_sent_start = None
|
||||
doc[3].is_sent_start = True
|
||||
doc = parser(doc)
|
||||
assert len(list(doc.sents)) >= 3
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
doc[1].sent_start = True
|
||||
doc[2].sent_start = False
|
||||
doc[3].sent_start = True
|
||||
doc[0].is_sent_start = True
|
||||
doc[1].is_sent_start = True
|
||||
doc[2].is_sent_start = False
|
||||
doc[3].is_sent_start = True
|
||||
doc = parser(doc)
|
||||
assert len(list(doc.sents)) == 3
|
||||
|
|
74
spacy/tests/parser/test_state.py
Normal file
74
spacy/tests/parser/test_state.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import pytest
|
||||
|
||||
from spacy.tokens.doc import Doc
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.pipeline._parser_internals.stateclass import StateClass
|
||||
|
||||
@pytest.fixture
|
||||
def vocab():
|
||||
return Vocab()
|
||||
|
||||
@pytest.fixture
|
||||
def doc(vocab):
|
||||
return Doc(vocab, words=["a", "b", "c", "d"])
|
||||
|
||||
def test_init_state(doc):
|
||||
state = StateClass(doc)
|
||||
assert state.stack == []
|
||||
assert state.queue == list(range(len(doc)))
|
||||
assert not state.is_final()
|
||||
assert state.buffer_length() == 4
|
||||
|
||||
def test_push_pop(doc):
|
||||
state = StateClass(doc)
|
||||
state.push()
|
||||
assert state.buffer_length() == 3
|
||||
assert state.stack == [0]
|
||||
assert 0 not in state.queue
|
||||
state.push()
|
||||
assert state.stack == [1, 0]
|
||||
assert 1 not in state.queue
|
||||
assert state.buffer_length() == 2
|
||||
state.pop()
|
||||
assert state.stack == [0]
|
||||
assert 1 not in state.queue
|
||||
|
||||
def test_stack_depth(doc):
|
||||
state = StateClass(doc)
|
||||
assert state.stack_depth() == 0
|
||||
assert state.buffer_length() == len(doc)
|
||||
state.push()
|
||||
assert state.buffer_length() == 3
|
||||
assert state.stack_depth() == 1
|
||||
|
||||
|
||||
def test_H(doc):
|
||||
state = StateClass(doc)
|
||||
assert state.H(0) == -1
|
||||
state.add_arc(1, 0, 0)
|
||||
assert state.arcs == [{"head": 1, "child": 0, "label": 0}]
|
||||
assert state.H(0) == 1
|
||||
state.add_arc(3, 1, 0)
|
||||
assert state.H(1) == 3
|
||||
|
||||
|
||||
def test_L(doc):
|
||||
state = StateClass(doc)
|
||||
assert state.L(2, 1) == -1
|
||||
state.add_arc(2, 1, 0)
|
||||
assert state.arcs == [{"head": 2, "child": 1, "label": 0}]
|
||||
assert state.L(2, 1) == 1
|
||||
state.add_arc(2, 0, 0)
|
||||
assert state.L(2, 1) == 0
|
||||
assert state.n_L(2) == 2
|
||||
|
||||
|
||||
def test_R(doc):
|
||||
state = StateClass(doc)
|
||||
assert state.R(0, 1) == -1
|
||||
state.add_arc(0, 1, 0)
|
||||
assert state.arcs == [{"head": 0, "child": 1, "label": 0}]
|
||||
assert state.R(0, 1) == 1
|
||||
state.add_arc(0, 2, 0)
|
||||
assert state.R(0, 1) == 2
|
||||
assert state.n_R(0) == 2
|
|
@ -122,7 +122,8 @@ def test_issue4042_bug2():
|
|||
assert "SOME_LABEL" in ner1.labels
|
||||
apple_ent = Span(doc1, 5, 6, label="MY_ORG")
|
||||
doc1.ents = list(doc1.ents) + [apple_ent]
|
||||
# reapply the NER - at this point it should resize itself
|
||||
# Add the label explicitly. Previously we didn't require this.
|
||||
ner1.add_label("MY_ORG")
|
||||
ner1(doc1)
|
||||
assert len(ner1.labels) == 2
|
||||
assert "SOME_LABEL" in ner1.labels
|
||||
|
|
|
@ -22,6 +22,9 @@ def parser(en_vocab):
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"beam_width": 1,
|
||||
"beam_update_prob": 1.0,
|
||||
"beam_density": 0.0
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
|
@ -36,6 +39,9 @@ def blank_parser(en_vocab):
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"beam_width": 1,
|
||||
"beam_update_prob": 1.0,
|
||||
"beam_density": 0.0
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
|
@ -58,6 +64,9 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"beam_width": 1,
|
||||
"beam_update_prob": 1.0,
|
||||
"beam_density": 0.0
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
|
@ -79,6 +88,9 @@ def test_serialize_parser_strings(Parser):
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"beam_width": 1,
|
||||
"beam_update_prob": 1.0,
|
||||
"beam_density": 0.0
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
|
@ -98,6 +110,9 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
|
|||
"learn_tokens": False,
|
||||
"min_action_freq": 0,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"beam_width": 1,
|
||||
"beam_update_prob": 1.0,
|
||||
"beam_density": 0.0
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
|
|
|
@ -191,6 +191,24 @@ cdef class Example:
|
|||
aligned_deps[cand_i] = deps[gold_i]
|
||||
return aligned_heads, aligned_deps
|
||||
|
||||
def get_aligned_sent_starts(self):
|
||||
"""Get list of SENT_START attributes aligned to the predicted tokenization.
|
||||
If the reference has not sentence starts, return a list of None values.
|
||||
|
||||
The aligned sentence starts use the get_aligned_spans method, rather
|
||||
than aligning the list of tags, so that it handles cases where a mistaken
|
||||
tokenization starts the sentence.
|
||||
"""
|
||||
if self.y.has_annotation("SENT_START"):
|
||||
align = self.alignment.y2x
|
||||
sent_starts = [False] * len(self.x)
|
||||
for y_sent in self.y.sents:
|
||||
x_start = int(align[y_sent.start].dataXd[0])
|
||||
sent_starts[x_start] = True
|
||||
return sent_starts
|
||||
else:
|
||||
return [None] * len(self.x)
|
||||
|
||||
def get_aligned_spans_x2y(self, x_spans):
|
||||
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user