Add beam_parser and beam_ner components for v3 (#6369)

* Get basic beam tests working

* Get basic beam tests working

* Compile _beam_utils

* Remove prints

* Test beam density

* Beam parser seems to train

* Draft beam NER

* Upd beam

* Add hypothesis as dev dependency

* Implement missing is-gold-parse method

* Implement early update

* Fix state hashing

* Fix test

* Fix test

* Default to non-beam in parser constructor

* Improve oracle for beam

* Start refactoring beam

* Update test

* Refactor beam

* Update nn

* Refactor beam and weight by cost

* Update ner beam settings

* Update test

* Add __init__.pxd

* Upd test

* Fix test

* Upd test

* Fix test

* Remove ring buffer history from StateC

* WIP change arc-eager transitions

* Add state tests

* Support ternary sent start values

* Fix arc eager

* Fix NER

* Pass oracle cut size for beam

* Fix ner test

* Fix beam

* Improve StateC.clone

* Improve StateClass.borrow

* Work directly with StateC, not StateClass

* Remove print statements

* Fix state copy

* Improve state class

* Refactor parser oracles

* Fix arc eager oracle

* Fix arc eager oracle

* Use a vector to implement the stack

* Refactor state data structure

* Fix alignment of sent start

* Add get_aligned_sent_starts method

* Add test for ae oracle when bad sentence starts

* Fix sentence segment handling

* Avoid Reduce that inserts illegal sentence

* Update preset SBD test

* Fix test

* Remove prints

* Fix sent starts in Example

* Improve python API of StateClass

* Tweak comments and debug output of arc eager

* Upd test

* Fix state test

* Fix state test
This commit is contained in:
Matthew Honnibal 2020-12-13 12:08:32 +11:00 committed by GitHub
parent 85ca8c2bdd
commit 8656a08777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1570 additions and 816 deletions

View File

@ -27,3 +27,4 @@ pytest>=4.6.5
pytest-timeout>=1.3.0,<2.0.0 pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0 mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0 flake8>=3.5.0,<3.6.0
hypothesis

View File

@ -48,6 +48,7 @@ MOD_NAMES = [
"spacy.pipeline._parser_internals._state", "spacy.pipeline._parser_internals._state",
"spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system", "spacy.pipeline._parser_internals.transition_system",
"spacy.pipeline._parser_internals._beam_utils",
"spacy.tokenizer", "spacy.tokenizer",
"spacy.training.align", "spacy.training.align",
"spacy.training.gold_io", "spacy.training.gold_io",

View File

@ -0,0 +1,6 @@
from ...typedefs cimport class_t, hash_t
# These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
cdef int check_final_state(void* _state, void* extra_args) except -1

View File

@ -0,0 +1,296 @@
# cython: infer_types=True
# cython: profile=True
cimport numpy as np
import numpy
from cpython.ref cimport PyObject, Py_XDECREF
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.extra.search cimport MaxViolation
from ...typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition
from ...errors import Errors
from .stateclass cimport StateC, StateClass
# These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateC*>_dest
src = <StateC*>_src
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest, moves[clas].label)
cdef int check_final_state(void* _state, void* extra_args) except -1:
state = <StateC*>_state
return state.is_final()
cdef class BeamBatch(object):
cdef public TransitionSystem moves
cdef public object states
cdef public object docs
cdef public object golds
cdef public object beams
def __init__(self, TransitionSystem moves, states, golds,
int width, float density=0.):
cdef StateClass state
self.moves = moves
self.states = states
self.docs = [state.doc for state in states]
self.golds = golds
self.beams = []
cdef Beam beam
cdef StateC* st
for state in states:
beam = Beam(self.moves.n_moves, width, min_density=density)
beam.initialize(self.moves.init_beam_state,
self.moves.del_beam_state, state.c.length,
<void*>state.c._sent)
for i in range(beam.width):
st = <StateC*>beam.at(i)
st.offset = state.c.offset
beam.check_done(check_final_state, NULL)
self.beams.append(beam)
@property
def is_done(self):
return all(b.is_done for b in self.beams)
def __getitem__(self, i):
return self.beams[i]
def __len__(self):
return len(self.beams)
def get_states(self):
cdef Beam beam
cdef StateC* state
cdef StateClass stcls
states = []
for beam, doc in zip(self, self.docs):
for i in range(beam.size):
state = <StateC*>beam.at(i)
stcls = StateClass.borrow(state, doc)
states.append(stcls)
return states
def get_unfinished_states(self):
return [st for st in self.get_states() if not st.is_final()]
def advance(self, float[:, ::1] scores, follow_gold=False):
cdef Beam beam
cdef int nr_class = scores.shape[1]
cdef const float* c_scores = &scores[0, 0]
docs = self.docs
for i, beam in enumerate(self):
if not beam.is_done:
nr_state = self._set_scores(beam, c_scores, nr_class)
assert nr_state
if self.golds is not None:
self._set_costs(
beam,
docs[i],
self.golds[i],
follow_gold=follow_gold
)
c_scores += nr_state * nr_class
beam.advance(transition_state, NULL, <void*>self.moves.c)
beam.check_done(check_final_state, NULL)
cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
cdef int nr_state = 0
for i in range(beam.size):
state = <StateC*>beam.at(i)
if not state.is_final():
for j in range(nr_class):
beam.scores[i][j] = scores[nr_state * nr_class + j]
self.moves.set_valid(beam.is_valid[i], state)
nr_state += 1
else:
for j in range(beam.nr_class):
beam.scores[i][j] = 0
beam.costs[i][j] = 0
return nr_state
def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
cdef const StateC* state
for i in range(beam.size):
state = <const StateC*>beam.at(i)
if state.is_final():
for j in range(beam.nr_class):
beam.is_valid[i][j] = 0
beam.costs[i][j] = 9000
else:
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
state, gold)
if follow_gold:
min_cost = 0
for j in range(beam.nr_class):
if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
min_cost = beam.costs[i][j]
for j in range(beam.nr_class):
if beam.costs[i][j] > min_cost:
beam.is_valid[i][j] = 0
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
cdef MaxViolation violn
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
cdef StateClass state
beam_maps = []
backprops = []
violns = [MaxViolation() for _ in range(len(states))]
dones = [False for _ in states]
while not pbeam.is_done or not gbeam.is_done:
# The beam maps let us find the right row in the flattened scores
# array for each state. States are identified by (example id,
# history). We keep a different beam map for each step (since we'll
# have a flat scores array for each step). The beam map will let us
# take the per-state losses, and compute the gradient for each (step,
# state, class).
# Gather all states from the two beams in a list. Some stats may occur
# in both beams. To figure out which beam each state belonged to,
# we keep two lists of indices, p_indices and g_indices
states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
beam_maps.append(beam_map)
if not states:
break
# Now that we have our flat list of states, feed them through the model
scores, bp_scores = model.begin_update(states)
assert scores.size != 0
# Store the callbacks for the backward pass
backprops.append(bp_scores)
# Unpack the scores for the two beams. The indices arrays
# tell us which example and state the scores-row refers to.
# Now advance the states in the beams. The gold beam is constrained to
# to follow only gold analyses.
if not pbeam.is_done:
pbeam.advance(model.ops.as_contig(scores[p_indices]))
if not gbeam.is_done:
gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
# Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns):
if not dones[i]:
violn.check_crf(pbeam[i], gbeam[i])
if pbeam[i].is_done and gbeam[i].is_done:
dones[i] = True
histories = []
grads = []
for violn in violns:
if violn.p_hist:
histories.append(violn.p_hist + violn.g_hist)
d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
grads.append(d_loss)
else:
histories.append([])
grads.append([])
loss = 0.0
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
loss += (d_scores**2).mean()
bp_scores(d_scores)
return loss
def collect_states(beams, docs):
cdef StateClass state
cdef Beam beam
states = []
for state_or_beam, doc in zip(beams, docs):
if isinstance(state_or_beam, StateClass):
states.append(state_or_beam)
else:
beam = state_or_beam
state = StateClass.borrow(<StateC*>beam.at(0), doc)
states.append(state)
return states
def get_unique_states(pbeams, gbeams):
seen = {}
states = []
p_indices = []
g_indices = []
beam_map = {}
docs = pbeams.docs
cdef Beam pbeam, gbeam
if len(pbeams) != len(gbeams):
raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
if not pbeam.is_done:
for i in range(pbeam.size):
state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i])
if key in seen:
raise ValueError(Errors.E080.format(key=key))
seen[key] = len(states)
p_indices.append(len(states))
states.append(state)
beam_map.update(seen)
if not gbeam.is_done:
for i in range(gbeam.size):
state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
if not state.is_final():
key = tuple([eg_id] + gbeam.histories[i])
if key in seen:
g_indices.append(seen[key])
else:
g_indices.append(len(states))
beam_map[key] = len(states)
states.append(state)
p_indices = numpy.asarray(p_indices, dtype='i')
g_indices = numpy.asarray(g_indices, dtype='i')
return states, p_indices, g_indices, beam_map
def get_gradient(nr_class, beam_maps, histories, losses):
"""The global model assigns a loss to each parse. The beam scores
are additive, so the same gradient is applied to each action
in the history. This gives the gradient of a single *action*
for a beam state -- so we have "the gradient of loss for taking
action i given history H."
Histories: Each hitory is a list of actions
Each candidate has a history
Each beam has multiple candidates
Each batch has multiple beams
So history is list of lists of lists of ints
"""
grads = []
nr_steps = []
for eg_id, hists in enumerate(histories):
nr_step = 0
for loss, hist in zip(losses[eg_id], hists):
assert not numpy.isnan(loss)
if loss != 0.0:
nr_step = max(nr_step, len(hist))
nr_steps.append(nr_step)
for i in range(max(nr_steps)):
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
dtype='f'))
if len(histories) != len(losses):
raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists):
assert not numpy.isnan(loss)
if loss == 0.0:
continue
key = tuple([eg_id])
# Adjust loss for length
# We need to do this because each state in a short path is scored
# multiple times, as we add in the average cost when we run out
# of actions.
avg_loss = loss / len(hist)
loss += avg_loss * (nr_steps[eg_id] - len(hist))
for step, clas in enumerate(hist):
i = beam_maps[step][key]
# In step j, at state i action clas
# resulted in loss
grads[step][i, clas] += loss
key = key + tuple([clas])
return grads

View File

@ -1,6 +1,9 @@
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
cimport libcpp
from libcpp.vector cimport vector
from libcpp.set cimport set
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
@ -14,89 +17,48 @@ from ...typedefs cimport attr_t
cdef inline bint is_space_token(const TokenC* token) nogil: cdef inline bint is_space_token(const TokenC* token) nogil:
return Lexeme.c_check_flag(token.lex, IS_SPACE) return Lexeme.c_check_flag(token.lex, IS_SPACE)
cdef struct RingBufferC: cdef struct ArcC:
int[8] data int head
int i int child
int default attr_t label
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
ring.data[ring.i] = value
ring.i += 1
if ring.i >= 8:
ring.i = 0
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
if i >= ring.i:
return ring.default
else:
return ring.data[ring.i-i]
cdef cppclass StateC: cdef cppclass StateC:
int* _stack int* _heads
int* _buffer const TokenC* _sent
bint* shifted vector[int] _stack
TokenC* _sent vector[int] _rebuffer
SpanC* _ents vector[SpanC] _ents
vector[ArcC] _left_arcs
vector[ArcC] _right_arcs
vector[libcpp.bool] _unshiftable
set[int] _sent_starts
TokenC _empty_token TokenC _empty_token
RingBufferC _hist
int length int length
int offset int offset
int _s_i
int _b_i int _b_i
int _e_i
int _break
__init__(const TokenC* sent, int length) nogil: __init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5 this._sent = sent
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._heads = <int*>calloc(length, sizeof(int))
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int)) if not (this._sent and this._heads):
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
if not (this._buffer and this._stack and this.shifted
and this._sent and this._ents):
with gil: with gil:
PyErr_SetFromErrno(MemoryError) PyErr_SetFromErrno(MemoryError)
PyErr_CheckSignals() PyErr_CheckSignals()
memset(&this._hist, 0, sizeof(this._hist))
this.offset = 0 this.offset = 0
cdef int i
for i in range(length + (PADDING * 2)):
this._ents[i].end = -1
this._sent[i].l_edge = i
this._sent[i].r_edge = i
for i in range(PADDING):
this._sent[i].lex = &EMPTY_LEXEME
this._sent += PADDING
this._ents += PADDING
this._buffer += PADDING
this._stack += PADDING
this.shifted += PADDING
this.length = length this.length = length
this._break = -1
this._s_i = 0
this._b_i = 0 this._b_i = 0
this._e_i = 0
for i in range(length): for i in range(length):
this._buffer[i] = i this._heads[i] = -1
this._unshiftable.push_back(0)
memset(&this._empty_token, 0, sizeof(TokenC)) memset(&this._empty_token, 0, sizeof(TokenC))
this._empty_token.lex = &EMPTY_LEXEME this._empty_token.lex = &EMPTY_LEXEME
for i in range(length):
this._sent[i] = sent[i]
this._buffer[i] = i
for i in range(length, length+PADDING):
this._sent[i].lex = &EMPTY_LEXEME
__dealloc__(): __dealloc__():
cdef int PADDING = 5 free(this._heads)
free(this._sent - PADDING)
free(this._ents - PADDING)
free(this._buffer - PADDING)
free(this._stack - PADDING)
free(this.shifted - PADDING)
void set_context_tokens(int* ids, int n) nogil: void set_context_tokens(int* ids, int n) nogil:
cdef int i, j
if n == 1: if n == 1:
if this.B(0) >= 0: if this.B(0) >= 0:
ids[0] = this.B(0) ids[0] = this.B(0)
@ -145,22 +107,18 @@ cdef cppclass StateC:
ids[11] = this.R(this.S(1), 1) ids[11] = this.R(this.S(1), 1)
ids[12] = this.R(this.S(1), 2) ids[12] = this.R(this.S(1), 2)
elif n == 6: elif n == 6:
for i in range(6):
ids[i] = -1
if this.B(0) >= 0: if this.B(0) >= 0:
ids[0] = this.B(0) ids[0] = this.B(0)
ids[1] = this.B(0)-1 if this.entity_is_open():
else: ent = this.get_ent()
ids[0] = -1 j = 1
ids[1] = -1 for i in range(ent.start, this.B(0)):
ids[2] = this.B(1) ids[j] = i
ids[3] = this.E(0) j += 1
if ids[3] >= 1: if j >= 6:
ids[4] = this.E(0)-1 break
else:
ids[4] = -1
if (ids[3]+1) < this.length:
ids[5] = this.E(0)+1
else:
ids[5] = -1
else: else:
# TODO error =/ # TODO error =/
pass pass
@ -171,329 +129,256 @@ cdef cppclass StateC:
ids[i] = -1 ids[i] = -1
int S(int i) nogil const: int S(int i) nogil const:
if i >= this._s_i: if i >= this._stack.size():
return -1 return -1
return this._stack[this._s_i - (i+1)] elif i < 0:
return -1
return this._stack.at(this._stack.size() - (i+1))
int B(int i) nogil const: int B(int i) nogil const:
if (i + this._b_i) >= this.length: if i < 0:
return -1 return -1
return this._buffer[this._b_i + i] elif i < this._rebuffer.size():
return this._rebuffer.at(this._rebuffer.size() - (i+1))
const TokenC* S_(int i) nogil const: else:
return this.safe_get(this.S(i)) b_i = this._b_i + (i - this._rebuffer.size())
if b_i >= this.length:
return -1
else:
return b_i
const TokenC* B_(int i) nogil const: const TokenC* B_(int i) nogil const:
return this.safe_get(this.B(i)) return this.safe_get(this.B(i))
const TokenC* H_(int i) nogil const:
return this.safe_get(this.H(i))
const TokenC* E_(int i) nogil const: const TokenC* E_(int i) nogil const:
return this.safe_get(this.E(i)) return this.safe_get(this.E(i))
const TokenC* L_(int i, int idx) nogil const:
return this.safe_get(this.L(i, idx))
const TokenC* R_(int i, int idx) nogil const:
return this.safe_get(this.R(i, idx))
const TokenC* safe_get(int i) nogil const: const TokenC* safe_get(int i) nogil const:
if i < 0 or i >= this.length: if i < 0 or i >= this.length:
return &this._empty_token return &this._empty_token
else: else:
return &this._sent[i] return &this._sent[i]
int H(int i) nogil const: void get_arcs(vector[ArcC]* arcs) nogil const:
if i < 0 or i >= this.length: for i in range(this._left_arcs.size()):
arc = this._left_arcs.at(i)
if arc.head != -1 and arc.child != -1:
arcs.push_back(arc)
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head != -1 and arc.child != -1:
arcs.push_back(arc)
int H(int child) nogil const:
if child >= this.length or child < 0:
return -1 return -1
return this._sent[i].head + i else:
return this._heads[child]
int E(int i) nogil const: int E(int i) nogil const:
if this._e_i <= 0 or this._e_i >= this.length: if this._ents.size() == 0:
return -1 return -1
if i < 0 or i >= this._e_i:
return -1
return this._ents[this._e_i - (i+1)].start
int L(int i, int idx) nogil const:
if idx < 1:
return -1
if i < 0 or i >= this.length:
return -1
cdef const TokenC* target = &this._sent[i]
if target.l_kids < <uint32_t>idx:
return -1
cdef const TokenC* ptr = &this._sent[target.l_edge]
while ptr < target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head >= 1) and (ptr + ptr.head) < target:
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - this._sent
ptr += 1
else: else:
ptr += 1 return this._ents.back().start
return -1
int R(int i, int idx) nogil const: int L(int head, int idx) nogil const:
if idx < 1: if idx < 1 or this._left_arcs.size() == 0:
return -1 return -1
if i < 0 or i >= this.length: cdef vector[int] lefts
for i in range(this._left_arcs.size()):
arc = this._left_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child < head:
lefts.push_back(arc.child)
idx = (<int>lefts.size()) - idx
if idx < 0:
return -1 return -1
cdef const TokenC* target = &this._sent[i]
if target.r_kids < <uint32_t>idx:
return -1
cdef const TokenC* ptr = &this._sent[target.r_edge]
while ptr > target:
# If this head is still to the right of us, we can skip to it
# No token that's between this token and this head could be our
# child.
if (ptr.head < 0) and ((ptr + ptr.head) > target):
ptr += ptr.head
elif ptr + ptr.head == target:
idx -= 1
if idx == 0:
return ptr - this._sent
ptr -= 1
else: else:
ptr -= 1 return lefts.at(idx)
int R(int head, int idx) nogil const:
if idx < 1 or this._right_arcs.size() == 0:
return -1 return -1
cdef vector[int] rights
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child > head:
rights.push_back(arc.child)
idx = (<int>rights.size()) - idx
if idx < 0:
return -1
else:
return rights.at(idx)
bint empty() nogil const: bint empty() nogil const:
return this._s_i <= 0 return this._stack.size() == 0
bint eol() nogil const: bint eol() nogil const:
return this.buffer_length() == 0 return this.buffer_length() == 0
bint at_break() nogil const:
return this._break != -1
bint is_final() nogil const: bint is_final() nogil const:
return this.stack_depth() <= 0 and this._b_i >= this.length return this.stack_depth() <= 0 and this.eol()
bint has_head(int i) nogil const: int cannot_sent_start(int word) nogil const:
return this.safe_get(i).head != 0 if word < 0 or word >= this.length:
return 0
elif this._sent[word].sent_start == -1:
return 1
else:
return 0
int n_L(int i) nogil const: int is_sent_start(int word) nogil const:
return this.safe_get(i).l_kids if word < 0 or word >= this.length:
return 0
elif this._sent[word].sent_start == 1:
return 1
elif this._sent_starts.count(word) >= 1:
return 1
else:
return 0
int n_R(int i) nogil const: void set_sent_start(int word, int value) nogil:
return this.safe_get(i).r_kids if value >= 1:
this._sent_starts.insert(word)
bint has_head(int child) nogil const:
return this._heads[child] >= 0
int l_edge(int word) nogil const:
return word
int r_edge(int word) nogil const:
return word
int n_L(int head) nogil const:
cdef int n = 0
for i in range(this._left_arcs.size()):
arc = this._left_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child < arc.head:
n += 1
return n
int n_R(int head) nogil const:
cdef int n = 0
for i in range(this._right_arcs.size()):
arc = this._right_arcs.at(i)
if arc.head == head and arc.child != -1 and arc.child > arc.head:
n += 1
return n
bint stack_is_connected() nogil const: bint stack_is_connected() nogil const:
return False return False
bint entity_is_open() nogil const: bint entity_is_open() nogil const:
if this._e_i < 1: if this._ents.size() == 0:
return False return False
return this._ents[this._e_i-1].end == -1 else:
return this._ents.back().end == -1
int stack_depth() nogil const: int stack_depth() nogil const:
return this._s_i return this._stack.size()
int buffer_length() nogil const: int buffer_length() nogil const:
if this._break != -1:
return this._break - this._b_i
else:
return this.length - this._b_i return this.length - this._b_i
uint64_t hash() nogil const:
cdef TokenC[11] sig
sig[0] = this.S_(2)[0]
sig[1] = this.S_(1)[0]
sig[2] = this.R_(this.S(1), 1)[0]
sig[3] = this.L_(this.S(0), 1)[0]
sig[4] = this.L_(this.S(0), 2)[0]
sig[5] = this.S_(0)[0]
sig[6] = this.R_(this.S(0), 2)[0]
sig[7] = this.R_(this.S(0), 1)[0]
sig[8] = this.B_(0)[0]
sig[9] = this.E_(0)[0]
sig[10] = this.E_(1)[0]
return hash64(sig, sizeof(sig), this._s_i) \
+ hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
void push_hist(int act) nogil:
ring_push(&this._hist, act+1)
int get_hist(int i) nogil:
return ring_get(&this._hist, i)
void push() nogil: void push() nogil:
if this.B(0) != -1: b0 = this.B(0)
this._stack[this._s_i] = this.B(0) if this._rebuffer.size():
this._s_i += 1 b0 = this._rebuffer.back()
this._rebuffer.pop_back()
else:
b0 = this._b_i
this._b_i += 1 this._b_i += 1
if this.safe_get(this.B_(0).l_edge).sent_start == 1: this._stack.push_back(b0)
this.set_break(this.B_(0).l_edge)
if this._b_i > this._break:
this._break = -1
void pop() nogil: void pop() nogil:
if this._s_i >= 1: this._stack.pop_back()
this._s_i -= 1
void force_final() nogil: void force_final() nogil:
# This should only be used in desperate situations, as it may leave # This should only be used in desperate situations, as it may leave
# the analysis in an unexpected state. # the analysis in an unexpected state.
this._s_i = 0 this._stack.clear()
this._b_i = this.length this._b_i = this.length
void unshift() nogil: void unshift() nogil:
this._b_i -= 1 s0 = this._stack.back()
this._buffer[this._b_i] = this.S(0) this._unshiftable[s0] = 1
this._s_i -= 1 this._rebuffer.push_back(s0)
this.shifted[this.B(0)] = True this._stack.pop_back()
int is_unshiftable(int item) nogil const:
if item >= this._unshiftable.size():
return 0
else:
return this._unshiftable.at(item)
void set_reshiftable(int item) nogil:
if item < this._unshiftable.size():
this._unshiftable[item] = 0
void add_arc(int head, int child, attr_t label) nogil: void add_arc(int head, int child, attr_t label) nogil:
if this.has_head(child): if this.has_head(child):
this.del_arc(this.H(child), child) this.del_arc(this.H(child), child)
cdef ArcC arc
cdef int dist = head - child arc.head = head
this._sent[child].head = dist arc.child = child
this._sent[child].dep = label arc.label = label
cdef int i if head > child:
if child > head: this._left_arcs.push_back(arc)
this._sent[head].r_kids += 1
# Some transition systems can have a word in the buffer have a
# rightward child, e.g. from Unshift.
this._sent[head].r_edge = this._sent[child].r_edge
i = 0
while this.has_head(head) and i < this.length:
head = this.H(head)
this._sent[head].r_edge = this._sent[child].r_edge
i += 1 # Guard against infinite loops
else: else:
this._sent[head].l_kids += 1 this._right_arcs.push_back(arc)
this._sent[head].l_edge = this._sent[child].l_edge this._heads[child] = head
void del_arc(int h_i, int c_i) nogil: void del_arc(int h_i, int c_i) nogil:
cdef int dist = h_i - c_i cdef vector[ArcC]* arcs
cdef TokenC* h = &this._sent[h_i] if h_i > c_i:
cdef int i = 0 arcs = &this._left_arcs
if c_i > h_i:
# this.R_(h_i, 2) returns the second-rightmost child token of h_i
# If we have more than 2 rightmost children, our 2nd rightmost child's
# rightmost edge is going to be our new rightmost edge.
h.r_edge = this.R_(h_i, 2).r_edge if h.r_kids >= 2 else h_i
h.r_kids -= 1
new_edge = h.r_edge
# Correct upwards in the tree --- see Issue #251
while h.head < 0 and i < this.length: # Guard infinite loop
h += h.head
h.r_edge = new_edge
i += 1
else: else:
# Same logic applies for left edge, but we don't need to walk up arcs = &this._right_arcs
# the tree, as the head is off the stack. if arcs.size() == 0:
h.l_edge = this.L_(h_i, 2).l_edge if h.l_kids >= 2 else h_i return
h.l_kids -= 1 arc = arcs.back()
if arc.head == h_i and arc.child == c_i:
arcs.pop_back()
else:
for i in range(arcs.size()-1):
arc = arcs.at(i)
if arc.head == h_i and arc.child == c_i:
arc.head = -1
arc.child = -1
arc.label = 0
break
SpanC get_ent() nogil const:
cdef SpanC ent
if this._ents.size() == 0:
ent.start = 0
ent.end = 0
ent.label = 0
return ent
else:
return this._ents.back()
void open_ent(attr_t label) nogil: void open_ent(attr_t label) nogil:
this._ents[this._e_i].start = this.B(0) cdef SpanC ent
this._ents[this._e_i].label = label ent.start = this.B(0)
this._ents[this._e_i].end = -1 ent.label = label
this._e_i += 1 ent.end = -1
this._ents.push_back(ent)
void close_ent() nogil: void close_ent() nogil:
# Note that we don't decrement _e_i here! We want to maintain all this._ents.back().end = this.B(0)+1
# entities, not over-write them...
this._ents[this._e_i-1].end = this.B(0)+1
this._sent[this.B(0)].ent_iob = 1
void set_ent_tag(int i, int ent_iob, attr_t ent_type) nogil:
if 0 <= i < this.length:
this._sent[i].ent_iob = ent_iob
this._sent[i].ent_type = ent_type
void set_break(int i) nogil:
if 0 <= i < this.length:
this._sent[i].sent_start = 1
this._break = this._b_i
void clone(const StateC* src) nogil: void clone(const StateC* src) nogil:
this.length = src.length this.length = src.length
memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) this._sent = src._sent
memcpy(this._stack, src._stack, this.length * sizeof(int)) this._stack = src._stack
memcpy(this._buffer, src._buffer, this.length * sizeof(int)) this._rebuffer = src._rebuffer
memcpy(this._ents, src._ents, this.length * sizeof(SpanC)) this._sent_starts = src._sent_starts
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) this._unshiftable = src._unshiftable
memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0]))
this._ents = src._ents
this._left_arcs = src._left_arcs
this._right_arcs = src._right_arcs
this._b_i = src._b_i this._b_i = src._b_i
this._s_i = src._s_i
this._e_i = src._e_i
this._break = src._break
this.offset = src.offset this.offset = src.offset
this._empty_token = src._empty_token this._empty_token = src._empty_token
void fast_forward() nogil:
# space token attachement policy:
# - attach space tokens always to the last preceding real token
# - except if it's the beginning of a sentence, then attach to the first following
# - boundary case: a document containing multiple space tokens but nothing else,
# then make the last space token the head of all others
while is_space_token(this.B_(0)) \
or this.buffer_length() == 0 \
or this.stack_depth() == 0:
if this.buffer_length() == 0:
# remove the last sentence's root from the stack
if this.stack_depth() == 1:
this.pop()
# parser got stuck: reduce stack or unshift
elif this.stack_depth() > 1:
if this.has_head(this.S(0)):
this.pop()
else:
this.unshift()
# stack is empty but there is another sentence on the buffer
elif (this.length - this._b_i) >= 1:
this.push()
else: # stack empty and nothing else coming
break
elif is_space_token(this.B_(0)):
# the normal case: we're somewhere inside a sentence
if this.stack_depth() > 0:
# assert not is_space_token(this.S_(0))
# attach all coming space tokens to their last preceding
# real token (which should be on the top of the stack)
while is_space_token(this.B_(0)):
this.add_arc(this.S(0),this.B(0),0)
this.push()
this.pop()
# the rare case: we're at the beginning of a document:
# space tokens are attached to the first real token on the buffer
elif this.stack_depth() == 0:
# store all space tokens on the stack until a real token shows up
# or the last token on the buffer is reached
while is_space_token(this.B_(0)) and this.buffer_length() > 1:
this.push()
# empty the stack by attaching all space tokens to the
# first token on the buffer
# boundary case: if all tokens are space tokens, the last one
# becomes the head of all others
while this.stack_depth() > 0:
this.add_arc(this.B(0),this.S(0),0)
this.pop()
# move the first token onto the stack
this.push()
elif this.stack_depth() == 0:
# for one token sentences (?)
if this.buffer_length() == 1:
this.push()
this.pop()
# with an empty stack and a non-empty buffer
# only shift is valid anyway
elif (this.length - this._b_i) >= 1:
this.push()
else: # can this even happen?
break

View File

@ -1,11 +1,7 @@
from .stateclass cimport StateClass from ._state cimport StateC
from ...typedefs cimport weight_t, attr_t from ...typedefs cimport weight_t, attr_t
from .transition_system cimport Transition, TransitionSystem from .transition_system cimport Transition, TransitionSystem
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
pass pass
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil

View File

@ -14,16 +14,11 @@ from ._state cimport StateC
from ...errors import Errors from ...errors import Errors
# Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string(u'subtok') cdef attr_t SUBTOK_LABEL = hash_string(u'subtok')
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
DEF USE_BREAK = True
# Break transition from here
# http://www.aclweb.org/anthology/P13-1074
cdef enum: cdef enum:
SHIFT SHIFT
REDUCE REDUCE
@ -61,9 +56,11 @@ cdef struct GoldParseStateC:
int32_t* n_kids int32_t* n_kids
int32_t length int32_t length
int32_t stride int32_t stride
weight_t push_cost
weight_t pop_cost
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
heads, labels, sent_starts) except *: heads, labels, sent_starts) except *:
cdef GoldParseStateC gs cdef GoldParseStateC gs
gs.length = len(heads) gs.length = len(heads)
@ -142,10 +139,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
if head != i: if head != i:
gs.kids[head][js[head]] = i gs.kids[head][js[head]] = i
js[head] += 1 js[head] += 1
gs.push_cost = push_cost(state, &gs)
gs.pop_cost = pop_cost(state, &gs)
return gs return gs
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil: cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil:
for i in range(gs.length): for i in range(gs.length):
gs.state_bits[i] = set_state_flag( gs.state_bits[i] = set_state_flag(
gs.state_bits[i], gs.state_bits[i],
@ -160,9 +159,9 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
gs.n_kids_in_stack[i] = 0 gs.n_kids_in_stack[i] = 0
gs.n_kids_in_buffer[i] = 0 gs.n_kids_in_buffer[i] = 0
for i in range(stcls.stack_depth()): for i in range(s.stack_depth()):
s_i = stcls.S(i) s_i = s.S(i)
if not is_head_unknown(gs, s_i): if not is_head_unknown(gs, s_i) and gs.heads[s_i] != s_i:
gs.n_kids_in_stack[gs.heads[s_i]] += 1 gs.n_kids_in_stack[gs.heads[s_i]] += 1
for kid in gs.kids[s_i][:gs.n_kids[s_i]]: for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
gs.state_bits[kid] = set_state_flag( gs.state_bits[kid] = set_state_flag(
@ -170,9 +169,11 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
HEAD_IN_STACK, HEAD_IN_STACK,
1 1
) )
for i in range(stcls.buffer_length()): for i in range(s.buffer_length()):
b_i = stcls.B(i) b_i = s.B(i)
if not is_head_unknown(gs, b_i): if s.is_sent_start(b_i):
break
if not is_head_unknown(gs, b_i) and gs.heads[b_i] != b_i:
gs.n_kids_in_buffer[gs.heads[b_i]] += 1 gs.n_kids_in_buffer[gs.heads[b_i]] += 1
for kid in gs.kids[b_i][:gs.n_kids[b_i]]: for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
gs.state_bits[kid] = set_state_flag( gs.state_bits[kid] = set_state_flag(
@ -180,6 +181,8 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
HEAD_IN_BUFFER, HEAD_IN_BUFFER,
1 1
) )
gs.push_cost = push_cost(s, gs)
gs.pop_cost = pop_cost(s, gs)
cdef class ArcEagerGold: cdef class ArcEagerGold:
@ -191,17 +194,17 @@ cdef class ArcEagerGold:
heads, labels = example.get_aligned_parse(projectivize=True) heads, labels = example.get_aligned_parse(projectivize=True)
labels = [label if label is not None else "" for label in labels] labels = [label if label is not None else "" for label in labels]
labels = [example.x.vocab.strings.add(label) for label in labels] labels = [example.x.vocab.strings.add(label) for label in labels]
sent_starts = example.get_aligned("SENT_START") sent_starts = example.get_aligned_sent_starts()
assert len(heads) == len(labels) == len(sent_starts) assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts) self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
def update(self, StateClass stcls): def update(self, StateClass stcls):
update_gold_state(&self.c, stcls) update_gold_state(&self.c, stcls.c)
cdef int check_state_gold(char state_bits, char flag) nogil: cdef int check_state_gold(char state_bits, char flag) nogil:
cdef char one = 1 cdef char one = 1
return state_bits & (one << flag) return 1 if (state_bits & (one << flag)) else 0
cdef int set_state_flag(char state_bits, char flag, int value) nogil: cdef int set_state_flag(char state_bits, char flag, int value) nogil:
@ -232,41 +235,30 @@ cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
# Helper functions for the arc-eager oracle # Helper functions for the arc-eager oracle
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = 0 cdef weight_t cost = 0
if is_head_in_stack(gold, target): b0 = state.B(0)
if b0 < 0:
return 9000
if is_head_in_stack(gold, b0):
cost += 1 cost += 1
cost += gold.n_kids_in_stack[target] cost += gold.n_kids_in_stack[b0]
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: if Break.is_valid(state, 0) and is_sent_start(gold, state.B(1)):
cost += 1 cost += 1
return cost return cost
cdef weight_t pop_cost(StateClass stcls, const void* _gold, int target) nogil: cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = 0 cdef weight_t cost = 0
if is_head_in_buffer(gold, target): s0 = state.S(0)
cost += 1 if s0 < 0:
cost += gold[0].n_kids_in_buffer[target] return 9000
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: if is_head_in_buffer(gold, s0):
cost += 1 cost += 1
cost += gold.n_kids_in_buffer[s0]
return cost return cost
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil:
gold = <const GoldParseStateC*>_gold
if arc_is_gold(gold, head, child):
return 0
elif stcls.H(child) == gold.heads[child]:
return 1
# Head in buffer
elif is_head_in_buffer(gold, child):
return 1
else:
return 0
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil: cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
if is_head_unknown(gold, child): if is_head_unknown(gold, child):
return True return True
@ -276,7 +268,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
return False return False
cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil: cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil:
if is_head_unknown(gold, child): if is_head_unknown(gold, child):
return True return True
elif label == 0: elif label == 0:
@ -292,218 +284,251 @@ cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
cdef class Shift: cdef class Shift:
"""Move the first word of the buffer onto the stack and mark it as "shifted"
Validity:
* If stack is empty
* At least two words in sentence
* Word has not been shifted before
Cost: push_cost
Action:
* Mark B[0] as 'shifted'
* Push stack
* Advance buffer
"""
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
sent_start = st._sent[st.B_(0).l_edge].sent_start if st.stack_depth() == 0:
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1 return 1
elif st.buffer_length() < 2:
return 0
elif st.is_sent_start(st.B(0)):
return 0
elif st.is_unshiftable(st.B(0)):
return 0
else:
return 1
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.push() st.push()
st.fast_forward()
@staticmethod @staticmethod
cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold gold = <const GoldParseStateC*>_gold
return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label) return gold.push_cost
@staticmethod
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
return push_cost(s, gold, s.B(0))
@staticmethod
cdef inline weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
return 0
cdef class Reduce: cdef class Reduce:
"""
Pop from the stack. If it has no head and the stack isn't empty, place
it back on the buffer.
Validity:
* Stack not empty
* Buffer nt empty
* Stack depth 1 and cannot sent start l_edge(st.B(0))
Cost:
* If B[0] is the start of a sentence, cost is 0
* Arcs between stack and buffer
* If arc has no head, we're saving arcs between S[0] and S[1:], so decrement
cost by those arcs.
"""
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
return st.stack_depth() >= 2 if st.stack_depth() == 0:
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
if st.has_head(st.S(0)):
st.pop()
else:
st.unshift()
st.fast_forward()
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass st, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
s0 = st.S(0)
cost = pop_cost(st, gold, s0)
return_to_buffer = not st.has_head(s0)
if return_to_buffer:
# Decrement cost for the arcs we save, as we'll be putting this
# back to the buffer
if is_head_in_stack(gold, s0):
cost -= 1
cost -= gold.n_kids_in_stack[s0]
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
cost -= 1
return cost
@staticmethod
cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
return 0
cdef class LeftArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0
sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.B(0), st.S(0), label)
st.pop()
st.fast_forward()
@staticmethod
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
cdef weight_t cost = 0
s0 = s.S(0)
b0 = s.B(0)
if arc_is_gold(gold, b0, s0):
# Have a negative cost if we 'recover' from the wrong dependency
return 0 if not s.has_head(s0) else -1
else:
# Account for deps we might lose between S0 and stack
if not s.has_head(s0):
cost += gold.n_kids_in_stack[s0]
if is_head_in_buffer(gold, s0):
cost += 1
return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
@staticmethod
cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
cdef class RightArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle.
if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0
sent_start = st._sent[st.B_(0).l_edge].sent_start
return sent_start != 1 and st.H(st.S(0)) != st.B(0)
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
st.fast_forward()
@staticmethod
cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
if arc_is_gold(gold, s.S(0), s.B(0)):
return 0
elif s.c.shifted[s.B(0)]:
return push_cost(s, gold, s.B(0))
else:
return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
@staticmethod
cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
cdef class Break:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int i
if not USE_BREAK:
return False return False
elif st.at_break(): elif st.buffer_length() == 0:
return False return True
elif st.stack_depth() < 1: elif st.stack_depth() == 1 and st.cannot_sent_start(st.l_edge(st.B(0))):
return False
elif st.B_(0).l_edge < 0:
return False
elif st._sent[st.B_(0).l_edge].sent_start < 0:
return False return False
else: else:
return True return True
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.set_break(st.B_(0).l_edge) if st.has_head(st.S(0)) or st.stack_depth() == 1:
st.fast_forward() st.pop()
@staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
@staticmethod
cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
gold = <const GoldParseStateC*>_gold
cost = 0
for i in range(s.stack_depth()):
S_i = s.S(i)
cost += gold.n_kids_in_buffer[S_i]
if is_head_in_buffer(gold, S_i):
cost += 1
# It's weird not to check the gold sentence boundaries but if we do,
# we can't account for "sunk costs", i.e. situations where we're already
# wrong.
s0_root = _get_root(s.S(0), gold)
b0_root = _get_root(s.B(0), gold)
if s0_root != b0_root or s0_root == -1 or b0_root == -1:
return cost
else: else:
return cost + 1 st.unshift()
@staticmethod @staticmethod
cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil: cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
if state.is_sent_start(state.B(0)):
return 0 return 0
s0 = state.S(0)
cost = gold.pop_cost
if not state.has_head(s0):
# Decrement cost for the arcs we save, as we'll be putting this
# back to the buffer
if is_head_in_stack(gold, s0):
cost -= 1
cost -= gold.n_kids_in_stack[s0]
return cost
cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
if is_head_unknown(gold, word): cdef class LeftArc:
return -1 """Add an arc between B[0] and S[0], replacing the previous head of S[0] if
while gold.heads[word] != word and word >= 0: one is set. Pop S[0] from the stack.
word = gold.heads[word]
if is_head_unknown(gold, word): Validity:
return -1 * len(S) >= 1
* len(B) >= 1
* not is_sent_start(B[0])
Cost:
pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0]))
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() == 0:
return 0
elif st.buffer_length() == 0:
return 0
elif st.is_sent_start(st.B(0)):
return 0
elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
return 0
else: else:
return word return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.B(0), st.S(0), label)
# If we change the stack, it's okay to remove the shifted mark, as
# we can't get in an infinite loop this way.
st.set_reshiftable(st.B(0))
st.pop()
@staticmethod
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
cdef weight_t cost = gold.pop_cost
s0 = state.S(0)
s1 = state.S(1)
b0 = state.B(0)
if state.has_head(s0):
# Increment cost if we're clobbering a correct arc
cost += gold.heads[s0] == s1
else:
# If there's no head, we're losing arcs between S0 and S[1:].
cost += is_head_in_stack(gold, s0)
cost += gold.n_kids_in_stack[s0]
if b0 != -1 and s0 != -1 and gold.heads[s0] == b0:
cost -= 1
cost += not label_is_gold(gold, s0, label)
return cost
cdef class RightArc:
"""
Add an arc from S[0] to B[0]. Push B[0].
Validity:
* len(S) >= 1
* len(B) >= 1
* not is_sent_start(B[0])
Cost:
push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label)
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() == 0:
return 0
elif st.buffer_length() == 0:
return 0
elif st.is_sent_start(st.B(0)):
return 0
elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
# If there's (perhaps partial) parse pre-set, don't allow cycle.
return 0
else:
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
@staticmethod
cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
cost = gold.push_cost
s0 = state.S(0)
b0 = state.B(0)
if s0 != -1 and b0 != -1 and gold.heads[b0] == s0:
cost -= 1
cost += not label_is_gold(gold, b0, label)
elif is_head_in_buffer(gold, b0) and not state.is_unshiftable(b0):
cost += 1
return cost
cdef class Break:
"""Mark the second word of the buffer as the start of a
sentence.
Validity:
* len(buffer) >= 2
* B[1] == B[0] + 1
* not is_sent_start(B[1])
* not cannot_sent_start(B[1])
Action:
* mark_sent_start(B[1])
Cost:
* not is_sent_start(B[1])
* Arcs between B[0] and B[1:]
* Arcs between S and B[1]
"""
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int i
if st.buffer_length() < 2:
return False
elif st.B(1) != st.B(0) + 1:
return False
elif st.is_sent_start(st.B(1)):
return False
elif st.cannot_sent_start(st.B(1)):
return False
else:
return True
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.set_sent_start(st.B(1), 1)
@staticmethod
cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
gold = <const GoldParseStateC*>_gold
cdef int b0 = state.B(0)
cdef int cost = 0
cdef int si
for i in range(state.stack_depth()):
si = state.S(i)
if is_head_in_buffer(gold, si):
cost += 1
cost += gold.n_kids_in_buffer[si]
# We need to score into B[1:], so subtract deps that are at b0
if gold.heads[b0] == si:
cost -= 1
if gold.heads[si] == b0:
cost -= 1
if not is_sent_start(gold, state.B(1)) \
and not is_sent_start_unknown(gold, state.B(1)):
cost += 1
return cost
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
st = new StateC(<const TokenC*>tokens, length) st = new StateC(<const TokenC*>tokens, length)
for i in range(st.length):
if st._sent[i].dep == 0:
st._sent[i].l_edge = i
st._sent[i].r_edge = i
st._sent[i].head = 0
st._sent[i].dep = 0
st._sent[i].l_kids = 0
st._sent[i].r_kids = 0
st.fast_forward()
return <void*>st return <void*>st
@ -515,6 +540,8 @@ cdef int _del_state(Pool mem, void* state, void* x) except -1:
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
TransitionSystem.__init__(self, *args, **kwargs) TransitionSystem.__init__(self, *args, **kwargs)
self.init_beam_state = _init_state
self.del_beam_state = _del_state
@classmethod @classmethod
def get_actions(cls, **kwargs): def get_actions(cls, **kwargs):
@ -537,7 +564,7 @@ cdef class ArcEager(TransitionSystem):
label = 'ROOT' label = 'ROOT'
if head == child: if head == child:
actions[BREAK][label] += 1 actions[BREAK][label] += 1
elif head < child: if head < child:
actions[RIGHT][label] += 1 actions[RIGHT][label] += 1
actions[REDUCE][''] += 1 actions[REDUCE][''] += 1
elif head > child: elif head > child:
@ -567,8 +594,14 @@ cdef class ArcEager(TransitionSystem):
t.do(state.c, t.label) t.do(state.c, t.label)
return state return state
def is_gold_parse(self, StateClass state, gold): def is_gold_parse(self, StateClass state, ArcEagerGold gold):
raise NotImplementedError for i in range(state.c.length):
token = state.c.safe_get(i)
if not arc_is_gold(&gold.c, i, i+token.head):
return False
elif not label_is_gold(&gold.c, i, token.dep):
return False
return True
def init_gold(self, StateClass state, Example example): def init_gold(self, StateClass state, Example example):
gold = ArcEagerGold(self, state, example) gold = ArcEagerGold(self, state, example)
@ -576,6 +609,7 @@ cdef class ArcEager(TransitionSystem):
return gold return gold
def init_gold_batch(self, examples): def init_gold_batch(self, examples):
# TODO: Projectivitity?
all_states = self.init_batch([eg.predicted for eg in examples]) all_states = self.init_batch([eg.predicted for eg in examples])
golds = [] golds = []
states = [] states = []
@ -662,24 +696,13 @@ cdef class ArcEager(TransitionSystem):
raise ValueError(Errors.E019.format(action=move, src='arc_eager')) raise ValueError(Errors.E019.format(action=move, src='arc_eager'))
return t return t
cdef int initialize_state(self, StateC* st) nogil: def set_annotations(self, StateClass state, Doc doc):
for i in range(st.length): for arc in state.arcs:
if st._sent[i].dep == 0: doc.c[arc["child"]].head = arc["head"] - arc["child"]
st._sent[i].l_edge = i doc.c[arc["child"]].dep = arc["label"]
st._sent[i].r_edge = i for i in range(doc.length):
st._sent[i].head = 0 if doc.c[i].head == 0:
st._sent[i].dep = 0 doc.c[i].dep = self.root_label
st._sent[i].l_kids = 0
st._sent[i].r_kids = 0
st.fast_forward()
cdef int finalize_state(self, StateC* st) nogil:
cdef int i
for i in range(st.length):
if st._sent[i].head == 0:
st._sent[i].dep = self.root_label
def finalize_doc(self, Doc doc):
set_children_from_heads(doc.c, 0, doc.length) set_children_from_heads(doc.c, 0, doc.length)
def has_gold(self, Example eg, start=0, end=None): def has_gold(self, Example eg, start=0, end=None):
@ -690,7 +713,7 @@ cdef class ArcEager(TransitionSystem):
return False return False
cdef int set_valid(self, int* output, const StateC* st) nogil: cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid cdef int[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(st, 0) is_valid[SHIFT] = Shift.is_valid(st, 0)
is_valid[REDUCE] = Reduce.is_valid(st, 0) is_valid[REDUCE] = Reduce.is_valid(st, 0)
is_valid[LEFT] = LeftArc.is_valid(st, 0) is_valid[LEFT] = LeftArc.is_valid(st, 0)
@ -710,29 +733,31 @@ cdef class ArcEager(TransitionSystem):
gold_state = gold_.c gold_state = gold_.c
n_gold = 0 n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label): if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
else: else:
cost = 9000 cost = 9000
return cost return cost
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, gold) except -1: const StateC* state, gold) except -1:
if not isinstance(gold, ArcEagerGold): if not isinstance(gold, ArcEagerGold):
raise TypeError(Errors.E909.format(name="ArcEagerGold")) raise TypeError(Errors.E909.format(name="ArcEagerGold"))
cdef ArcEagerGold gold_ = gold cdef ArcEagerGold gold_ = gold
gold_.update(stcls)
gold_state = gold_.c gold_state = gold_.c
update_gold_state(&gold_state, state)
self.set_valid(is_valid, state)
cdef int n_gold = 0 cdef int n_gold = 0
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].is_valid(stcls.c, self.c[i].label): if is_valid[i]:
is_valid[i] = True costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
if costs[i] <= 0: if costs[i] <= 0:
n_gold += 1 n_gold += 1
else: else:
is_valid[i] = False
costs[i] = 9000 costs[i] = 9000
if n_gold < 1: if n_gold < 1:
for i in range(self.n_moves):
print(self.get_class_name(i), is_valid[i], costs[i])
print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
raise ValueError raise ValueError
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None): def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
@ -748,12 +773,13 @@ cdef class ArcEager(TransitionSystem):
failed = False failed = False
while not state.is_final(): while not state.is_final():
try: try:
self.set_costs(is_valid, costs, state, gold) self.set_costs(is_valid, costs, state.c, gold)
except ValueError: except ValueError:
failed = True failed = True
break break
min_cost = min(costs[i] for i in range(self.n_moves))
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= min_cost:
action = self.c[i] action = self.c[i]
history.append(i) history.append(i)
s0 = state.S(0) s0 = state.S(0)
@ -762,9 +788,7 @@ cdef class ArcEager(TransitionSystem):
example = _debug example = _debug
debug_log.append(" ".join(( debug_log.append(" ".join((
self.get_class_name(i), self.get_class_name(i),
"S0=", (example.x[s0].text if s0 >= 0 else "__"), state.print_state()
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
"S0 head?", str(state.has_head(state.S(0))),
))) )))
action.do(state.c, action.label) action.do(state.c, action.label)
break break
@ -783,6 +807,8 @@ cdef class ArcEager(TransitionSystem):
print("Aligned heads") print("Aligned heads")
for i, head in enumerate(aligned_heads): for i, head in enumerate(aligned_heads):
print(example.x[i], example.x[head] if head is not None else "__") print(example.x[i], example.x[head] if head is not None else "__")
print("Aligned sent starts")
print(example.get_aligned_sent_starts())
print("Predicted tokens") print("Predicted tokens")
print([(w.i, w.text) for w in example.x]) print([(w.i, w.text) for w in example.x])

View File

@ -3,9 +3,12 @@ from cymem.cymem cimport Pool
from collections import Counter from collections import Counter
from ...tokens.doc cimport Doc
from ...tokens.span import Span
from ...typedefs cimport weight_t, attr_t from ...typedefs cimport weight_t, attr_t
from ...lexeme cimport Lexeme from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE from ...attrs cimport IS_SPACE
from ...structs cimport TokenC
from ...training.example cimport Example from ...training.example cimport Example
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
@ -46,17 +49,17 @@ cdef class BiluoGold:
def __init__(self, BiluoPushDown moves, StateClass stcls, Example example): def __init__(self, BiluoPushDown moves, StateClass stcls, Example example):
self.mem = Pool() self.mem = Pool()
self.c = create_gold_state(self.mem, moves, stcls, example) self.c = create_gold_state(self.mem, moves, stcls.c, example)
def update(self, StateClass stcls): def update(self, StateClass stcls):
update_gold_state(&self.c, stcls) update_gold_state(&self.c, stcls.c)
cdef GoldNERStateC create_gold_state( cdef GoldNERStateC create_gold_state(
Pool mem, Pool mem,
BiluoPushDown moves, BiluoPushDown moves,
StateClass stcls, const StateC* stcls,
Example example Example example
) except *: ) except *:
cdef GoldNERStateC gs cdef GoldNERStateC gs
@ -67,7 +70,7 @@ cdef GoldNERStateC create_gold_state(
return gs return gs
cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *: cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
# We don't need to update each time, unlike the parser. # We don't need to update each time, unlike the parser.
pass pass
@ -75,14 +78,15 @@ cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
cdef do_func_t[N_MOVES] do_funcs cdef do_func_t[N_MOVES] do_funcs
cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil: cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
if not st.entity_is_open(): if not state.entity_is_open():
return False return False
cdef const Transition* gold = &golds[st.E(0)] cdef const Transition* gold = &golds[state.E(0)]
ent = state.get_ent()
if gold.move != BEGIN and gold.move != UNIT: if gold.move != BEGIN and gold.move != UNIT:
return True return True
elif gold.label != st.E_(0).ent_type: elif gold.label != ent.label:
return True return True
else: else:
return False return False
@ -228,15 +232,18 @@ cdef class BiluoPushDown(TransitionSystem):
self.labels[action][label_name] = -1 self.labels[action][label_name] = -1
return 1 return 1
cdef int initialize_state(self, StateC* st) nogil: def set_annotations(self, StateClass state, Doc doc):
# This is especially necessary when we use limited training data. cdef int i
for i in range(st.length): ents = []
if st._sent[i].ent_type != 0: for i in range(state.c._ents.size()):
with gil: ent = state.c._ents.at(i)
self.add_action(BEGIN, st._sent[i].ent_type) if ent.start != -1 and ent.end != -1:
self.add_action(IN, st._sent[i].ent_type) ents.append(Span(doc, ent.start, ent.end, label=ent.label))
self.add_action(UNIT, st._sent[i].ent_type) doc.set_ents(ents, default="unmodified")
self.add_action(LAST, st._sent[i].ent_type) # Set non-blocked tokens to O
for i in range(doc.length):
if doc.c[i].ent_iob == 0:
doc.c[i].ent_iob = 2
def init_gold(self, StateClass state, Example example): def init_gold(self, StateClass state, Example example):
return BiluoGold(self, state, example) return BiluoGold(self, state, example)
@ -255,26 +262,25 @@ cdef class BiluoPushDown(TransitionSystem):
gold_state = gold_.c gold_state = gold_.c
n_gold = 0 n_gold = 0
if self.c[i].is_valid(stcls.c, self.c[i].label): if self.c[i].is_valid(stcls.c, self.c[i].label):
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
else: else:
cost = 9000 cost = 9000
return cost return cost
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, gold) except -1: const StateC* state, gold) except -1:
if not isinstance(gold, BiluoGold): if not isinstance(gold, BiluoGold):
raise TypeError(Errors.E909.format(name="BiluoGold")) raise TypeError(Errors.E909.format(name="BiluoGold"))
cdef BiluoGold gold_ = gold cdef BiluoGold gold_ = gold
gold_.update(stcls)
gold_state = gold_.c gold_state = gold_.c
update_gold_state(&gold_state, state)
n_gold = 0 n_gold = 0
self.set_valid(is_valid, state)
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].is_valid(stcls.c, self.c[i].label): if is_valid[i]:
is_valid[i] = 1 costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
n_gold += costs[i] <= 0 n_gold += costs[i] <= 0
else: else:
is_valid[i] = 0
costs[i] = 9000 costs[i] = 9000
if n_gold < 1: if n_gold < 1:
raise ValueError raise ValueError
@ -290,7 +296,7 @@ cdef class Missing:
pass pass
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
return 9000 return 9000
@ -299,10 +305,10 @@ cdef class Begin:
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int preset_ent_iob = st.B_(0).ent_iob cdef int preset_ent_iob = st.B_(0).ent_iob
cdef attr_t preset_ent_label = st.B_(0).ent_type cdef attr_t preset_ent_label = st.B_(0).ent_type
# If we're the last token of the input, we can't B -- must U or O. if st.entity_is_open():
if st.B(1) == -1:
return False return False
elif st.entity_is_open(): if st.buffer_length() < 2:
# If we're the last token of the input, we can't B -- must U or O.
return False return False
elif label == 0: elif label == 0:
return False return False
@ -337,12 +343,11 @@ cdef class Begin:
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.open_ent(label) st.open_ent(label)
st.set_ent_tag(st.B(0), 3, label)
st.push() st.push()
st.pop() st.pop()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label cdef attr_t g_tag = gold.ner[s.B(0)].label
@ -366,16 +371,17 @@ cdef class Begin:
cdef class In: cdef class In:
@staticmethod @staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef bint is_valid(const StateC* st, attr_t label) nogil:
if not st.entity_is_open():
return False
if st.buffer_length() < 2:
# If we're at the end, we can't I.
return False
ent = st.get_ent()
cdef int preset_ent_iob = st.B_(0).ent_iob cdef int preset_ent_iob = st.B_(0).ent_iob
cdef attr_t preset_ent_label = st.B_(0).ent_type cdef attr_t preset_ent_label = st.B_(0).ent_type
if label == 0: if label == 0:
return False return False
elif st.E_(0).ent_type != label: elif ent.label != label:
return False
elif not st.entity_is_open():
return False
elif st.B(1) == -1:
# If we're at the end, we can't I.
return False return False
elif preset_ent_iob == 3: elif preset_ent_iob == 3:
return False return False
@ -401,12 +407,11 @@ cdef class In:
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.set_ent_tag(st.B(0), 1, label)
st.push() st.push()
st.pop() st.pop()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold gold = <GoldNERStateC*>_gold
move = IN move = IN
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
@ -457,7 +462,7 @@ cdef class Last:
# Otherwise, force acceptance, even if we're across a sentence # Otherwise, force acceptance, even if we're across a sentence
# boundary or the token is whitespace. # boundary or the token is whitespace.
return True return True
elif st.E_(0).ent_type != label: elif st.get_ent().label != label:
return False return False
elif st.B_(1).ent_iob == 1: elif st.B_(1).ent_iob == 1:
# If a preset entity has I next, we can't L here. # If a preset entity has I next, we can't L here.
@ -468,12 +473,11 @@ cdef class Last:
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.close_ent() st.close_ent()
st.set_ent_tag(st.B(0), 1, label)
st.push() st.push()
st.pop() st.pop()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold gold = <GoldNERStateC*>_gold
move = LAST move = LAST
@ -537,12 +541,11 @@ cdef class Unit:
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.open_ent(label) st.open_ent(label)
st.close_ent() st.close_ent()
st.set_ent_tag(st.B(0), 3, label)
st.push() st.push()
st.pop() st.pop()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label cdef attr_t g_tag = gold.ner[s.B(0)].label
@ -578,12 +581,11 @@ cdef class Out:
@staticmethod @staticmethod
cdef int transition(StateC* st, attr_t label) nogil: cdef int transition(StateC* st, attr_t label) nogil:
st.set_ent_tag(st.B(0), 2, 0)
st.push() st.push()
st.pop() st.pop()
@staticmethod @staticmethod
cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
gold = <GoldNERStateC*>_gold gold = <GoldNERStateC*>_gold
cdef int g_act = gold.ner[s.B(0)].move cdef int g_act = gold.ner[s.B(0)].move
cdef attr_t g_tag = gold.ner[s.B(0)].label cdef attr_t g_tag = gold.ner[s.B(0)].label

View File

@ -2,30 +2,24 @@ from cymem.cymem cimport Pool
from ...structs cimport TokenC, SpanC from ...structs cimport TokenC, SpanC
from ...typedefs cimport attr_t from ...typedefs cimport attr_t
from ...tokens.doc cimport Doc
from ._state cimport StateC from ._state cimport StateC
cdef class StateClass: cdef class StateClass:
cdef Pool mem
cdef StateC* c cdef StateC* c
cdef readonly Doc doc
cdef int _borrowed cdef int _borrowed
@staticmethod @staticmethod
cdef inline StateClass init(const TokenC* sent, int length): cdef inline StateClass borrow(StateC* ptr, Doc doc):
cdef StateClass self = StateClass() cdef StateClass self = StateClass()
self.c = new StateC(sent, length)
return self
@staticmethod
cdef inline StateClass borrow(StateC* ptr):
cdef StateClass self = StateClass()
del self.c
self.c = ptr self.c = ptr
self._borrowed = 1 self._borrowed = 1
self.doc = doc
return self return self
@staticmethod @staticmethod
cdef inline StateClass init_offset(const TokenC* sent, int length, int cdef inline StateClass init_offset(const TokenC* sent, int length, int
offset): offset):
@ -33,105 +27,3 @@ cdef class StateClass:
self.c = new StateC(sent, length) self.c = new StateC(sent, length)
self.c.offset = offset self.c.offset = offset
return self return self
cdef inline int S(self, int i) nogil:
return self.c.S(i)
cdef inline int B(self, int i) nogil:
return self.c.B(i)
cdef inline const TokenC* S_(self, int i) nogil:
return self.c.S_(i)
cdef inline const TokenC* B_(self, int i) nogil:
return self.c.B_(i)
cdef inline const TokenC* H_(self, int i) nogil:
return self.c.H_(i)
cdef inline const TokenC* E_(self, int i) nogil:
return self.c.E_(i)
cdef inline const TokenC* L_(self, int i, int idx) nogil:
return self.c.L_(i, idx)
cdef inline const TokenC* R_(self, int i, int idx) nogil:
return self.c.R_(i, idx)
cdef inline const TokenC* safe_get(self, int i) nogil:
return self.c.safe_get(i)
cdef inline int H(self, int i) nogil:
return self.c.H(i)
cdef inline int E(self, int i) nogil:
return self.c.E(i)
cdef inline int L(self, int i, int idx) nogil:
return self.c.L(i, idx)
cdef inline int R(self, int i, int idx) nogil:
return self.c.R(i, idx)
cdef inline bint empty(self) nogil:
return self.c.empty()
cdef inline bint eol(self) nogil:
return self.c.eol()
cdef inline bint at_break(self) nogil:
return self.c.at_break()
cdef inline bint has_head(self, int i) nogil:
return self.c.has_head(i)
cdef inline int n_L(self, int i) nogil:
return self.c.n_L(i)
cdef inline int n_R(self, int i) nogil:
return self.c.n_R(i)
cdef inline bint stack_is_connected(self) nogil:
return False
cdef inline bint entity_is_open(self) nogil:
return self.c.entity_is_open()
cdef inline int stack_depth(self) nogil:
return self.c.stack_depth()
cdef inline int buffer_length(self) nogil:
return self.c.buffer_length()
cdef inline void push(self) nogil:
self.c.push()
cdef inline void pop(self) nogil:
self.c.pop()
cdef inline void unshift(self) nogil:
self.c.unshift()
cdef inline void add_arc(self, int head, int child, attr_t label) nogil:
self.c.add_arc(head, child, label)
cdef inline void del_arc(self, int head, int child) nogil:
self.c.del_arc(head, child)
cdef inline void open_ent(self, attr_t label) nogil:
self.c.open_ent(label)
cdef inline void close_ent(self) nogil:
self.c.close_ent()
cdef inline void set_ent_tag(self, int i, int ent_iob, attr_t ent_type) nogil:
self.c.set_ent_tag(i, ent_iob, ent_type)
cdef inline void set_break(self, int i) nogil:
self.c.set_break(i)
cdef inline void clone(self, StateClass src) nogil:
self.c.clone(src.c)
cdef inline void fast_forward(self) nogil:
self.c.fast_forward()

View File

@ -1,17 +1,20 @@
# cython: infer_types=True # cython: infer_types=True
import numpy import numpy
from libcpp.vector cimport vector
from ._state cimport ArcC
from ...tokens.doc cimport Doc from ...tokens.doc cimport Doc
cdef class StateClass: cdef class StateClass:
def __init__(self, Doc doc=None, int offset=0): def __init__(self, Doc doc=None, int offset=0):
cdef Pool mem = Pool()
self.mem = mem
self._borrowed = 0 self._borrowed = 0
if doc is not None: if doc is not None:
self.c = new StateC(doc.c, doc.length) self.c = new StateC(doc.c, doc.length)
self.c.offset = offset self.c.offset = offset
self.doc = doc
else:
self.doc = None
def __dealloc__(self): def __dealloc__(self):
if self._borrowed != 1: if self._borrowed != 1:
@ -19,36 +22,157 @@ cdef class StateClass:
@property @property
def stack(self): def stack(self):
return {self.S(i) for i in range(self.c._s_i)} return [self.S(i) for i in range(self.c.stack_depth())]
@property @property
def queue(self): def queue(self):
return {self.B(i) for i in range(self.c.buffer_length())} return [self.B(i) for i in range(self.c.buffer_length())]
@property @property
def token_vector_lenth(self): def token_vector_lenth(self):
return self.doc.tensor.shape[1] return self.doc.tensor.shape[1]
@property @property
def history(self): def arcs(self):
hist = numpy.ndarray((8,), dtype='i') cdef vector[ArcC] arcs
for i in range(8): self.c.get_arcs(&arcs)
hist[i] = self.c.get_hist(i+1) return list(arcs)
return hist #py_arcs = []
#for arc in arcs:
# if arc.head != -1 and arc.child != -1:
# py_arcs.append((arc.head, arc.child, arc.label))
#return arcs
def add_arc(self, int head, int child, int label):
self.c.add_arc(head, child, label)
def del_arc(self, int head, int child):
self.c.del_arc(head, child)
def H(self, int child):
return self.c.H(child)
def L(self, int head, int idx):
return self.c.L(head, idx)
def R(self, int head, int idx):
return self.c.R(head, idx)
@property
def _b_i(self):
return self.c._b_i
@property
def length(self):
return self.c.length
def is_final(self): def is_final(self):
return self.c.is_final() return self.c.is_final()
def copy(self): def copy(self):
cdef StateClass new_state = StateClass.init(self.c._sent, self.c.length) cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset)
new_state.c.clone(self.c) new_state.c.clone(self.c)
return new_state return new_state
def print_state(self, words): def print_state(self):
words = [token.text for token in self.doc]
words = list(words) + ['_'] words = list(words) + ['_']
top = f"{words[self.S(0)]}_{self.S_(0).head}" bools = ["F", "T"]
second = f"{words[self.S(1)]}_{self.S_(1).head}" sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))]
third = f"{words[self.S(2)]}_{self.S_(2).head}" shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)]
n0 = words[self.B(0)] shifted.append("")
n1 = words[self.B(1)] sent_starts.append("")
return ' '.join((third, second, top, '|', n0, n1)) top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}"
second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}"
third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}"
n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}"
n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}"
return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1))
def S(self, int i):
return self.c.S(i)
def B(self, int i):
return self.c.B(i)
def H(self, int i):
return self.c.H(i)
def E(self, int i):
return self.c.E(i)
def L(self, int i, int idx):
return self.c.L(i, idx)
def R(self, int i, int idx):
return self.c.R(i, idx)
def S_(self, int i):
return self.doc[self.c.S(i)]
def B_(self, int i):
return self.doc[self.c.B(i)]
def H_(self, int i):
return self.doc[self.c.H(i)]
def E_(self, int i):
return self.doc[self.c.E(i)]
def L_(self, int i, int idx):
return self.doc[self.c.L(i, idx)]
def R_(self, int i, int idx):
return self.doc[self.c.R(i, idx)]
def empty(self):
return self.c.empty()
def eol(self):
return self.c.eol()
def at_break(self):
return False
#return self.c.at_break()
def has_head(self, int i):
return self.c.has_head(i)
def n_L(self, int i):
return self.c.n_L(i)
def n_R(self, int i):
return self.c.n_R(i)
def entity_is_open(self):
return self.c.entity_is_open()
def stack_depth(self):
return self.c.stack_depth()
def buffer_length(self):
return self.c.buffer_length()
def push(self):
self.c.push()
def pop(self):
self.c.pop()
def unshift(self):
self.c.unshift()
def add_arc(self, int head, int child, attr_t label):
self.c.add_arc(head, child, label)
def del_arc(self, int head, int child):
self.c.del_arc(head, child)
def open_ent(self, attr_t label):
self.c.open_ent(label)
def close_ent(self):
self.c.close_ent()
def clone(self, StateClass src):
self.c.clone(src.c)

View File

@ -16,14 +16,14 @@ cdef struct Transition:
weight_t score weight_t score
bint (*is_valid)(const StateC* state, attr_t label) nogil bint (*is_valid)(const StateC* state, attr_t label) nogil
weight_t (*get_cost)(StateClass state, const void* gold, attr_t label) nogil weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil
int (*do)(StateC* state, attr_t label) nogil int (*do)(StateC* state, attr_t label) nogil
ctypedef weight_t (*get_cost_func_t)(StateClass state, const void* gold, ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
attr_tlabel) nogil attr_tlabel) nogil
ctypedef weight_t (*move_cost_func_t)(StateClass state, const void* gold) nogil ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
ctypedef weight_t (*label_cost_func_t)(StateClass state, const void* ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
gold, attr_t label) nogil gold, attr_t label) nogil
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
@ -41,9 +41,8 @@ cdef class TransitionSystem:
cdef public attr_t root_label cdef public attr_t root_label
cdef public freqs cdef public freqs
cdef public object labels cdef public object labels
cdef init_state_t init_beam_state
cdef int initialize_state(self, StateC* state) nogil cdef del_state_t del_beam_state
cdef int finalize_state(self, StateC* state) nogil
cdef Transition lookup_transition(self, object name) except * cdef Transition lookup_transition(self, object name) except *
@ -52,4 +51,4 @@ cdef class TransitionSystem:
cdef int set_valid(self, int* output, const StateC* st) nogil cdef int set_valid(self, int* output, const StateC* st) nogil
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass state, gold) except -1 const StateC* state, gold) except -1

View File

@ -5,6 +5,7 @@ from cymem.cymem cimport Pool
from collections import Counter from collections import Counter
import srsly import srsly
from . cimport _beam_utils
from ...typedefs cimport weight_t, attr_t from ...typedefs cimport weight_t, attr_t
from ...tokens.doc cimport Doc from ...tokens.doc cimport Doc
from ...structs cimport TokenC from ...structs cimport TokenC
@ -44,6 +45,8 @@ cdef class TransitionSystem:
if labels_by_action: if labels_by_action:
self.initialize_actions(labels_by_action, min_freq=min_freq) self.initialize_actions(labels_by_action, min_freq=min_freq)
self.root_label = self.strings.add('ROOT') self.root_label = self.strings.add('ROOT')
self.init_beam_state = _init_state
self.del_beam_state = _del_state
def __reduce__(self): def __reduce__(self):
return (self.__class__, (self.strings, self.labels), None, None) return (self.__class__, (self.strings, self.labels), None, None)
@ -54,7 +57,6 @@ cdef class TransitionSystem:
offset = 0 offset = 0
for doc in docs: for doc in docs:
state = StateClass(doc, offset=offset) state = StateClass(doc, offset=offset)
self.initialize_state(state.c)
states.append(state) states.append(state)
offset += len(doc) offset += len(doc)
return states return states
@ -80,7 +82,7 @@ cdef class TransitionSystem:
history = [] history = []
debug_log = [] debug_log = []
while not state.is_final(): while not state.is_final():
self.set_costs(is_valid, costs, state, gold) self.set_costs(is_valid, costs, state.c, gold)
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= 0:
action = self.c[i] action = self.c[i]
@ -124,15 +126,6 @@ cdef class TransitionSystem:
action = self.lookup_transition(name) action = self.lookup_transition(name)
action.do(state.c, action.label) action.do(state.c, action.label)
cdef int initialize_state(self, StateC* state) nogil:
pass
cdef int finalize_state(self, StateC* state) nogil:
pass
def finalize_doc(self, doc):
pass
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError raise NotImplementedError
@ -151,7 +144,7 @@ cdef class TransitionSystem:
is_valid[i] = self.c[i].is_valid(st, self.c[i].label) is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
cdef int set_costs(self, int* is_valid, weight_t* costs, cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, gold) except -1: const StateC* state, gold) except -1:
raise NotImplementedError raise NotImplementedError
def get_class_name(self, int clas): def get_class_name(self, int clas):

View File

@ -105,6 +105,93 @@ def make_parser(
update_with_oracle_cut_size=update_with_oracle_cut_size, update_with_oracle_cut_size=update_with_oracle_cut_size,
multitasks=[], multitasks=[],
learn_tokens=learn_tokens, learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
beam_width=1,
beam_density=0.0,
beam_update_prob=0.0,
)
@Language.factory(
"beam_parser",
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
default_config={
"beam_width": 8,
"beam_density": 0.01,
"beam_update_prob": 0.5,
"moves": None,
"update_with_oracle_cut_size": 100,
"learn_tokens": False,
"min_action_freq": 30,
"model": DEFAULT_PARSER_MODEL,
},
default_score_weights={
"dep_uas": 0.5,
"dep_las": 0.5,
"dep_las_per_type": None,
"sents_p": None,
"sents_r": None,
"sents_f": 0.0,
},
)
def make_beam_parser(
nlp: Language,
name: str,
model: Model,
moves: Optional[list],
update_with_oracle_cut_size: int,
learn_tokens: bool,
min_action_freq: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
):
"""Create a transition-based DependencyParser component that uses beam-search.
The dependency parser jointly learns sentence segmentation and labelled
dependency parsing, and can optionally learn to merge tokens that had been
over-segmented by the tokenizer.
The parser uses a variant of the non-monotonic arc-eager transition-system
described by Honnibal and Johnson (2014), with the addition of a "break"
transition to perform the sentence segmentation. Nivre's pseudo-projective
dependency transformation is used to allow the parser to predict
non-projective parses.
The parser is trained using a global objective. That is, it learns to assign
probabilities to whole parses.
model (Model): The model for the transition-based parser. The model needs
to have a specific substructure of named components --- see the
spacy.ml.tb_framework.TransitionModel for details.
moves (List[str]): A list of transition names. Inferred from the data if not
provided.
beam_width (int): The number of candidate analyses to maintain.
beam_density (float): The minimum ratio between the scores of the first and
last candidates in the beam. This allows the parser to avoid exploring
candidates that are too far behind. This is mostly intended to improve
efficiency, but it can also improve accuracy as deeper search is not
always better.
beam_update_prob (float): The chance of making a beam update, instead of a
greedy update. Greedy updates are an approximation for the beam updates,
and are faster to compute.
learn_tokens (bool): Whether to learn to merge subtokens that are split
relative to the gold standard. Experimental.
min_action_freq (int): The minimum frequency of labelled actions to retain.
Rarer labelled actions have their label backed-off to "dep". While this
primarily affects the label accuracy, it can also affect the attachment
structure, as the labels are used to represent the pseudo-projectivity
transformation.
"""
return DependencyParser(
nlp.vocab,
model,
name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
multitasks=[],
learn_tokens=learn_tokens,
min_action_freq=min_action_freq min_action_freq=min_action_freq
) )

View File

@ -82,6 +82,79 @@ def make_ner(
multitasks=[], multitasks=[],
min_action_freq=1, min_action_freq=1,
learn_tokens=False, learn_tokens=False,
beam_width=1,
beam_density=0.0,
beam_update_prob=0.0,
)
@Language.factory(
"beam_ner",
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
default_config={
"moves": None,
"update_with_oracle_cut_size": 100,
"model": DEFAULT_NER_MODEL,
"beam_density": 0.01,
"beam_update_prob": 0.5,
"beam_width": 32
},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
)
def make_beam_ner(
nlp: Language,
name: str,
model: Model,
moves: Optional[list],
update_with_oracle_cut_size: int,
beam_width: int,
beam_density: float,
beam_update_prob: float,
):
"""Create a transition-based EntityRecognizer component that uses beam-search.
The entity recognizer identifies non-overlapping labelled spans of tokens.
The transition-based algorithm used encodes certain assumptions that are
effective for "traditional" named entity recognition tasks, but may not be
a good fit for every span identification problem. Specifically, the loss
function optimizes for whole entity accuracy, so if your inter-annotator
agreement on boundary tokens is low, the component will likely perform poorly
on your problem. The transition-based algorithm also assumes that the most
decisive information about your entities will be close to their initial tokens.
If your entities are long and characterised by tokens in their middle, the
component will likely do poorly on your task.
model (Model): The model for the transition-based parser. The model needs
to have a specific substructure of named components --- see the
spacy.ml.tb_framework.TransitionModel for details.
moves (list[str]): A list of transition names. Inferred from the data if not
provided.
update_with_oracle_cut_size (int):
During training, cut long sequences into shorter segments by creating
intermediate states based on the gold-standard history. The model is
not very sensitive to this parameter, so you usually won't need to change
it. 100 is a good default.
beam_width (int): The number of candidate analyses to maintain.
beam_density (float): The minimum ratio between the scores of the first and
last candidates in the beam. This allows the parser to avoid exploring
candidates that are too far behind. This is mostly intended to improve
efficiency, but it can also improve accuracy as deeper search is not
always better.
beam_update_prob (float): The chance of making a beam update, instead of a
greedy update. Greedy updates are an approximation for the beam updates,
and are faster to compute.
"""
return EntityRecognizer(
nlp.vocab,
model,
name,
moves=moves,
update_with_oracle_cut_size=update_with_oracle_cut_size,
multitasks=[],
min_action_freq=1,
learn_tokens=False,
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
) )

View File

@ -4,13 +4,14 @@ from cymem.cymem cimport Pool
cimport numpy as np cimport numpy as np
from itertools import islice from itertools import islice
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.string cimport memset from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
import random import random
from typing import Optional from typing import Optional
import srsly import srsly
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate, CupyOps
from thinc.extra.search cimport Beam
import numpy.random import numpy.random
import numpy import numpy
import warnings import warnings
@ -22,6 +23,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ..ml.parser_model cimport get_c_weights, get_c_sizes from ..ml.parser_model cimport get_c_weights, get_c_sizes
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe
from ._parser_internals cimport _beam_utils
from ._parser_internals import _beam_utils
from ..training import validate_examples, validate_get_examples from ..training import validate_examples, validate_get_examples
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
@ -41,9 +44,12 @@ cdef class Parser(TrainablePipe):
moves=None, moves=None,
*, *,
update_with_oracle_cut_size, update_with_oracle_cut_size,
multitasks=tuple(),
min_action_freq, min_action_freq,
learn_tokens, learn_tokens,
beam_width=1,
beam_density=0.0,
beam_update_prob=0.0,
multitasks=tuple(),
): ):
"""Create a Parser. """Create a Parser.
@ -61,7 +67,10 @@ cdef class Parser(TrainablePipe):
"update_with_oracle_cut_size": update_with_oracle_cut_size, "update_with_oracle_cut_size": update_with_oracle_cut_size,
"multitasks": list(multitasks), "multitasks": list(multitasks),
"min_action_freq": min_action_freq, "min_action_freq": min_action_freq,
"learn_tokens": learn_tokens "learn_tokens": learn_tokens,
"beam_width": beam_width,
"beam_density": beam_density,
"beam_update_prob": beam_update_prob
} }
if moves is None: if moves is None:
# defined by EntityRecognizer as a BiluoPushDown # defined by EntityRecognizer as a BiluoPushDown
@ -183,7 +192,15 @@ cdef class Parser(TrainablePipe):
result = self.moves.init_batch(docs) result = self.moves.init_batch(docs)
self._resize() self._resize()
return result return result
if self.cfg["beam_width"] == 1:
return self.greedy_parse(docs, drop=0.0) return self.greedy_parse(docs, drop=0.0)
else:
return self.beam_parse(
docs,
drop=0.0,
beam_width=self.cfg["beam_width"],
beam_density=self.cfg["beam_density"]
)
def greedy_parse(self, docs, drop=0.): def greedy_parse(self, docs, drop=0.):
cdef vector[StateC*] states cdef vector[StateC*] states
@ -207,6 +224,31 @@ cdef class Parser(TrainablePipe):
del model del model
return batch return batch
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
cdef Beam beam
cdef Doc doc
batch = _beam_utils.BeamBatch(
self.moves,
self.moves.init_batch(docs),
None,
beam_width,
density=beam_density
)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
model = self.model.predict(docs)
while not batch.is_done:
states = batch.get_unfinished_states()
if not states:
break
scores = model.predict(states)
batch.advance(scores)
model.clear_memory()
del model
return list(batch)
cdef void _parseC(self, StateC** states, cdef void _parseC(self, StateC** states,
WeightsC weights, SizesC sizes) nogil: WeightsC weights, SizesC sizes) nogil:
cdef int i, j cdef int i, j
@ -227,14 +269,13 @@ cdef class Parser(TrainablePipe):
unfinished.clear() unfinished.clear()
free_activations(&activations) free_activations(&activations)
def set_annotations(self, docs, states): def set_annotations(self, docs, states_or_beams):
cdef StateClass state cdef StateClass state
cdef Beam beam
cdef Doc doc cdef Doc doc
states = _beam_utils.collect_states(states_or_beams, docs)
for i, (state, doc) in enumerate(zip(states, docs)): for i, (state, doc) in enumerate(zip(states, docs)):
self.moves.finalize_state(state.c) self.moves.set_annotations(state, doc)
for j in range(doc.length):
doc.c[j] = state.c._sent[j]
self.moves.finalize_doc(doc)
for hook in self.postprocesses: for hook in self.postprocesses:
hook(doc) hook(doc)
@ -265,7 +306,6 @@ cdef class Parser(TrainablePipe):
else: else:
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(states[i], action.label) action.do(states[i], action.label)
states[i].push_hist(guess)
free(is_valid) free(is_valid)
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None): def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
@ -276,13 +316,23 @@ cdef class Parser(TrainablePipe):
validate_examples(examples, "Parser.update") validate_examples(examples, "Parser.update")
for multitask in self._multitasks: for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd) multitask.update(examples, drop=drop, sgd=sgd)
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
if n_examples == 0: if n_examples == 0:
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
# Prepare the stepwise model, and get the callback for finishing the batch # The probability we use beam update, instead of falling back to
model, backprop_tok2vec = self.model.begin_update( # a greedy update
[eg.predicted for eg in examples]) beam_update_prob = self.cfg["beam_update_prob"]
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(
examples,
beam_width=self.cfg["beam_width"],
set_annotations=set_annotations,
sgd=sgd,
losses=losses,
beam_density=self.cfg["beam_density"]
)
max_moves = self.cfg["update_with_oracle_cut_size"] max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1: if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the # Chop sequences into lengths of this many words, to make the
@ -296,6 +346,8 @@ cdef class Parser(TrainablePipe):
states, golds, _ = self.moves.init_gold_batch(examples) states, golds, _ = self.moves.init_gold_batch(examples)
if not states: if not states:
return losses return losses
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
all_states = list(states) all_states = list(states)
states_golds = list(zip(states, golds)) states_golds = list(zip(states, golds))
n_moves = 0 n_moves = 0
@ -379,6 +431,27 @@ cdef class Parser(TrainablePipe):
del tutor del tutor
return losses return losses
def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, set_annotations=False, beam_density=0.0):
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
loss = _beam_utils.update_beam(
self.moves,
states,
golds,
model,
beam_width,
beam_density=beam_density,
)
losses[self.name] += loss
backprop_tok2vec(golds)
if sgd is not None:
self.finish_update(sgd)
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
cdef StateClass state cdef StateClass state
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -396,7 +469,7 @@ cdef class Parser(TrainablePipe):
for i, (state, gold) in enumerate(zip(states, golds)): for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int)) memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float)) memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state, gold) self.moves.set_costs(is_valid, costs, state.c, gold)
for j in range(self.moves.n_moves): for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in unseen_classes: if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j) unseen_classes.remove(j)
@ -539,7 +612,6 @@ cdef class Parser(TrainablePipe):
for clas in oracle_actions[i:i+max_length]: for clas in oracle_actions[i:i+max_length]:
action = self.moves.c[clas] action = self.moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.push_hist(action.clas)
if state.is_final(): if state.is_final():
break break
if self.moves.has_gold(eg, start_state.B(0), state.B(0)): if self.moves.has_gold(eg, start_state.B(0), state.B(0)):

View File

@ -7,6 +7,7 @@ from spacy.tokens import Doc
from spacy.pipeline._parser_internals.nonproj import projectivize from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.pipeline._parser_internals.arc_eager import ArcEager from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
from spacy.pipeline._parser_internals.stateclass import StateClass
def get_sequence_costs(M, words, heads, deps, transitions): def get_sequence_costs(M, words, heads, deps, transitions):
@ -47,15 +48,24 @@ def test_oracle_four_words(arc_eager, vocab):
for dep in deps: for dep in deps:
arc_eager.add_action(2, dep) # Left arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right arc_eager.add_action(3, dep) # Right
actions = ["L-left", "B-ROOT", "L-left"] actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
expected_gold = [
["S"],
["B-ROOT", "L-left"],
["B-ROOT"],
["S"],
["D"],
["S"],
["L-left"],
["S"],
["D"]
]
assert state.is_final() assert state.is_final()
for i, state_costs in enumerate(cost_history): for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost # Check gold moves is 0 cost
assert state_costs[actions[i]] == 0.0, actions[i] golds = [act for act, cost in state_costs.items() if cost < 1]
for other_action, cost in state_costs.items(): assert golds == expected_gold[i], (i, golds, expected_gold[i])
if other_action != actions[i]:
assert cost >= 1, (i, other_action)
annot_tuples = [ annot_tuples = [
@ -169,12 +179,15 @@ def test_oracle_dev_sentence(vocab, arc_eager):
. punct said . punct said
""" """
expected_transitions = [ expected_transitions = [
"S", # Shift "Rolls-Royce"
"S", # Shift 'Motor' "S", # Shift 'Motor'
"S", # Shift 'Cars' "S", # Shift 'Cars'
"L-nn", # Attach 'Cars' to 'Inc.' "L-nn", # Attach 'Cars' to 'Inc.'
"L-nn", # Attach 'Motor' to 'Inc.' "L-nn", # Attach 'Motor' to 'Inc.'
"L-nn", # Attach 'Rolls-Royce' to 'Inc.', force shift "L-nn", # Attach 'Rolls-Royce' to 'Inc.'
"S", # Shift "Inc."
"L-nsubj", # Attach 'Inc.' to 'said' "L-nsubj", # Attach 'Inc.' to 'said'
"S", # Shift 'said'
"S", # Shift 'it' "S", # Shift 'it'
"L-nsubj", # Attach 'it.' to 'expects' "L-nsubj", # Attach 'it.' to 'expects'
"R-ccomp", # Attach 'expects' to 'said' "R-ccomp", # Attach 'expects' to 'said'
@ -204,6 +217,8 @@ def test_oracle_dev_sentence(vocab, arc_eager):
"D", # Reduce "steady" "D", # Reduce "steady"
"D", # Reduce "expects" "D", # Reduce "expects"
"R-punct", # Attach "." to "said" "R-punct", # Attach "." to "said"
"D", # Reduce "."
"D", # Reduce "said"
] ]
gold_words = [] gold_words = []
@ -221,10 +236,40 @@ def test_oracle_dev_sentence(vocab, arc_eager):
for dep in gold_deps: for dep in gold_deps:
arc_eager.add_action(2, dep) # Left arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right arc_eager.add_action(3, dep) # Right
doc = Doc(Vocab(), words=gold_words) doc = Doc(Vocab(), words=gold_words)
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps}) example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = arc_eager.get_oracle_sequence(example)
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions == expected_transitions assert ae_oracle_actions == expected_transitions
def test_oracle_bad_tokenization(vocab, arc_eager):
words_deps_heads = """
[catalase] dep is
: punct is
that nsubj is
is root is
bad comp is
"""
gold_words = []
gold_deps = []
gold_heads = []
for line in words_deps_heads.strip().split("\n"):
line = line.strip()
if not line:
continue
word, dep, head = line.split()
gold_words.append(word)
gold_deps.append(dep)
gold_heads.append(head)
gold_heads = [gold_words.index(head) for head in gold_heads]
for dep in gold_deps:
arc_eager.add_action(2, dep) # Left
arc_eager.add_action(3, dep) # Right
reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
predicted = Doc(reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"])
example = Example(predicted=predicted, reference=reference)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions

View File

@ -54,7 +54,7 @@ def tsys(vocab, entity_types):
def test_get_oracle_moves(tsys, doc, entity_annots): def test_get_oracle_moves(tsys, doc, entity_annots):
example = Example.from_dict(doc, {"entities": entity_annots}) example = Example.from_dict(doc, {"entities": entity_annots})
act_classes = tsys.get_oracle_sequence(example) act_classes = tsys.get_oracle_sequence(example, _debug=False)
names = [tsys.get_class_name(act) for act in act_classes] names = [tsys.get_class_name(act) for act in act_classes]
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"] assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]

View File

@ -0,0 +1,144 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
import hypothesis
import hypothesis.strategies
import numpy
from spacy.vocab import Vocab
from spacy.language import Language
from spacy.pipeline import DependencyParser
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.tokens import Doc
from spacy.pipeline._parser_internals._beam_utils import BeamBatch
from spacy.pipeline._parser_internals.stateclass import StateClass
from spacy.training import Example
from thinc.tests.strategies import ndarrays_of_shape
@pytest.fixture(scope="module")
def vocab():
return Vocab()
@pytest.fixture(scope="module")
def moves(vocab):
aeager = ArcEager(vocab.strings, {})
aeager.add_action(0, "")
aeager.add_action(1, "")
aeager.add_action(2, "nsubj")
aeager.add_action(2, "punct")
aeager.add_action(2, "aux")
aeager.add_action(2, "nsubjpass")
aeager.add_action(3, "dobj")
aeager.add_action(2, "aux")
aeager.add_action(4, "ROOT")
return aeager
@pytest.fixture(scope="module")
def docs(vocab):
return [
Doc(
vocab,
words=["Rats", "bite", "things"],
heads=[1, 1, 1],
deps=["nsubj", "ROOT", "dobj"],
sent_starts=[True, False, False]
)
]
@pytest.fixture(scope="module")
def examples(docs):
return [Example(doc, doc.copy()) for doc in docs]
@pytest.fixture
def states(docs):
return [StateClass(doc) for doc in docs]
@pytest.fixture
def tokvecs(docs, vector_size):
output = []
for doc in docs:
vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size))
output.append(numpy.asarray(vec))
return output
@pytest.fixture(scope="module")
def batch_size(docs):
return len(docs)
@pytest.fixture(scope="module")
def beam_width():
return 4
@pytest.fixture(params=[0.0, 0.5, 1.0])
def beam_density(request):
return request.param
@pytest.fixture
def vector_size():
return 6
@pytest.fixture
def beam(moves, examples, beam_width):
states, golds, _ = moves.init_gold_batch(examples)
return BeamBatch(moves, states, golds, width=beam_width, density=0.0)
@pytest.fixture
def scores(moves, batch_size, beam_width):
return numpy.asarray(
numpy.concatenate(
[
numpy.random.uniform(-0.1, 0.1, (beam_width, moves.n_moves))
for _ in range(batch_size)
]
), dtype="float32")
def test_create_beam(beam):
pass
def test_beam_advance(beam, scores):
beam.advance(scores)
def test_beam_advance_too_few_scores(beam, scores):
n_state = sum(len(beam) for beam in beam)
scores = scores[:n_state]
with pytest.raises(IndexError):
beam.advance(scores[:-1])
def test_beam_parse(examples, beam_width):
nlp = Language()
parser = nlp.add_pipe("beam_parser")
parser.cfg["beam_width"] = beam_width
parser.add_label("nsubj")
parser.initialize(lambda: examples)
doc = nlp.make_doc("Australia is a country")
parser(doc)
@hypothesis.given(hyp=hypothesis.strategies.data())
def test_beam_density(moves, examples, beam_width, hyp):
beam_density = float(hyp.draw(hypothesis.strategies.floats(0.0, 1.0, width=32)))
states, golds, _ = moves.init_gold_batch(examples)
beam = BeamBatch(moves, states, golds, width=beam_width, density=beam_density)
n_state = sum(len(beam) for beam in beam)
scores = hyp.draw(ndarrays_of_shape((n_state, moves.n_moves)))
beam.advance(scores)
for b in beam:
beam_probs = b.probs
assert b.min_density == beam_density
assert beam_probs[-1] >= beam_probs[0] * beam_density

View File

@ -22,6 +22,7 @@ def _parser_example(parser):
@pytest.fixture @pytest.fixture
def parser(vocab): def parser(vocab):
vocab.strings.add("ROOT")
config = { config = {
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,
@ -76,13 +77,16 @@ def test_sents_1_2(parser):
def test_sents_1_3(parser): def test_sents_1_3(parser):
doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc[1].sent_start = True doc[0].is_sent_start = True
doc[3].sent_start = True doc[1].is_sent_start = True
doc[2].is_sent_start = None
doc[3].is_sent_start = True
doc = parser(doc) doc = parser(doc)
assert len(list(doc.sents)) >= 3 assert len(list(doc.sents)) >= 3
doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc[1].sent_start = True doc[0].is_sent_start = True
doc[2].sent_start = False doc[1].is_sent_start = True
doc[3].sent_start = True doc[2].is_sent_start = False
doc[3].is_sent_start = True
doc = parser(doc) doc = parser(doc)
assert len(list(doc.sents)) == 3 assert len(list(doc.sents)) == 3

View File

@ -0,0 +1,74 @@
import pytest
from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.pipeline._parser_internals.stateclass import StateClass
@pytest.fixture
def vocab():
return Vocab()
@pytest.fixture
def doc(vocab):
return Doc(vocab, words=["a", "b", "c", "d"])
def test_init_state(doc):
state = StateClass(doc)
assert state.stack == []
assert state.queue == list(range(len(doc)))
assert not state.is_final()
assert state.buffer_length() == 4
def test_push_pop(doc):
state = StateClass(doc)
state.push()
assert state.buffer_length() == 3
assert state.stack == [0]
assert 0 not in state.queue
state.push()
assert state.stack == [1, 0]
assert 1 not in state.queue
assert state.buffer_length() == 2
state.pop()
assert state.stack == [0]
assert 1 not in state.queue
def test_stack_depth(doc):
state = StateClass(doc)
assert state.stack_depth() == 0
assert state.buffer_length() == len(doc)
state.push()
assert state.buffer_length() == 3
assert state.stack_depth() == 1
def test_H(doc):
state = StateClass(doc)
assert state.H(0) == -1
state.add_arc(1, 0, 0)
assert state.arcs == [{"head": 1, "child": 0, "label": 0}]
assert state.H(0) == 1
state.add_arc(3, 1, 0)
assert state.H(1) == 3
def test_L(doc):
state = StateClass(doc)
assert state.L(2, 1) == -1
state.add_arc(2, 1, 0)
assert state.arcs == [{"head": 2, "child": 1, "label": 0}]
assert state.L(2, 1) == 1
state.add_arc(2, 0, 0)
assert state.L(2, 1) == 0
assert state.n_L(2) == 2
def test_R(doc):
state = StateClass(doc)
assert state.R(0, 1) == -1
state.add_arc(0, 1, 0)
assert state.arcs == [{"head": 0, "child": 1, "label": 0}]
assert state.R(0, 1) == 1
state.add_arc(0, 2, 0)
assert state.R(0, 1) == 2
assert state.n_R(0) == 2

View File

@ -122,7 +122,8 @@ def test_issue4042_bug2():
assert "SOME_LABEL" in ner1.labels assert "SOME_LABEL" in ner1.labels
apple_ent = Span(doc1, 5, 6, label="MY_ORG") apple_ent = Span(doc1, 5, 6, label="MY_ORG")
doc1.ents = list(doc1.ents) + [apple_ent] doc1.ents = list(doc1.ents) + [apple_ent]
# reapply the NER - at this point it should resize itself # Add the label explicitly. Previously we didn't require this.
ner1.add_label("MY_ORG")
ner1(doc1) ner1(doc1)
assert len(ner1.labels) == 2 assert len(ner1.labels) == 2
assert "SOME_LABEL" in ner1.labels assert "SOME_LABEL" in ner1.labels

View File

@ -22,6 +22,9 @@ def parser(en_vocab):
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"beam_width": 1,
"beam_update_prob": 1.0,
"beam_density": 0.0
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
@ -36,6 +39,9 @@ def blank_parser(en_vocab):
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"beam_width": 1,
"beam_update_prob": 1.0,
"beam_density": 0.0
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
@ -58,6 +64,9 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"beam_width": 1,
"beam_update_prob": 1.0,
"beam_density": 0.0
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
@ -79,6 +88,9 @@ def test_serialize_parser_strings(Parser):
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"beam_width": 1,
"beam_update_prob": 1.0,
"beam_density": 0.0
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
@ -98,6 +110,9 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 0, "min_action_freq": 0,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"beam_width": 1,
"beam_update_prob": 1.0,
"beam_density": 0.0
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]

View File

@ -191,6 +191,24 @@ cdef class Example:
aligned_deps[cand_i] = deps[gold_i] aligned_deps[cand_i] = deps[gold_i]
return aligned_heads, aligned_deps return aligned_heads, aligned_deps
def get_aligned_sent_starts(self):
"""Get list of SENT_START attributes aligned to the predicted tokenization.
If the reference has not sentence starts, return a list of None values.
The aligned sentence starts use the get_aligned_spans method, rather
than aligning the list of tags, so that it handles cases where a mistaken
tokenization starts the sentence.
"""
if self.y.has_annotation("SENT_START"):
align = self.alignment.y2x
sent_starts = [False] * len(self.x)
for y_sent in self.y.sents:
x_start = int(align[y_sent.start].dataXd[0])
sent_starts[x_start] = True
return sent_starts
else:
return [None] * len(self.x)
def get_aligned_spans_x2y(self, x_spans): def get_aligned_spans_x2y(self, x_spans):
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y) return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)