From 8656a08777978cb4a77113bdcd17f404c9d8e1e8 Mon Sep 17 00:00:00 2001
From: Matthew Honnibal <honnibal+gh@gmail.com>
Date: Sun, 13 Dec 2020 12:08:32 +1100
Subject: [PATCH] Add beam_parser and beam_ner components for v3 (#6369)

* Get basic beam tests working

* Get basic beam tests working

* Compile _beam_utils

* Remove prints

* Test beam density

* Beam parser seems to train

* Draft beam NER

* Upd beam

* Add hypothesis as dev dependency

* Implement missing is-gold-parse method

* Implement early update

* Fix state hashing

* Fix test

* Fix test

* Default to non-beam in parser constructor

* Improve oracle for beam

* Start refactoring beam

* Update test

* Refactor beam

* Update nn

* Refactor beam and weight by cost

* Update ner beam settings

* Update test

* Add __init__.pxd

* Upd test

* Fix test

* Upd test

* Fix test

* Remove ring buffer history from StateC

* WIP change arc-eager transitions

* Add state tests

* Support ternary sent start values

* Fix arc eager

* Fix NER

* Pass oracle cut size for beam

* Fix ner test

* Fix beam

* Improve StateC.clone

* Improve StateClass.borrow

* Work directly with StateC, not StateClass

* Remove print statements

* Fix state copy

* Improve state class

* Refactor parser oracles

* Fix arc eager oracle

* Fix arc eager oracle

* Use a vector to implement the stack

* Refactor state data structure

* Fix alignment of sent start

* Add get_aligned_sent_starts method

* Add test for ae oracle when bad sentence starts

* Fix sentence segment handling

* Avoid Reduce that inserts illegal sentence

* Update preset SBD test

* Fix test

* Remove prints

* Fix sent starts in Example

* Improve python API of StateClass

* Tweak comments and debug output of arc eager

* Upd test

* Fix state test

* Fix state test
---
 requirements.txt                              |   1 +
 setup.py                                      |   1 +
 spacy/pipeline/_parser_internals/__init__.pxd |   0
 .../_parser_internals/_beam_utils.pxd         |   6 +
 .../_parser_internals/_beam_utils.pyx         | 296 ++++++++++
 spacy/pipeline/_parser_internals/_state.pxd   | 535 +++++++----------
 .../pipeline/_parser_internals/arc_eager.pxd  |   6 +-
 .../pipeline/_parser_internals/arc_eager.pyx  | 546 +++++++++---------
 spacy/pipeline/_parser_internals/ner.pyx      |  92 +--
 .../pipeline/_parser_internals/stateclass.pxd | 116 +---
 .../pipeline/_parser_internals/stateclass.pyx | 158 ++++-
 .../_parser_internals/transition_system.pxd   |  15 +-
 .../_parser_internals/transition_system.pyx   |  17 +-
 spacy/pipeline/dep_parser.pyx                 |  87 +++
 spacy/pipeline/ner.pyx                        |  73 +++
 spacy/pipeline/transition_parser.pyx          | 104 +++-
 spacy/tests/parser/test_arc_eager_oracle.py   |  63 +-
 spacy/tests/parser/test_ner.py                |   2 +-
 spacy/tests/parser/test_nn_beam.py            | 144 +++++
 spacy/tests/parser/test_preset_sbd.py         |  14 +-
 spacy/tests/parser/test_state.py              |  74 +++
 spacy/tests/regression/test_issue4001-4500.py |   3 +-
 .../serialize/test_serialize_pipeline.py      |  15 +
 spacy/training/example.pyx                    |  18 +
 24 files changed, 1570 insertions(+), 816 deletions(-)
 create mode 100644 spacy/pipeline/_parser_internals/__init__.pxd
 create mode 100644 spacy/pipeline/_parser_internals/_beam_utils.pxd
 create mode 100644 spacy/pipeline/_parser_internals/_beam_utils.pyx
 create mode 100644 spacy/tests/parser/test_state.py

diff --git a/requirements.txt b/requirements.txt
index 13c34d601..44f53bdb4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,3 +27,4 @@ pytest>=4.6.5
 pytest-timeout>=1.3.0,<2.0.0
 mock>=2.0.0,<3.0.0
 flake8>=3.5.0,<3.6.0
+hypothesis
diff --git a/setup.py b/setup.py
index 6e6e08988..14f8486ca 100755
--- a/setup.py
+++ b/setup.py
@@ -48,6 +48,7 @@ MOD_NAMES = [
     "spacy.pipeline._parser_internals._state",
     "spacy.pipeline._parser_internals.stateclass",
     "spacy.pipeline._parser_internals.transition_system",
+    "spacy.pipeline._parser_internals._beam_utils",
     "spacy.tokenizer",
     "spacy.training.align",
     "spacy.training.gold_io",
diff --git a/spacy/pipeline/_parser_internals/__init__.pxd b/spacy/pipeline/_parser_internals/__init__.pxd
new file mode 100644
index 000000000..e69de29bb
diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pxd b/spacy/pipeline/_parser_internals/_beam_utils.pxd
new file mode 100644
index 000000000..de3573fbc
--- /dev/null
+++ b/spacy/pipeline/_parser_internals/_beam_utils.pxd
@@ -0,0 +1,6 @@
+from ...typedefs cimport class_t, hash_t
+
+# These are passed as callbacks to thinc.search.Beam
+cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
+
+cdef int check_final_state(void* _state, void* extra_args) except -1
diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx
new file mode 100644
index 000000000..a7f34daaf
--- /dev/null
+++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx
@@ -0,0 +1,296 @@
+# cython: infer_types=True
+# cython: profile=True
+cimport numpy as np
+import numpy
+from cpython.ref cimport PyObject, Py_XDECREF
+from thinc.extra.search cimport Beam
+from thinc.extra.search import MaxViolation
+from thinc.extra.search cimport MaxViolation
+
+from ...typedefs cimport hash_t, class_t
+from .transition_system cimport TransitionSystem, Transition
+from ...errors import Errors
+from .stateclass cimport StateC, StateClass
+
+
+# These are passed as callbacks to thinc.search.Beam
+cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
+    dest = <StateC*>_dest
+    src = <StateC*>_src
+    moves = <const Transition*>_moves
+    dest.clone(src)
+    moves[clas].do(dest, moves[clas].label)
+
+
+cdef int check_final_state(void* _state, void* extra_args) except -1:
+    state = <StateC*>_state
+    return state.is_final()
+
+
+cdef class BeamBatch(object):
+    cdef public TransitionSystem moves
+    cdef public object states
+    cdef public object docs
+    cdef public object golds
+    cdef public object beams
+
+    def __init__(self, TransitionSystem moves, states, golds,
+                 int width, float density=0.):
+        cdef StateClass state
+        self.moves = moves
+        self.states = states
+        self.docs = [state.doc for state in states]
+        self.golds = golds
+        self.beams = []
+        cdef Beam beam
+        cdef StateC* st
+        for state in states:
+            beam = Beam(self.moves.n_moves, width, min_density=density)
+            beam.initialize(self.moves.init_beam_state,
+                            self.moves.del_beam_state, state.c.length,
+                            <void*>state.c._sent)
+            for i in range(beam.width):
+                st = <StateC*>beam.at(i)
+                st.offset = state.c.offset
+            beam.check_done(check_final_state, NULL)
+            self.beams.append(beam)
+
+    @property
+    def is_done(self):
+        return all(b.is_done for b in self.beams)
+
+    def __getitem__(self, i):
+        return self.beams[i]
+
+    def __len__(self):
+        return len(self.beams)
+
+    def get_states(self):
+        cdef Beam beam
+        cdef StateC* state
+        cdef StateClass stcls
+        states = []
+        for beam, doc in zip(self, self.docs):
+            for i in range(beam.size):
+                state = <StateC*>beam.at(i)
+                stcls = StateClass.borrow(state, doc)
+                states.append(stcls)
+        return states
+
+    def get_unfinished_states(self):
+        return [st for st in self.get_states() if not st.is_final()]
+
+    def advance(self, float[:, ::1] scores, follow_gold=False):
+        cdef Beam beam
+        cdef int nr_class = scores.shape[1]
+        cdef const float* c_scores = &scores[0, 0]
+        docs = self.docs
+        for i, beam in enumerate(self):
+            if not beam.is_done:
+                nr_state = self._set_scores(beam, c_scores, nr_class)
+                assert nr_state
+                if self.golds is not None:
+                    self._set_costs(
+                        beam,
+                        docs[i],
+                        self.golds[i],
+                        follow_gold=follow_gold
+                    )
+                c_scores += nr_state * nr_class
+                beam.advance(transition_state, NULL, <void*>self.moves.c)
+                beam.check_done(check_final_state, NULL)
+
+    cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
+        cdef int nr_state = 0
+        for i in range(beam.size):
+            state = <StateC*>beam.at(i)
+            if not state.is_final():
+                for j in range(nr_class):
+                    beam.scores[i][j] = scores[nr_state * nr_class + j]
+                self.moves.set_valid(beam.is_valid[i], state)
+                nr_state += 1
+            else:
+                for j in range(beam.nr_class):
+                    beam.scores[i][j] = 0
+                    beam.costs[i][j] = 0
+        return nr_state
+
+    def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
+        cdef const StateC* state
+        for i in range(beam.size):
+            state = <const StateC*>beam.at(i)
+            if state.is_final():
+                for j in range(beam.nr_class):
+                    beam.is_valid[i][j] = 0
+                    beam.costs[i][j] = 9000
+            else:
+                self.moves.set_costs(beam.is_valid[i], beam.costs[i],
+                                     state, gold)
+                if follow_gold:
+                    min_cost = 0
+                    for j in range(beam.nr_class):
+                        if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
+                            min_cost = beam.costs[i][j]
+                    for j in range(beam.nr_class):
+                        if beam.costs[i][j] > min_cost:
+                            beam.is_valid[i][j] = 0
+
+
+def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
+    cdef MaxViolation violn
+    pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
+    gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
+    cdef StateClass state
+    beam_maps = []
+    backprops = []
+    violns = [MaxViolation() for _ in range(len(states))]
+    dones = [False for _ in states]
+    while not pbeam.is_done or not gbeam.is_done:
+        # The beam maps let us find the right row in the flattened scores
+        # array for each state. States are identified by (example id,
+        # history). We keep a different beam map for each step (since we'll
+        # have a flat scores array for each step). The beam map will let us
+        # take the per-state losses, and compute the gradient for each (step,
+        # state, class).
+        # Gather all states from the two beams in a list. Some stats may occur
+        # in both beams. To figure out which beam each state belonged to,
+        # we keep two lists of indices, p_indices and g_indices
+        states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
+        beam_maps.append(beam_map)
+        if not states:
+            break
+        # Now that we have our flat list of states, feed them through the model
+        scores, bp_scores = model.begin_update(states)
+        assert scores.size != 0
+        # Store the callbacks for the backward pass
+        backprops.append(bp_scores)
+        # Unpack the scores for the two beams. The indices arrays
+        # tell us which example and state the scores-row refers to.
+        # Now advance the states in the beams. The gold beam is constrained to
+        # to follow only gold analyses.
+        if not pbeam.is_done:
+            pbeam.advance(model.ops.as_contig(scores[p_indices]))
+        if not gbeam.is_done:
+            gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
+        # Track the "maximum violation", to use in the update.
+        for i, violn in enumerate(violns):
+            if not dones[i]:
+                violn.check_crf(pbeam[i], gbeam[i])
+                if pbeam[i].is_done and gbeam[i].is_done:
+                    dones[i] = True
+    histories = []
+    grads = []
+    for violn in violns:
+        if violn.p_hist:
+            histories.append(violn.p_hist + violn.g_hist)
+            d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
+            grads.append(d_loss)
+        else:
+            histories.append([])
+            grads.append([])
+    loss = 0.0
+    states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
+    for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
+        loss += (d_scores**2).mean()
+        bp_scores(d_scores)
+    return loss
+
+
+def collect_states(beams, docs):
+    cdef StateClass state
+    cdef Beam beam
+    states = []
+    for state_or_beam, doc in zip(beams, docs):
+        if isinstance(state_or_beam, StateClass):
+            states.append(state_or_beam)
+        else:
+            beam = state_or_beam
+            state = StateClass.borrow(<StateC*>beam.at(0), doc)
+            states.append(state)
+    return states
+
+
+def get_unique_states(pbeams, gbeams):
+    seen = {}
+    states = []
+    p_indices = []
+    g_indices = []
+    beam_map = {}
+    docs = pbeams.docs
+    cdef Beam pbeam, gbeam
+    if len(pbeams) != len(gbeams):
+        raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
+    for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
+        if not pbeam.is_done:
+            for i in range(pbeam.size):
+                state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
+                if not state.is_final():
+                    key = tuple([eg_id] + pbeam.histories[i])
+                    if key in seen:
+                        raise ValueError(Errors.E080.format(key=key))
+                    seen[key] = len(states)
+                    p_indices.append(len(states))
+                    states.append(state)
+            beam_map.update(seen)
+        if not gbeam.is_done:
+            for i in range(gbeam.size):
+                state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
+                if not state.is_final():
+                    key = tuple([eg_id] + gbeam.histories[i])
+                    if key in seen:
+                        g_indices.append(seen[key])
+                    else:
+                        g_indices.append(len(states))
+                        beam_map[key] = len(states)
+                        states.append(state)
+    p_indices = numpy.asarray(p_indices, dtype='i')
+    g_indices = numpy.asarray(g_indices, dtype='i')
+    return states, p_indices, g_indices, beam_map
+
+
+def get_gradient(nr_class, beam_maps, histories, losses):
+    """The global model assigns a loss to each parse. The beam scores
+    are additive, so the same gradient is applied to each action
+    in the history. This gives the gradient of a single *action*
+    for a beam state -- so we have "the gradient of loss for taking
+    action i given history H."
+
+    Histories: Each hitory is a list of actions
+    Each candidate has a history
+    Each beam has multiple candidates
+    Each batch has multiple beams
+    So history is list of lists of lists of ints
+    """
+    grads = []
+    nr_steps = []
+    for eg_id, hists in enumerate(histories):
+        nr_step = 0
+        for loss, hist in zip(losses[eg_id], hists):
+            assert not numpy.isnan(loss)
+            if loss != 0.0:
+                nr_step = max(nr_step, len(hist))
+        nr_steps.append(nr_step)
+    for i in range(max(nr_steps)):
+        grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
+                                 dtype='f'))
+    if len(histories) != len(losses):
+        raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
+    for eg_id, hists in enumerate(histories):
+        for loss, hist in zip(losses[eg_id], hists):
+            assert not numpy.isnan(loss)
+            if loss == 0.0:
+                continue
+            key = tuple([eg_id])
+            # Adjust loss for length
+            # We need to do this because each state in a short path is scored
+            # multiple times, as we add in the average cost when we run out
+            # of actions.
+            avg_loss = loss / len(hist)
+            loss += avg_loss * (nr_steps[eg_id] - len(hist))
+            for step, clas in enumerate(hist):
+                i = beam_maps[step][key]
+                # In step j, at state i action clas
+                # resulted in loss
+                grads[step][i, clas] += loss
+                key = key + tuple([clas])
+    return grads
diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd
index 0d0dd8c05..a6bf926f9 100644
--- a/spacy/pipeline/_parser_internals/_state.pxd
+++ b/spacy/pipeline/_parser_internals/_state.pxd
@@ -1,6 +1,9 @@
 from libc.string cimport memcpy, memset
 from libc.stdlib cimport calloc, free
 from libc.stdint cimport uint32_t, uint64_t
+cimport libcpp
+from libcpp.vector cimport vector
+from libcpp.set cimport set
 from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
 from murmurhash.mrmr cimport hash64
 
@@ -14,89 +17,48 @@ from ...typedefs cimport attr_t
 cdef inline bint is_space_token(const TokenC* token) nogil:
     return Lexeme.c_check_flag(token.lex, IS_SPACE)
 
-cdef struct RingBufferC:
-    int[8] data
-    int i
-    int default
-
-cdef inline int ring_push(RingBufferC* ring, int value) nogil:
-    ring.data[ring.i] = value
-    ring.i += 1
-    if ring.i >= 8:
-        ring.i = 0
-
-cdef inline int ring_get(RingBufferC* ring, int i) nogil:
-    if i >= ring.i:
-        return ring.default
-    else:
-        return ring.data[ring.i-i]
+cdef struct ArcC:
+    int head
+    int child
+    attr_t label
 
 
 cdef cppclass StateC:
-    int* _stack
-    int* _buffer
-    bint* shifted
-    TokenC* _sent
-    SpanC* _ents
+    int* _heads
+    const TokenC* _sent
+    vector[int] _stack
+    vector[int] _rebuffer
+    vector[SpanC] _ents
+    vector[ArcC] _left_arcs
+    vector[ArcC] _right_arcs
+    vector[libcpp.bool] _unshiftable
+    set[int] _sent_starts
     TokenC _empty_token
-    RingBufferC _hist
     int length
     int offset
-    int _s_i
     int _b_i
-    int _e_i
-    int _break
 
     __init__(const TokenC* sent, int length) nogil:
-        cdef int PADDING = 5
-        this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
-        this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
-        this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
-        this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
-        this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
-        if not (this._buffer and this._stack and this.shifted
-                and this._sent and this._ents):
+        this._sent = sent
+        this._heads = <int*>calloc(length, sizeof(int))
+        if not (this._sent and this._heads):
             with gil:
                 PyErr_SetFromErrno(MemoryError)
                 PyErr_CheckSignals()
-        memset(&this._hist, 0, sizeof(this._hist))
         this.offset = 0
-        cdef int i
-        for i in range(length + (PADDING * 2)):
-            this._ents[i].end = -1
-            this._sent[i].l_edge = i
-            this._sent[i].r_edge = i
-        for i in range(PADDING):
-            this._sent[i].lex = &EMPTY_LEXEME
-        this._sent += PADDING
-        this._ents += PADDING
-        this._buffer += PADDING
-        this._stack += PADDING
-        this.shifted += PADDING
         this.length = length
-        this._break = -1
-        this._s_i = 0
         this._b_i = 0
-        this._e_i = 0
         for i in range(length):
-            this._buffer[i] = i
+            this._heads[i] = -1
+            this._unshiftable.push_back(0)
         memset(&this._empty_token, 0, sizeof(TokenC))
         this._empty_token.lex = &EMPTY_LEXEME
-        for i in range(length):
-            this._sent[i] = sent[i]
-            this._buffer[i] = i
-        for i in range(length, length+PADDING):
-            this._sent[i].lex = &EMPTY_LEXEME
 
     __dealloc__():
-        cdef int PADDING = 5
-        free(this._sent - PADDING)
-        free(this._ents - PADDING)
-        free(this._buffer - PADDING)
-        free(this._stack - PADDING)
-        free(this.shifted - PADDING)
+        free(this._heads)
 
     void set_context_tokens(int* ids, int n) nogil:
+        cdef int i, j
         if n == 1:
             if this.B(0) >= 0:
                 ids[0] = this.B(0)
@@ -145,22 +107,18 @@ cdef cppclass StateC:
             ids[11] = this.R(this.S(1), 1)
             ids[12] = this.R(this.S(1), 2)
         elif n == 6:
+            for i in range(6):
+                ids[i] = -1
             if this.B(0) >= 0:
                 ids[0] = this.B(0)
-                ids[1] = this.B(0)-1
-            else:
-                ids[0] = -1
-                ids[1] = -1
-            ids[2] = this.B(1)
-            ids[3] = this.E(0)
-            if ids[3] >= 1:
-                ids[4] = this.E(0)-1
-            else:
-                ids[4] = -1
-            if (ids[3]+1) < this.length:
-                ids[5] = this.E(0)+1
-            else:
-                ids[5] = -1
+            if this.entity_is_open():
+                ent = this.get_ent()
+                j = 1
+                for i in range(ent.start, this.B(0)):
+                    ids[j] = i
+                    j += 1
+                    if j >= 6:
+                        break
         else:
             # TODO error =/
             pass
@@ -171,329 +129,256 @@ cdef cppclass StateC:
                 ids[i] = -1
 
     int S(int i) nogil const:
-        if i >= this._s_i:
+        if i >= this._stack.size():
             return -1
-        return this._stack[this._s_i - (i+1)]
+        elif i < 0:
+            return -1
+        return this._stack.at(this._stack.size() - (i+1))
 
     int B(int i) nogil const:
-        if (i + this._b_i) >= this.length:
+        if i < 0:
             return -1
-        return this._buffer[this._b_i + i]
-
-    const TokenC* S_(int i) nogil const:
-        return this.safe_get(this.S(i))
+        elif i < this._rebuffer.size():
+            return this._rebuffer.at(this._rebuffer.size() - (i+1))
+        else:
+            b_i = this._b_i + (i - this._rebuffer.size())
+            if b_i >= this.length:
+                return -1
+            else:
+                return b_i
 
     const TokenC* B_(int i) nogil const:
         return this.safe_get(this.B(i))
 
-    const TokenC* H_(int i) nogil const:
-        return this.safe_get(this.H(i))
-
     const TokenC* E_(int i) nogil const:
         return this.safe_get(this.E(i))
 
-    const TokenC* L_(int i, int idx) nogil const:
-        return this.safe_get(this.L(i, idx))
-
-    const TokenC* R_(int i, int idx) nogil const:
-        return this.safe_get(this.R(i, idx))
-
     const TokenC* safe_get(int i) nogil const:
         if i < 0 or i >= this.length:
             return &this._empty_token
         else:
             return &this._sent[i]
 
-    int H(int i) nogil const:
-        if i < 0 or i >= this.length:
+    void get_arcs(vector[ArcC]* arcs) nogil const:
+        for i in range(this._left_arcs.size()):
+            arc = this._left_arcs.at(i)
+            if arc.head != -1 and arc.child != -1:
+                arcs.push_back(arc)
+        for i in range(this._right_arcs.size()):
+            arc = this._right_arcs.at(i)
+            if arc.head != -1 and arc.child != -1:
+                arcs.push_back(arc)
+
+    int H(int child) nogil const:
+        if child >= this.length or child < 0:
             return -1
-        return this._sent[i].head + i
+        else:
+            return this._heads[child]
 
     int E(int i) nogil const:
-        if this._e_i <= 0 or this._e_i >= this.length:
+        if this._ents.size() == 0:
             return -1
-        if i < 0 or i >= this._e_i:
-            return -1
-        return this._ents[this._e_i - (i+1)].start
+        else:
+            return this._ents.back().start
 
-    int L(int i, int idx) nogil const:
-        if idx < 1:
+    int L(int head, int idx) nogil const:
+        if idx < 1 or this._left_arcs.size() == 0:
             return -1
-        if i < 0 or i >= this.length:
+        cdef vector[int] lefts
+        for i in range(this._left_arcs.size()):
+            arc = this._left_arcs.at(i)
+            if arc.head == head and arc.child != -1 and arc.child < head:
+                lefts.push_back(arc.child)
+        idx = (<int>lefts.size()) - idx
+        if idx < 0:
             return -1
-        cdef const TokenC* target = &this._sent[i]
-        if target.l_kids < <uint32_t>idx:
-            return -1
-        cdef const TokenC* ptr = &this._sent[target.l_edge]
+        else:
+            return lefts.at(idx)
 
-        while ptr < target:
-            # If this head is still to the right of us, we can skip to it
-            # No token that's between this token and this head could be our
-            # child.
-            if (ptr.head >= 1) and (ptr + ptr.head) < target:
-                ptr += ptr.head
-
-            elif ptr + ptr.head == target:
-                idx -= 1
-                if idx == 0:
-                    return ptr - this._sent
-                ptr += 1
-            else:
-                ptr += 1
-        return -1
-
-    int R(int i, int idx) nogil const:
-        if idx < 1:
+    int R(int head, int idx) nogil const:
+        if idx < 1 or this._right_arcs.size() == 0:
             return -1
-        if i < 0 or i >= this.length:
+        cdef vector[int] rights
+        for i in range(this._right_arcs.size()):
+            arc = this._right_arcs.at(i)
+            if arc.head == head and arc.child != -1 and arc.child > head:
+                rights.push_back(arc.child)
+        idx = (<int>rights.size()) - idx
+        if idx < 0:
             return -1
-        cdef const TokenC* target = &this._sent[i]
-        if target.r_kids < <uint32_t>idx:
-            return -1
-        cdef const TokenC* ptr = &this._sent[target.r_edge]
-        while ptr > target:
-            # If this head is still to the right of us, we can skip to it
-            # No token that's between this token and this head could be our
-            # child.
-            if (ptr.head < 0) and ((ptr + ptr.head) > target):
-                ptr += ptr.head
-            elif ptr + ptr.head == target:
-                idx -= 1
-                if idx == 0:
-                    return ptr - this._sent
-                ptr -= 1
-            else:
-                ptr -= 1
-        return -1
+        else:
+            return rights.at(idx)
 
     bint empty() nogil const:
-        return this._s_i <= 0
+        return this._stack.size() == 0
 
     bint eol() nogil const:
         return this.buffer_length() == 0
 
-    bint at_break() nogil const:
-        return this._break != -1
-
     bint is_final() nogil const:
-        return this.stack_depth() <= 0 and this._b_i >= this.length
+        return this.stack_depth() <= 0 and this.eol()
 
-    bint has_head(int i) nogil const:
-        return this.safe_get(i).head != 0
+    int cannot_sent_start(int word) nogil const:
+        if word < 0 or word >= this.length:
+            return 0
+        elif this._sent[word].sent_start == -1:
+            return 1
+        else:
+            return 0
 
-    int n_L(int i) nogil const:
-        return this.safe_get(i).l_kids
+    int is_sent_start(int word) nogil const:
+        if word < 0 or word >= this.length:
+            return 0
+        elif this._sent[word].sent_start == 1:
+            return 1
+        elif this._sent_starts.count(word) >= 1:
+            return 1
+        else:
+            return 0
 
-    int n_R(int i) nogil const:
-        return this.safe_get(i).r_kids
+    void set_sent_start(int word, int value) nogil:
+        if value >= 1:
+            this._sent_starts.insert(word)
+
+    bint has_head(int child) nogil const:
+        return this._heads[child] >= 0
+
+    int l_edge(int word) nogil const:
+        return word
+
+    int r_edge(int word) nogil const:
+        return word
+ 
+    int n_L(int head) nogil const:
+        cdef int n = 0
+        for i in range(this._left_arcs.size()):
+            arc = this._left_arcs.at(i) 
+            if arc.head == head and arc.child != -1 and arc.child < arc.head:
+                n += 1
+        return n
+
+    int n_R(int head) nogil const:
+        cdef int n = 0
+        for i in range(this._right_arcs.size()):
+            arc = this._right_arcs.at(i) 
+            if arc.head == head and arc.child != -1 and arc.child > arc.head:
+                n += 1
+        return n
 
     bint stack_is_connected() nogil const:
         return False
 
     bint entity_is_open() nogil const:
-        if this._e_i < 1:
+        if this._ents.size() == 0:
             return False
-        return this._ents[this._e_i-1].end == -1
+        else:
+            return this._ents.back().end == -1
 
     int stack_depth() nogil const:
-        return this._s_i
+        return this._stack.size()
 
     int buffer_length() nogil const:
-        if this._break != -1:
-            return this._break - this._b_i
-        else:
-            return this.length - this._b_i
-
-    uint64_t hash() nogil const:
-        cdef TokenC[11] sig
-        sig[0] = this.S_(2)[0]
-        sig[1] = this.S_(1)[0]
-        sig[2] = this.R_(this.S(1), 1)[0]
-        sig[3] = this.L_(this.S(0), 1)[0]
-        sig[4] = this.L_(this.S(0), 2)[0]
-        sig[5] = this.S_(0)[0]
-        sig[6] = this.R_(this.S(0), 2)[0]
-        sig[7] = this.R_(this.S(0), 1)[0]
-        sig[8] = this.B_(0)[0]
-        sig[9] = this.E_(0)[0]
-        sig[10] = this.E_(1)[0]
-        return hash64(sig, sizeof(sig), this._s_i) \
-             + hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
-
-    void push_hist(int act) nogil:
-        ring_push(&this._hist, act+1)
-
-    int get_hist(int i) nogil:
-        return ring_get(&this._hist, i)
+        return this.length - this._b_i
 
     void push() nogil:
-        if this.B(0) != -1:
-            this._stack[this._s_i] = this.B(0)
-        this._s_i += 1
-        this._b_i += 1
-        if this.safe_get(this.B_(0).l_edge).sent_start == 1:
-            this.set_break(this.B_(0).l_edge)
-        if this._b_i > this._break:
-            this._break = -1
+        b0 = this.B(0)
+        if this._rebuffer.size():
+            b0 = this._rebuffer.back()
+            this._rebuffer.pop_back()
+        else:
+            b0 = this._b_i
+            this._b_i += 1
+        this._stack.push_back(b0)
 
     void pop() nogil:
-        if this._s_i >= 1:
-            this._s_i -= 1
+        this._stack.pop_back()
 
     void force_final() nogil:
         # This should only be used in desperate situations, as it may leave
         # the analysis in an unexpected state.
-        this._s_i = 0
+        this._stack.clear()
         this._b_i = this.length
 
     void unshift() nogil:
-        this._b_i -= 1
-        this._buffer[this._b_i] = this.S(0)
-        this._s_i -= 1
-        this.shifted[this.B(0)] = True
+        s0 = this._stack.back()
+        this._unshiftable[s0] = 1
+        this._rebuffer.push_back(s0)
+        this._stack.pop_back()
+
+    int is_unshiftable(int item) nogil const:
+        if item >= this._unshiftable.size():
+            return 0
+        else:
+            return this._unshiftable.at(item)
+
+    void set_reshiftable(int item) nogil:
+        if item < this._unshiftable.size():
+            this._unshiftable[item] = 0
 
     void add_arc(int head, int child, attr_t label) nogil:
         if this.has_head(child):
             this.del_arc(this.H(child), child)
-
-        cdef int dist = head - child
-        this._sent[child].head = dist
-        this._sent[child].dep = label
-        cdef int i
-        if child > head:
-            this._sent[head].r_kids += 1
-            # Some transition systems can have a word in the buffer have a
-            # rightward child, e.g. from Unshift.
-            this._sent[head].r_edge = this._sent[child].r_edge
-            i = 0
-            while this.has_head(head) and i < this.length:
-                head = this.H(head)
-                this._sent[head].r_edge = this._sent[child].r_edge
-                i += 1 # Guard against infinite loops
+        cdef ArcC arc
+        arc.head = head
+        arc.child = child
+        arc.label = label
+        if head > child:
+            this._left_arcs.push_back(arc)
         else:
-            this._sent[head].l_kids += 1
-            this._sent[head].l_edge = this._sent[child].l_edge
+            this._right_arcs.push_back(arc)
+        this._heads[child] = head
 
     void del_arc(int h_i, int c_i) nogil:
-        cdef int dist = h_i - c_i
-        cdef TokenC* h = &this._sent[h_i]
-        cdef int i = 0
-        if c_i > h_i:
-            # this.R_(h_i, 2) returns the second-rightmost child token of h_i
-            # If we have more than 2 rightmost children, our 2nd rightmost child's
-            # rightmost edge is going to be our new rightmost edge.
-            h.r_edge = this.R_(h_i, 2).r_edge if h.r_kids >= 2 else h_i
-            h.r_kids -= 1
-            new_edge = h.r_edge
-            # Correct upwards in the tree --- see Issue #251
-            while h.head < 0 and i < this.length: # Guard infinite loop
-                h += h.head
-                h.r_edge = new_edge
-                i += 1
+        cdef vector[ArcC]* arcs
+        if h_i > c_i:
+            arcs = &this._left_arcs
         else:
-            # Same logic applies for left edge, but we don't need to walk up
-            # the tree, as the head is off the stack.
-            h.l_edge = this.L_(h_i, 2).l_edge if h.l_kids >= 2 else h_i
-            h.l_kids -= 1
+            arcs = &this._right_arcs
+        if arcs.size() == 0:
+            return
+        arc = arcs.back()
+        if arc.head == h_i and arc.child == c_i:
+            arcs.pop_back()
+        else:
+            for i in range(arcs.size()-1):
+                arc = arcs.at(i)
+                if arc.head == h_i and arc.child == c_i:
+                    arc.head = -1
+                    arc.child = -1
+                    arc.label = 0
+                    break
+
+    SpanC get_ent() nogil const:
+        cdef SpanC ent
+        if this._ents.size() == 0:
+            ent.start = 0
+            ent.end = 0
+            ent.label = 0
+            return ent
+        else:
+            return this._ents.back()
 
     void open_ent(attr_t label) nogil:
-        this._ents[this._e_i].start = this.B(0)
-        this._ents[this._e_i].label = label
-        this._ents[this._e_i].end = -1
-        this._e_i += 1
+        cdef SpanC ent
+        ent.start = this.B(0)
+        ent.label = label
+        ent.end = -1
+        this._ents.push_back(ent)
 
     void close_ent() nogil:
-        # Note that we don't decrement _e_i here! We want to maintain all
-        # entities, not over-write them...
-        this._ents[this._e_i-1].end = this.B(0)+1
-        this._sent[this.B(0)].ent_iob = 1
-
-    void set_ent_tag(int i, int ent_iob, attr_t ent_type) nogil:
-        if 0 <= i < this.length:
-            this._sent[i].ent_iob = ent_iob
-            this._sent[i].ent_type = ent_type
-
-    void set_break(int i) nogil:
-        if 0 <= i < this.length:
-            this._sent[i].sent_start = 1
-            this._break = this._b_i
+        this._ents.back().end = this.B(0)+1
 
     void clone(const StateC* src) nogil:
         this.length = src.length
-        memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
-        memcpy(this._stack, src._stack, this.length * sizeof(int))
-        memcpy(this._buffer, src._buffer, this.length * sizeof(int))
-        memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
-        memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
+        this._sent = src._sent
+        this._stack = src._stack
+        this._rebuffer = src._rebuffer
+        this._sent_starts = src._sent_starts
+        this._unshiftable = src._unshiftable
+        memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0]))
+        this._ents = src._ents
+        this._left_arcs = src._left_arcs
+        this._right_arcs = src._right_arcs
         this._b_i = src._b_i
-        this._s_i = src._s_i
-        this._e_i = src._e_i
-        this._break = src._break
         this.offset = src.offset
         this._empty_token = src._empty_token
-
-    void fast_forward() nogil:
-        # space token attachement policy:
-        # - attach space tokens always to the last preceding real token
-        # - except if it's the beginning of a sentence, then attach to the first following
-        # - boundary case: a document containing multiple space tokens but nothing else,
-        #   then make the last space token the head of all others
-
-        while is_space_token(this.B_(0)) \
-        or this.buffer_length() == 0 \
-        or this.stack_depth() == 0:
-            if this.buffer_length() == 0:
-                # remove the last sentence's root from the stack
-                if this.stack_depth() == 1:
-                    this.pop()
-                # parser got stuck: reduce stack or unshift
-                elif this.stack_depth() > 1:
-                    if this.has_head(this.S(0)):
-                        this.pop()
-                    else:
-                        this.unshift()
-                # stack is empty but there is another sentence on the buffer
-                elif (this.length - this._b_i) >= 1:
-                    this.push()
-                else: # stack empty and nothing else coming
-                    break
-
-            elif is_space_token(this.B_(0)):
-                # the normal case: we're somewhere inside a sentence
-                if this.stack_depth() > 0:
-                    # assert not is_space_token(this.S_(0))
-                    # attach all coming space tokens to their last preceding
-                    # real token (which should be on the top of the stack)
-                    while is_space_token(this.B_(0)):
-                        this.add_arc(this.S(0),this.B(0),0)
-                        this.push()
-                        this.pop()
-                # the rare case: we're at the beginning of a document:
-                # space tokens are attached to the first real token on the buffer
-                elif this.stack_depth() == 0:
-                    # store all space tokens on the stack until a real token shows up
-                    # or the last token on the buffer is reached
-                    while is_space_token(this.B_(0)) and this.buffer_length() > 1:
-                        this.push()
-                    # empty the stack by attaching all space tokens to the
-                    # first token on the buffer
-                    # boundary case: if all tokens are space tokens, the last one
-                    # becomes the head of all others
-                    while this.stack_depth() > 0:
-                        this.add_arc(this.B(0),this.S(0),0)
-                        this.pop()
-                    # move the first token onto the stack
-                    this.push()
-
-            elif this.stack_depth() == 0:
-                # for one token sentences (?)
-                if this.buffer_length() == 1:
-                    this.push()
-                    this.pop()
-                # with an empty stack and a non-empty buffer
-                # only shift is valid anyway
-                elif (this.length - this._b_i) >= 1:
-                    this.push()
-
-            else: # can this even happen?
-                break
diff --git a/spacy/pipeline/_parser_internals/arc_eager.pxd b/spacy/pipeline/_parser_internals/arc_eager.pxd
index e05a34f56..3732dd1b7 100644
--- a/spacy/pipeline/_parser_internals/arc_eager.pxd
+++ b/spacy/pipeline/_parser_internals/arc_eager.pxd
@@ -1,11 +1,7 @@
-from .stateclass cimport StateClass
+from ._state cimport StateC
 from ...typedefs cimport weight_t, attr_t
 from .transition_system cimport Transition, TransitionSystem
 
 
 cdef class ArcEager(TransitionSystem):
     pass
-
-
-cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil
-cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil
diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx
index 69f015bda..cddb6cbd9 100644
--- a/spacy/pipeline/_parser_internals/arc_eager.pyx
+++ b/spacy/pipeline/_parser_internals/arc_eager.pyx
@@ -14,16 +14,11 @@ from ._state cimport StateC
 
 from ...errors import Errors
 
-# Calculate cost as gold/not gold. We don't use scalar value anyway.
-cdef int BINARY_COSTS = 1
 cdef weight_t MIN_SCORE = -90000
 cdef attr_t SUBTOK_LABEL = hash_string(u'subtok')
 
 DEF NON_MONOTONIC = True
-DEF USE_BREAK = True
 
-# Break transition from here
-# http://www.aclweb.org/anthology/P13-1074
 cdef enum:
     SHIFT
     REDUCE
@@ -61,9 +56,11 @@ cdef struct GoldParseStateC:
     int32_t* n_kids
     int32_t length
     int32_t stride
+    weight_t push_cost
+    weight_t pop_cost
 
 
-cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
+cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
         heads, labels, sent_starts) except *:
     cdef GoldParseStateC gs
     gs.length = len(heads)
@@ -142,10 +139,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
             if head != i:
                 gs.kids[head][js[head]] = i
                 js[head] += 1
+    gs.push_cost = push_cost(state, &gs)
+    gs.pop_cost = pop_cost(state, &gs)
     return gs
 
 
-cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
+cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil:
     for i in range(gs.length):
         gs.state_bits[i] = set_state_flag(
             gs.state_bits[i],
@@ -160,9 +159,9 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
         gs.n_kids_in_stack[i] = 0
         gs.n_kids_in_buffer[i] = 0
 
-    for i in range(stcls.stack_depth()):
-        s_i = stcls.S(i)
-        if not is_head_unknown(gs, s_i):
+    for i in range(s.stack_depth()):
+        s_i = s.S(i)
+        if not is_head_unknown(gs, s_i) and gs.heads[s_i] != s_i:
             gs.n_kids_in_stack[gs.heads[s_i]] += 1
         for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
             gs.state_bits[kid] = set_state_flag(
@@ -170,9 +169,11 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
                 HEAD_IN_STACK,
                 1
             )
-    for i in range(stcls.buffer_length()):
-        b_i = stcls.B(i)
-        if not is_head_unknown(gs, b_i):
+    for i in range(s.buffer_length()):
+        b_i = s.B(i)
+        if s.is_sent_start(b_i):
+            break
+        if not is_head_unknown(gs, b_i) and gs.heads[b_i] != b_i:
             gs.n_kids_in_buffer[gs.heads[b_i]] += 1
         for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
             gs.state_bits[kid] = set_state_flag(
@@ -180,6 +181,8 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
                 HEAD_IN_BUFFER,
                 1
             )
+    gs.push_cost = push_cost(s, gs)
+    gs.pop_cost = pop_cost(s, gs)
 
 
 cdef class ArcEagerGold:
@@ -191,17 +194,17 @@ cdef class ArcEagerGold:
         heads, labels = example.get_aligned_parse(projectivize=True)
         labels = [label if label is not None else "" for label in labels]
         labels = [example.x.vocab.strings.add(label) for label in labels]
-        sent_starts = example.get_aligned("SENT_START")
-        assert len(heads) == len(labels) == len(sent_starts)
-        self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
+        sent_starts = example.get_aligned_sent_starts()
+        assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
+        self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
 
     def update(self, StateClass stcls):
-        update_gold_state(&self.c, stcls)
+        update_gold_state(&self.c, stcls.c)
 
 
 cdef int check_state_gold(char state_bits, char flag) nogil:
     cdef char one = 1
-    return state_bits & (one << flag)
+    return 1 if (state_bits & (one << flag)) else 0
 
 
 cdef int set_state_flag(char state_bits, char flag, int value) nogil:
@@ -232,41 +235,30 @@ cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
 
 # Helper functions for the arc-eager oracle
 
-cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
-    gold = <const GoldParseStateC*>_gold
+cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
     cdef weight_t cost = 0
-    if is_head_in_stack(gold, target):
+    b0 = state.B(0)
+    if b0 < 0:
+        return 9000
+    if is_head_in_stack(gold, b0):
         cost += 1
-    cost += gold.n_kids_in_stack[target]
-    if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
+    cost += gold.n_kids_in_stack[b0]
+    if Break.is_valid(state, 0) and is_sent_start(gold, state.B(1)):
         cost += 1
     return cost
 
 
-cdef weight_t pop_cost(StateClass stcls, const void* _gold, int target) nogil:
-    gold = <const GoldParseStateC*>_gold
+cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
     cdef weight_t cost = 0
-    if is_head_in_buffer(gold, target):
-        cost += 1
-    cost += gold[0].n_kids_in_buffer[target]
-    if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
+    s0 = state.S(0)
+    if s0 < 0:
+        return 9000
+    if is_head_in_buffer(gold, s0):
         cost += 1
+    cost += gold.n_kids_in_buffer[s0]
     return cost
 
 
-cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil:
-    gold = <const GoldParseStateC*>_gold
-    if arc_is_gold(gold, head, child):
-        return 0
-    elif stcls.H(child) == gold.heads[child]:
-        return 1
-    # Head in buffer
-    elif is_head_in_buffer(gold, child):
-        return 1
-    else:
-        return 0
-
-
 cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
     if is_head_unknown(gold, child):
         return True
@@ -276,7 +268,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
         return False
 
 
-cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil:
+cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil:
     if is_head_unknown(gold, child):
         return True
     elif label == 0:
@@ -292,218 +284,251 @@ cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
 
 
 cdef class Shift:
+    """Move the first word of the buffer onto the stack and mark it as "shifted"
+
+    Validity:
+    * If stack is empty
+    * At least two words in sentence
+    * Word has not been shifted before
+
+    Cost: push_cost 
+
+    Action:
+    * Mark B[0] as 'shifted'
+    * Push stack
+    * Advance buffer
+    """
     @staticmethod
     cdef bint is_valid(const StateC* st, attr_t label) nogil:
-        sent_start = st._sent[st.B_(0).l_edge].sent_start
-        return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1
+        if st.stack_depth() == 0:
+            return 1
+        elif st.buffer_length() < 2:
+            return 0
+        elif st.is_sent_start(st.B(0)):
+            return 0
+        elif st.is_unshiftable(st.B(0)):
+            return 0
+        else:
+            return 1
 
     @staticmethod
     cdef int transition(StateC* st, attr_t label) nogil:
         st.push()
-        st.fast_forward()
 
     @staticmethod
-    cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
         gold = <const GoldParseStateC*>_gold
-        return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
-
-    @staticmethod
-    cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
-        gold = <const GoldParseStateC*>_gold
-        return push_cost(s, gold, s.B(0))
-
-    @staticmethod
-    cdef inline weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
-        return 0
+        return gold.push_cost
 
 
 cdef class Reduce:
+    """
+    Pop from the stack. If it has no head and the stack isn't empty, place
+    it back on the buffer.
+
+    Validity:
+    * Stack not empty
+    * Buffer nt empty
+    * Stack depth 1 and cannot sent start l_edge(st.B(0))
+
+    Cost:
+    * If B[0] is the start of a sentence, cost is 0
+    * Arcs between stack and buffer
+    * If arc has no head, we're saving arcs between S[0] and S[1:], so decrement
+        cost by those arcs.
+    """
     @staticmethod
     cdef bint is_valid(const StateC* st, attr_t label) nogil:
-        return st.stack_depth() >= 2
-
-    @staticmethod
-    cdef int transition(StateC* st, attr_t label) nogil:
-        if st.has_head(st.S(0)):
-            st.pop()
-        else:
-            st.unshift()
-        st.fast_forward()
-
-    @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
-        gold = <const GoldParseStateC*>_gold
-        return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
-
-    @staticmethod
-    cdef inline weight_t move_cost(StateClass st, const void* _gold) nogil:
-        gold = <const GoldParseStateC*>_gold
-        s0 = st.S(0)
-        cost = pop_cost(st, gold, s0)
-        return_to_buffer = not st.has_head(s0)
-        if return_to_buffer:
-            # Decrement cost for the arcs we save, as we'll be putting this
-            # back to the buffer
-            if is_head_in_stack(gold, s0):
-                cost -= 1
-            cost -= gold.n_kids_in_stack[s0]
-            if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
-                cost -= 1
-        return cost
-
-    @staticmethod
-    cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
-        return 0
-
-
-cdef class LeftArc:
-    @staticmethod
-    cdef bint is_valid(const StateC* st, attr_t label) nogil:
-        if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
-            return 0
-        sent_start = st._sent[st.B_(0).l_edge].sent_start
-        return sent_start != 1
-
-    @staticmethod
-    cdef int transition(StateC* st, attr_t label) nogil:
-        st.add_arc(st.B(0), st.S(0), label)
-        st.pop()
-        st.fast_forward()
-
-    @staticmethod
-    cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
-        gold = <const GoldParseStateC*>_gold
-        return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
-
-    @staticmethod
-    cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
-        cdef weight_t cost = 0
-        s0 = s.S(0)
-        b0 = s.B(0)
-        if arc_is_gold(gold, b0, s0):
-            # Have a negative cost if we 'recover' from the wrong dependency
-            return 0 if not s.has_head(s0) else -1
-        else:
-            # Account for deps we might lose between S0 and stack
-            if not s.has_head(s0):
-                cost += gold.n_kids_in_stack[s0]
-                if is_head_in_buffer(gold, s0):
-                    cost += 1
-            return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
-
-    @staticmethod
-    cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
-        return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
-
-
-cdef class RightArc:
-    @staticmethod
-    cdef bint is_valid(const StateC* st, attr_t label) nogil:
-        # If there's (perhaps partial) parse pre-set, don't allow cycle.
-        if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
-            return 0
-        sent_start = st._sent[st.B_(0).l_edge].sent_start
-        return sent_start != 1 and st.H(st.S(0)) != st.B(0)
-
-    @staticmethod
-    cdef int transition(StateC* st, attr_t label) nogil:
-        st.add_arc(st.S(0), st.B(0), label)
-        st.push()
-        st.fast_forward()
-
-    @staticmethod
-    cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
-        gold = <const GoldParseStateC*>_gold
-        return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
-
-    @staticmethod
-    cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
-        gold = <const GoldParseStateC*>_gold
-        if arc_is_gold(gold, s.S(0), s.B(0)):
-            return 0
-        elif s.c.shifted[s.B(0)]:
-            return push_cost(s, gold, s.B(0))
-        else:
-            return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
-
-    @staticmethod
-    cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
-        gold = <const GoldParseStateC*>_gold
-        return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
-
-
-cdef class Break:
-    @staticmethod
-    cdef bint is_valid(const StateC* st, attr_t label) nogil:
-        cdef int i
-        if not USE_BREAK:
+        if st.stack_depth() == 0:
             return False
-        elif st.at_break():
-            return False
-        elif st.stack_depth() < 1:
-            return False
-        elif st.B_(0).l_edge < 0:
-            return False
-        elif st._sent[st.B_(0).l_edge].sent_start < 0:
+        elif st.buffer_length() == 0:
+            return True
+        elif st.stack_depth() == 1 and st.cannot_sent_start(st.l_edge(st.B(0))):
             return False
         else:
             return True
 
     @staticmethod
     cdef int transition(StateC* st, attr_t label) nogil:
-        st.set_break(st.B_(0).l_edge)
-        st.fast_forward()
-
-    @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
-        gold = <const GoldParseStateC*>_gold
-        return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
-
-    @staticmethod
-    cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
-        gold = <const GoldParseStateC*>_gold
-        cost = 0
-        for i in range(s.stack_depth()):
-            S_i = s.S(i)
-            cost += gold.n_kids_in_buffer[S_i]
-            if is_head_in_buffer(gold, S_i):
-                cost += 1
-        # It's weird not to check the gold sentence boundaries but if we do,
-        # we can't account for "sunk costs", i.e. situations where we're already
-        # wrong.
-        s0_root = _get_root(s.S(0), gold)
-        b0_root = _get_root(s.B(0), gold)
-        if s0_root != b0_root or s0_root == -1 or b0_root == -1:
-            return cost
+        if st.has_head(st.S(0)) or st.stack_depth() == 1:
+            st.pop()
         else:
-            return cost + 1
+            st.unshift()
 
     @staticmethod
-    cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
-        return 0
+    cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
+        gold = <const GoldParseStateC*>_gold
+        if state.is_sent_start(state.B(0)):
+            return 0
+        s0 = state.S(0)
+        cost = gold.pop_cost
+        if not state.has_head(s0):
+            # Decrement cost for the arcs we save, as we'll be putting this
+            # back to the buffer
+            if is_head_in_stack(gold, s0):
+                cost -= 1
+            cost -= gold.n_kids_in_stack[s0]
+        return cost
 
-cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
-    if is_head_unknown(gold, word):
-        return -1
-    while gold.heads[word] != word and word >= 0:
-        word = gold.heads[word]
-        if is_head_unknown(gold, word):
-            return -1
-    else:
-        return word
+
+cdef class LeftArc:
+    """Add an arc between B[0] and S[0], replacing the previous head of S[0] if
+    one is set. Pop S[0] from the stack.
+
+    Validity:
+    * len(S) >= 1
+    * len(B) >= 1
+    * not is_sent_start(B[0])
+
+    Cost:
+        pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0]))
+    """
+    @staticmethod
+    cdef bint is_valid(const StateC* st, attr_t label) nogil:
+        if st.stack_depth() == 0:
+            return 0
+        elif st.buffer_length() == 0:
+            return 0
+        elif st.is_sent_start(st.B(0)):
+            return 0
+        elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
+            return 0
+        else:
+            return 1
+
+    @staticmethod
+    cdef int transition(StateC* st, attr_t label) nogil:
+        st.add_arc(st.B(0), st.S(0), label)
+        # If we change the stack, it's okay to remove the shifted mark, as
+        # we can't get in an infinite loop this way.
+        st.set_reshiftable(st.B(0))
+        st.pop()
+
+    @staticmethod
+    cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
+        gold = <const GoldParseStateC*>_gold
+        cdef weight_t cost = gold.pop_cost
+        s0 = state.S(0)
+        s1 = state.S(1)
+        b0 = state.B(0)
+        if state.has_head(s0):
+            # Increment cost if we're clobbering a correct arc
+            cost += gold.heads[s0] == s1
+        else:
+            # If there's no head, we're losing arcs between S0 and S[1:].
+            cost += is_head_in_stack(gold, s0)
+            cost += gold.n_kids_in_stack[s0]
+        if b0 != -1 and s0 != -1 and gold.heads[s0] == b0:
+            cost -= 1
+            cost += not label_is_gold(gold, s0, label)
+        return cost
+
+
+cdef class RightArc:
+    """
+    Add an arc from S[0] to B[0]. Push B[0].
+
+    Validity:
+    * len(S) >= 1
+    * len(B) >= 1
+    * not is_sent_start(B[0])
+
+    Cost:
+        push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label)
+    """
+    @staticmethod
+    cdef bint is_valid(const StateC* st, attr_t label) nogil:
+        if st.stack_depth() == 0:
+            return 0
+        elif st.buffer_length() == 0:
+            return 0
+        elif st.is_sent_start(st.B(0)):
+            return 0
+        elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
+            # If there's (perhaps partial) parse pre-set, don't allow cycle.
+            return 0
+        else:
+            return 1
+
+    @staticmethod
+    cdef int transition(StateC* st, attr_t label) nogil:
+        st.add_arc(st.S(0), st.B(0), label)
+        st.push()
+
+    @staticmethod
+    cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
+        gold = <const GoldParseStateC*>_gold
+        cost = gold.push_cost
+        s0 = state.S(0)
+        b0 = state.B(0)
+        if s0 != -1 and b0 != -1 and gold.heads[b0] == s0:
+            cost -= 1
+            cost += not label_is_gold(gold, b0, label)
+        elif is_head_in_buffer(gold, b0) and not state.is_unshiftable(b0):
+            cost += 1
+        return cost
+
+
+cdef class Break:
+    """Mark the second word of the buffer as the start of a 
+    sentence. 
+
+    Validity:
+    * len(buffer) >= 2
+    * B[1] == B[0] + 1
+    * not is_sent_start(B[1])
+    * not cannot_sent_start(B[1])
+
+    Action:
+    * mark_sent_start(B[1])
+
+    Cost:
+    * not is_sent_start(B[1])
+    * Arcs between B[0] and B[1:]
+    * Arcs between S and B[1]
+    """
+    @staticmethod
+    cdef bint is_valid(const StateC* st, attr_t label) nogil:
+        cdef int i
+        if st.buffer_length() < 2:
+            return False
+        elif st.B(1) != st.B(0) + 1:
+            return False
+        elif st.is_sent_start(st.B(1)):
+            return False
+        elif st.cannot_sent_start(st.B(1)):
+            return False
+        else:
+            return True
+
+    @staticmethod
+    cdef int transition(StateC* st, attr_t label) nogil:
+        st.set_sent_start(st.B(1), 1)
+
+    @staticmethod
+    cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
+        gold = <const GoldParseStateC*>_gold
+        cdef int b0 = state.B(0)
+        cdef int cost = 0
+        cdef int si
+        for i in range(state.stack_depth()):
+            si = state.S(i)
+            if is_head_in_buffer(gold, si):
+                cost += 1
+            cost += gold.n_kids_in_buffer[si]
+            # We need to score into B[1:], so subtract deps that are at b0
+            if gold.heads[b0] == si:
+                cost -= 1
+            if gold.heads[si] == b0:
+                cost -= 1
+        if not is_sent_start(gold, state.B(1)) \
+        and not is_sent_start_unknown(gold, state.B(1)):
+            cost += 1
+        return cost
 
 
 cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
     st = new StateC(<const TokenC*>tokens, length)
-    for i in range(st.length):
-        if st._sent[i].dep == 0:
-            st._sent[i].l_edge = i
-            st._sent[i].r_edge = i
-            st._sent[i].head = 0
-            st._sent[i].dep = 0
-            st._sent[i].l_kids = 0
-            st._sent[i].r_kids = 0
-    st.fast_forward()
     return <void*>st
 
 
@@ -515,6 +540,8 @@ cdef int _del_state(Pool mem, void* state, void* x) except -1:
 cdef class ArcEager(TransitionSystem):
     def __init__(self, *args, **kwargs):
         TransitionSystem.__init__(self, *args, **kwargs)
+        self.init_beam_state = _init_state
+        self.del_beam_state = _del_state
 
     @classmethod
     def get_actions(cls, **kwargs):
@@ -537,7 +564,7 @@ cdef class ArcEager(TransitionSystem):
                     label = 'ROOT'
                 if head == child:
                     actions[BREAK][label] += 1
-                elif head < child:
+                if head < child:
                     actions[RIGHT][label] += 1
                     actions[REDUCE][''] += 1
                 elif head > child:
@@ -567,8 +594,14 @@ cdef class ArcEager(TransitionSystem):
         t.do(state.c, t.label)
         return state
 
-    def is_gold_parse(self, StateClass state, gold):
-        raise NotImplementedError
+    def is_gold_parse(self, StateClass state, ArcEagerGold gold):
+        for i in range(state.c.length):
+            token = state.c.safe_get(i)
+            if not arc_is_gold(&gold.c, i, i+token.head):
+                return False
+            elif not label_is_gold(&gold.c, i, token.dep):
+                return False
+        return True
 
     def init_gold(self, StateClass state, Example example):
         gold = ArcEagerGold(self, state, example)
@@ -576,6 +609,7 @@ cdef class ArcEager(TransitionSystem):
         return gold
 
     def init_gold_batch(self, examples):
+        # TODO: Projectivitity?
         all_states = self.init_batch([eg.predicted for eg in examples])
         golds = []
         states = []
@@ -662,24 +696,13 @@ cdef class ArcEager(TransitionSystem):
             raise ValueError(Errors.E019.format(action=move, src='arc_eager'))
         return t
 
-    cdef int initialize_state(self, StateC* st) nogil:
-        for i in range(st.length):
-            if st._sent[i].dep == 0:
-                st._sent[i].l_edge = i
-                st._sent[i].r_edge = i
-                st._sent[i].head = 0
-                st._sent[i].dep = 0
-                st._sent[i].l_kids = 0
-                st._sent[i].r_kids = 0
-        st.fast_forward()
-
-    cdef int finalize_state(self, StateC* st) nogil:
-        cdef int i
-        for i in range(st.length):
-            if st._sent[i].head == 0:
-                st._sent[i].dep = self.root_label
-
-    def finalize_doc(self, Doc doc):
+    def set_annotations(self, StateClass state, Doc doc):
+        for arc in state.arcs:
+            doc.c[arc["child"]].head = arc["head"] - arc["child"]
+            doc.c[arc["child"]].dep = arc["label"]
+        for i in range(doc.length):
+            if doc.c[i].head == 0:
+                doc.c[i].dep = self.root_label
         set_children_from_heads(doc.c, 0, doc.length)
 
     def has_gold(self, Example eg, start=0, end=None):
@@ -690,7 +713,7 @@ cdef class ArcEager(TransitionSystem):
             return False
 
     cdef int set_valid(self, int* output, const StateC* st) nogil:
-        cdef bint[N_MOVES] is_valid
+        cdef int[N_MOVES] is_valid
         is_valid[SHIFT] = Shift.is_valid(st, 0)
         is_valid[REDUCE] = Reduce.is_valid(st, 0)
         is_valid[LEFT] = LeftArc.is_valid(st, 0)
@@ -710,29 +733,31 @@ cdef class ArcEager(TransitionSystem):
         gold_state = gold_.c
         n_gold = 0
         if self.c[i].is_valid(stcls.c, self.c[i].label):
-            cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
+            cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
         else:
             cost = 9000
         return cost
 
     cdef int set_costs(self, int* is_valid, weight_t* costs,
-                       StateClass stcls, gold) except -1:
+                       const StateC* state, gold) except -1:
         if not isinstance(gold, ArcEagerGold):
             raise TypeError(Errors.E909.format(name="ArcEagerGold"))
         cdef ArcEagerGold gold_ = gold
-        gold_.update(stcls)
         gold_state = gold_.c
+        update_gold_state(&gold_state, state)
+        self.set_valid(is_valid, state)
         cdef int n_gold = 0
         for i in range(self.n_moves):
-            if self.c[i].is_valid(stcls.c, self.c[i].label):
-                is_valid[i] = True
-                costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
+            if is_valid[i]:
+                costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
                 if costs[i] <= 0:
                     n_gold += 1
             else:
-                is_valid[i] = False
                 costs[i] = 9000
         if n_gold < 1:
+            for i in range(self.n_moves):
+                print(self.get_class_name(i), is_valid[i], costs[i])
+            print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
             raise ValueError
 
     def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
@@ -748,12 +773,13 @@ cdef class ArcEager(TransitionSystem):
         failed = False
         while not state.is_final():
             try:
-                self.set_costs(is_valid, costs, state, gold)
+                self.set_costs(is_valid, costs, state.c, gold)
             except ValueError:
                 failed = True
                 break
+            min_cost = min(costs[i] for i in range(self.n_moves))
             for i in range(self.n_moves):
-                if is_valid[i] and costs[i] <= 0:
+                if is_valid[i] and costs[i] <= min_cost:
                     action = self.c[i]
                     history.append(i)
                     s0 = state.S(0)
@@ -762,9 +788,7 @@ cdef class ArcEager(TransitionSystem):
                         example = _debug
                         debug_log.append(" ".join((
                             self.get_class_name(i),
-                            "S0=", (example.x[s0].text if s0 >= 0 else "__"),
-                            "B0=", (example.x[b0].text if b0 >= 0 else "__"),
-                            "S0 head?", str(state.has_head(state.S(0))),
+                            state.print_state()
                         )))
                     action.do(state.c, action.label)
                     break
@@ -783,6 +807,8 @@ cdef class ArcEager(TransitionSystem):
             print("Aligned heads")
             for i, head in enumerate(aligned_heads):
                 print(example.x[i], example.x[head] if head is not None else "__")
+            print("Aligned sent starts")
+            print(example.get_aligned_sent_starts())
 
             print("Predicted tokens")
             print([(w.i, w.text) for w in example.x])
diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx
index 4f142caaf..7f4d332db 100644
--- a/spacy/pipeline/_parser_internals/ner.pyx
+++ b/spacy/pipeline/_parser_internals/ner.pyx
@@ -3,9 +3,12 @@ from cymem.cymem cimport Pool
 
 from collections import Counter
 
+from ...tokens.doc cimport Doc
+from ...tokens.span import Span
 from ...typedefs cimport weight_t, attr_t
 from ...lexeme cimport Lexeme
 from ...attrs cimport IS_SPACE
+from ...structs cimport TokenC
 from ...training.example cimport Example
 from .stateclass cimport StateClass
 from ._state cimport StateC
@@ -46,17 +49,17 @@ cdef class BiluoGold:
 
     def __init__(self, BiluoPushDown moves, StateClass stcls, Example example):
         self.mem = Pool()
-        self.c = create_gold_state(self.mem, moves, stcls, example)
+        self.c = create_gold_state(self.mem, moves, stcls.c, example)
 
     def update(self, StateClass stcls):
-        update_gold_state(&self.c, stcls)
+        update_gold_state(&self.c, stcls.c)
 
 
 
 cdef GoldNERStateC create_gold_state(
     Pool mem,
     BiluoPushDown moves,
-    StateClass stcls,
+    const StateC* stcls,
     Example example
 ) except *:
     cdef GoldNERStateC gs
@@ -67,7 +70,7 @@ cdef GoldNERStateC create_gold_state(
     return gs
 
 
-cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
+cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
     # We don't need to update each time, unlike the parser.
     pass
 
@@ -75,14 +78,15 @@ cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
 cdef do_func_t[N_MOVES] do_funcs
 
 
-cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
-    if not st.entity_is_open():
+cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
+    if not state.entity_is_open():
         return False
 
-    cdef const Transition* gold = &golds[st.E(0)]
+    cdef const Transition* gold = &golds[state.E(0)]
+    ent = state.get_ent()
     if gold.move != BEGIN and gold.move != UNIT:
         return True
-    elif gold.label != st.E_(0).ent_type:
+    elif gold.label != ent.label:
         return True
     else:
         return False
@@ -228,15 +232,18 @@ cdef class BiluoPushDown(TransitionSystem):
             self.labels[action][label_name] = -1
         return 1
 
-    cdef int initialize_state(self, StateC* st) nogil:
-        # This is especially necessary when we use limited training data.
-        for i in range(st.length):
-            if st._sent[i].ent_type != 0:
-                with gil:
-                    self.add_action(BEGIN, st._sent[i].ent_type)
-                    self.add_action(IN, st._sent[i].ent_type)
-                    self.add_action(UNIT, st._sent[i].ent_type)
-                    self.add_action(LAST, st._sent[i].ent_type)
+    def set_annotations(self, StateClass state, Doc doc):
+        cdef int i
+        ents = []
+        for i in range(state.c._ents.size()):
+            ent = state.c._ents.at(i)
+            if ent.start != -1 and ent.end != -1:
+                ents.append(Span(doc, ent.start, ent.end, label=ent.label))
+        doc.set_ents(ents, default="unmodified")
+        # Set non-blocked tokens to O
+        for i in range(doc.length):
+            if doc.c[i].ent_iob == 0:
+                doc.c[i].ent_iob = 2
 
     def init_gold(self, StateClass state, Example example):
         return BiluoGold(self, state, example)
@@ -255,26 +262,25 @@ cdef class BiluoPushDown(TransitionSystem):
         gold_state = gold_.c
         n_gold = 0
         if self.c[i].is_valid(stcls.c, self.c[i].label):
-            cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
+            cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
         else:
             cost = 9000
         return cost
 
     cdef int set_costs(self, int* is_valid, weight_t* costs,
-                       StateClass stcls, gold) except -1:
+                       const StateC* state, gold) except -1:
         if not isinstance(gold, BiluoGold):
             raise TypeError(Errors.E909.format(name="BiluoGold"))
         cdef BiluoGold gold_ = gold
-        gold_.update(stcls)
         gold_state = gold_.c
+        update_gold_state(&gold_state, state)
         n_gold = 0
+        self.set_valid(is_valid, state)
         for i in range(self.n_moves):
-            if self.c[i].is_valid(stcls.c, self.c[i].label):
-                is_valid[i] = 1
-                costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
+            if is_valid[i]:
+                costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
                 n_gold += costs[i] <= 0
             else:
-                is_valid[i] = 0
                 costs[i] = 9000
         if n_gold < 1:
             raise ValueError
@@ -290,7 +296,7 @@ cdef class Missing:
         pass
 
     @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
         return 9000
 
 
@@ -299,10 +305,10 @@ cdef class Begin:
     cdef bint is_valid(const StateC* st, attr_t label) nogil:
         cdef int preset_ent_iob = st.B_(0).ent_iob
         cdef attr_t preset_ent_label = st.B_(0).ent_type
-        # If we're the last token of the input, we can't B -- must U or O.
-        if st.B(1) == -1:
+        if st.entity_is_open():
             return False
-        elif st.entity_is_open():
+        if st.buffer_length() < 2:
+            # If we're the last token of the input, we can't B -- must U or O.
             return False
         elif label == 0:
             return False
@@ -337,12 +343,11 @@ cdef class Begin:
     @staticmethod
     cdef int transition(StateC* st, attr_t label) nogil:
         st.open_ent(label)
-        st.set_ent_tag(st.B(0), 3, label)
         st.push()
         st.pop()
 
     @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
         gold = <GoldNERStateC*>_gold
         cdef int g_act = gold.ner[s.B(0)].move
         cdef attr_t g_tag = gold.ner[s.B(0)].label
@@ -366,16 +371,17 @@ cdef class Begin:
 cdef class In:
     @staticmethod
     cdef bint is_valid(const StateC* st, attr_t label) nogil:
+        if not st.entity_is_open():
+            return False
+        if st.buffer_length() < 2:
+            # If we're at the end, we can't I.
+            return False
+        ent = st.get_ent()
         cdef int preset_ent_iob = st.B_(0).ent_iob
         cdef attr_t preset_ent_label = st.B_(0).ent_type
         if label == 0:
             return False
-        elif st.E_(0).ent_type != label:
-            return False
-        elif not st.entity_is_open():
-            return False
-        elif st.B(1) == -1:
-            # If we're at the end, we can't I.
+        elif ent.label != label:
             return False
         elif preset_ent_iob == 3:
             return False
@@ -401,12 +407,11 @@ cdef class In:
 
     @staticmethod
     cdef int transition(StateC* st, attr_t label) nogil:
-        st.set_ent_tag(st.B(0), 1, label)
         st.push()
         st.pop()
 
     @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
         gold = <GoldNERStateC*>_gold
         move = IN
         cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
@@ -457,7 +462,7 @@ cdef class Last:
                 # Otherwise, force acceptance, even if we're across a sentence
                 # boundary or the token is whitespace.
                 return True
-        elif st.E_(0).ent_type != label:
+        elif st.get_ent().label != label:
             return False
         elif st.B_(1).ent_iob == 1:
             # If a preset entity has I next, we can't L here.
@@ -468,12 +473,11 @@ cdef class Last:
     @staticmethod
     cdef int transition(StateC* st, attr_t label) nogil:
         st.close_ent()
-        st.set_ent_tag(st.B(0), 1, label)
         st.push()
         st.pop()
 
     @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
         gold = <GoldNERStateC*>_gold
         move = LAST
 
@@ -537,12 +541,11 @@ cdef class Unit:
     cdef int transition(StateC* st, attr_t label) nogil:
         st.open_ent(label)
         st.close_ent()
-        st.set_ent_tag(st.B(0), 3, label)
         st.push()
         st.pop()
 
     @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
         gold = <GoldNERStateC*>_gold
         cdef int g_act = gold.ner[s.B(0)].move
         cdef attr_t g_tag = gold.ner[s.B(0)].label
@@ -578,12 +581,11 @@ cdef class Out:
 
     @staticmethod
     cdef int transition(StateC* st, attr_t label) nogil:
-        st.set_ent_tag(st.B(0), 2, 0)
         st.push()
         st.pop()
 
     @staticmethod
-    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
+    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
         gold = <GoldNERStateC*>_gold
         cdef int g_act = gold.ner[s.B(0)].move
         cdef attr_t g_tag = gold.ner[s.B(0)].label
diff --git a/spacy/pipeline/_parser_internals/stateclass.pxd b/spacy/pipeline/_parser_internals/stateclass.pxd
index 1d9f05538..54ff344b9 100644
--- a/spacy/pipeline/_parser_internals/stateclass.pxd
+++ b/spacy/pipeline/_parser_internals/stateclass.pxd
@@ -2,30 +2,24 @@ from cymem.cymem cimport Pool
 
 from ...structs cimport TokenC, SpanC
 from ...typedefs cimport attr_t
+from ...tokens.doc cimport Doc
 
 from ._state cimport StateC
 
 
 cdef class StateClass:
-    cdef Pool mem
     cdef StateC* c
+    cdef readonly Doc doc
     cdef int _borrowed
 
     @staticmethod
-    cdef inline StateClass init(const TokenC* sent, int length):
+    cdef inline StateClass borrow(StateC* ptr, Doc doc):
         cdef StateClass self = StateClass()
-        self.c = new StateC(sent, length)
-        return self
-    
-    @staticmethod
-    cdef inline StateClass borrow(StateC* ptr):
-        cdef StateClass self = StateClass()
-        del self.c
         self.c = ptr
         self._borrowed = 1
+        self.doc = doc
         return self
 
-
     @staticmethod
     cdef inline StateClass init_offset(const TokenC* sent, int length, int
                                        offset):
@@ -33,105 +27,3 @@ cdef class StateClass:
         self.c = new StateC(sent, length)
         self.c.offset = offset
         return self
-
-    cdef inline int S(self, int i) nogil:
-        return self.c.S(i)
-
-    cdef inline int B(self, int i) nogil:
-        return self.c.B(i)
-
-    cdef inline const TokenC* S_(self, int i) nogil:
-        return self.c.S_(i)
-
-    cdef inline const TokenC* B_(self, int i) nogil:
-        return self.c.B_(i)
-
-    cdef inline const TokenC* H_(self, int i) nogil:
-        return self.c.H_(i)
-
-    cdef inline const TokenC* E_(self, int i) nogil:
-        return self.c.E_(i)
-
-    cdef inline const TokenC* L_(self, int i, int idx) nogil:
-        return self.c.L_(i, idx)
-
-    cdef inline const TokenC* R_(self, int i, int idx) nogil:
-        return self.c.R_(i, idx)
-
-    cdef inline const TokenC* safe_get(self, int i) nogil:
-        return self.c.safe_get(i)
-
-    cdef inline int H(self, int i) nogil:
-        return self.c.H(i)
-    
-    cdef inline int E(self, int i) nogil:
-        return self.c.E(i)
-
-    cdef inline int L(self, int i, int idx) nogil:
-        return self.c.L(i, idx)
-
-    cdef inline int R(self, int i, int idx) nogil:
-        return self.c.R(i, idx)
-
-    cdef inline bint empty(self) nogil:
-        return self.c.empty()
-
-    cdef inline bint eol(self) nogil:
-        return self.c.eol()
-
-    cdef inline bint at_break(self) nogil:
-        return self.c.at_break()
-
-    cdef inline bint has_head(self, int i) nogil:
-        return self.c.has_head(i)
-
-    cdef inline int n_L(self, int i) nogil:
-        return self.c.n_L(i)
-
-    cdef inline int n_R(self, int i) nogil:
-        return self.c.n_R(i)
-
-    cdef inline bint stack_is_connected(self) nogil:
-        return False
-
-    cdef inline bint entity_is_open(self) nogil:
-        return self.c.entity_is_open()
-
-    cdef inline int stack_depth(self) nogil:
-        return self.c.stack_depth()
-
-    cdef inline int buffer_length(self) nogil:
-        return self.c.buffer_length()
-
-    cdef inline void push(self) nogil:
-        self.c.push()
-
-    cdef inline void pop(self) nogil:
-        self.c.pop()
-
-    cdef inline void unshift(self) nogil:
-        self.c.unshift()
-
-    cdef inline void add_arc(self, int head, int child, attr_t label) nogil:
-        self.c.add_arc(head, child, label)
-
-    cdef inline void del_arc(self, int head, int child) nogil:
-        self.c.del_arc(head, child)
-
-    cdef inline void open_ent(self, attr_t label) nogil:
-        self.c.open_ent(label)
-
-    cdef inline void close_ent(self) nogil:
-        self.c.close_ent()
-
-    cdef inline void set_ent_tag(self, int i, int ent_iob, attr_t ent_type) nogil:
-        self.c.set_ent_tag(i, ent_iob, ent_type)
-
-    cdef inline void set_break(self, int i) nogil:
-        self.c.set_break(i)
-
-    cdef inline void clone(self, StateClass src) nogil:
-        self.c.clone(src.c)
-
-    cdef inline void fast_forward(self) nogil:
-        self.c.fast_forward()
diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx
index 880cf6cc5..4eaddd997 100644
--- a/spacy/pipeline/_parser_internals/stateclass.pyx
+++ b/spacy/pipeline/_parser_internals/stateclass.pyx
@@ -1,17 +1,20 @@
 # cython: infer_types=True
 import numpy
+from libcpp.vector cimport vector
+from ._state cimport ArcC
 
 from ...tokens.doc cimport Doc
 
 
 cdef class StateClass:
     def __init__(self, Doc doc=None, int offset=0):
-        cdef Pool mem = Pool()
-        self.mem = mem
         self._borrowed = 0
         if doc is not None:
             self.c = new StateC(doc.c, doc.length)
             self.c.offset = offset
+            self.doc = doc
+        else:
+            self.doc = None
 
     def __dealloc__(self):
         if self._borrowed != 1:
@@ -19,36 +22,157 @@ cdef class StateClass:
 
     @property
     def stack(self):
-        return {self.S(i) for i in range(self.c._s_i)}
+        return [self.S(i) for i in range(self.c.stack_depth())]
 
     @property
     def queue(self):
-        return {self.B(i) for i in range(self.c.buffer_length())}
+        return [self.B(i) for i in range(self.c.buffer_length())]
 
     @property
     def token_vector_lenth(self):
         return self.doc.tensor.shape[1]
 
     @property
-    def history(self):
-        hist = numpy.ndarray((8,), dtype='i')
-        for i in range(8):
-            hist[i] = self.c.get_hist(i+1)
-        return hist
+    def arcs(self):
+        cdef vector[ArcC] arcs
+        self.c.get_arcs(&arcs)
+        return list(arcs)
+        #py_arcs = []
+        #for arc in arcs:
+        #    if arc.head != -1 and arc.child != -1:
+        #        py_arcs.append((arc.head, arc.child, arc.label))
+        #return arcs
+
+    def add_arc(self, int head, int child, int label):
+        self.c.add_arc(head, child, label)
+
+    def del_arc(self, int head, int child):
+        self.c.del_arc(head, child)
+
+    def H(self, int child):
+        return self.c.H(child)
+    
+    def L(self, int head, int idx):
+        return self.c.L(head, idx)
+    
+    def R(self, int head, int idx):
+        return self.c.R(head, idx)
+
+    @property
+    def _b_i(self):
+        return self.c._b_i
+
+    @property
+    def length(self):
+        return self.c.length
 
     def is_final(self):
         return self.c.is_final()
 
     def copy(self):
-        cdef StateClass new_state = StateClass.init(self.c._sent, self.c.length)
+        cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset)
         new_state.c.clone(self.c)
         return new_state
 
-    def print_state(self, words):
+    def print_state(self):
+        words = [token.text for token in self.doc]
         words = list(words) + ['_']
-        top = f"{words[self.S(0)]}_{self.S_(0).head}"
-        second = f"{words[self.S(1)]}_{self.S_(1).head}"
-        third = f"{words[self.S(2)]}_{self.S_(2).head}"
-        n0 = words[self.B(0)]
-        n1 = words[self.B(1)]
-        return ' '.join((third, second, top, '|', n0, n1))
+        bools = ["F", "T"]
+        sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))]
+        shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)]
+        shifted.append("")
+        sent_starts.append("")
+        top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}"
+        second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}"
+        third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}"
+        n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}"
+        n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}"
+        return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1))
+
+    def S(self, int i):
+        return self.c.S(i)
+
+    def B(self, int i):
+        return self.c.B(i)
+
+    def H(self, int i):
+        return self.c.H(i)
+    
+    def E(self, int i):
+        return self.c.E(i)
+
+    def L(self, int i, int idx):
+        return self.c.L(i, idx)
+
+    def R(self, int i, int idx):
+        return self.c.R(i, idx)
+
+    def S_(self, int i):
+        return self.doc[self.c.S(i)]
+
+    def B_(self, int i):
+        return self.doc[self.c.B(i)]
+
+    def H_(self, int i):
+        return self.doc[self.c.H(i)]
+    
+    def E_(self, int i):
+        return self.doc[self.c.E(i)]
+
+    def L_(self, int i, int idx):
+        return self.doc[self.c.L(i, idx)]
+
+    def R_(self, int i, int idx):
+        return self.doc[self.c.R(i, idx)]
+ 
+    def empty(self):
+        return self.c.empty()
+
+    def eol(self):
+        return self.c.eol()
+
+    def at_break(self):
+        return False
+        #return self.c.at_break()
+
+    def has_head(self, int i):
+        return self.c.has_head(i)
+
+    def  n_L(self, int i):
+        return self.c.n_L(i)
+
+    def n_R(self, int i):
+        return self.c.n_R(i)
+
+    def entity_is_open(self):
+        return self.c.entity_is_open()
+
+    def stack_depth(self):
+        return self.c.stack_depth()
+
+    def buffer_length(self):
+        return self.c.buffer_length()
+
+    def push(self):
+        self.c.push()
+
+    def pop(self):
+        self.c.pop()
+
+    def unshift(self):
+        self.c.unshift()
+
+    def add_arc(self, int head, int child, attr_t label):
+        self.c.add_arc(head, child, label)
+
+    def del_arc(self, int head, int child):
+        self.c.del_arc(head, child)
+
+    def open_ent(self, attr_t label):
+        self.c.open_ent(label)
+
+    def close_ent(self):
+        self.c.close_ent()
+
+    def clone(self, StateClass src):
+        self.c.clone(src.c)
diff --git a/spacy/pipeline/_parser_internals/transition_system.pxd b/spacy/pipeline/_parser_internals/transition_system.pxd
index 458f1d5f9..eed347b98 100644
--- a/spacy/pipeline/_parser_internals/transition_system.pxd
+++ b/spacy/pipeline/_parser_internals/transition_system.pxd
@@ -16,14 +16,14 @@ cdef struct Transition:
     weight_t score
 
     bint (*is_valid)(const StateC* state, attr_t label) nogil
-    weight_t (*get_cost)(StateClass state, const void* gold, attr_t label) nogil
+    weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil
     int (*do)(StateC* state, attr_t label) nogil
 
 
-ctypedef weight_t (*get_cost_func_t)(StateClass state, const void* gold,
+ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
         attr_tlabel) nogil
-ctypedef weight_t (*move_cost_func_t)(StateClass state, const void* gold) nogil
-ctypedef weight_t (*label_cost_func_t)(StateClass state, const void*
+ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
+ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
         gold, attr_t label) nogil
 
 ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
@@ -41,9 +41,8 @@ cdef class TransitionSystem:
     cdef public attr_t root_label
     cdef public freqs
     cdef public object labels
-
-    cdef int initialize_state(self, StateC* state) nogil
-    cdef int finalize_state(self, StateC* state) nogil
+    cdef init_state_t init_beam_state
+    cdef del_state_t del_beam_state
 
     cdef Transition lookup_transition(self, object name) except *
 
@@ -52,4 +51,4 @@ cdef class TransitionSystem:
     cdef int set_valid(self, int* output, const StateC* st) nogil
 
     cdef int set_costs(self, int* is_valid, weight_t* costs,
-                       StateClass state, gold) except -1
+                       const StateC* state, gold) except -1
diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx
index 7694e7f34..9bb4f7f5f 100644
--- a/spacy/pipeline/_parser_internals/transition_system.pyx
+++ b/spacy/pipeline/_parser_internals/transition_system.pyx
@@ -5,6 +5,7 @@ from cymem.cymem cimport Pool
 from collections import Counter
 import srsly
 
+from . cimport _beam_utils
 from ...typedefs cimport weight_t, attr_t
 from ...tokens.doc cimport Doc
 from ...structs cimport TokenC
@@ -44,6 +45,8 @@ cdef class TransitionSystem:
         if labels_by_action:
             self.initialize_actions(labels_by_action, min_freq=min_freq)
         self.root_label = self.strings.add('ROOT')
+        self.init_beam_state = _init_state
+        self.del_beam_state = _del_state
 
     def __reduce__(self):
         return (self.__class__, (self.strings, self.labels), None, None)
@@ -54,7 +57,6 @@ cdef class TransitionSystem:
         offset = 0
         for doc in docs:
             state = StateClass(doc, offset=offset)
-            self.initialize_state(state.c)
             states.append(state)
             offset += len(doc)
         return states
@@ -80,7 +82,7 @@ cdef class TransitionSystem:
         history = []
         debug_log = []
         while not state.is_final():
-            self.set_costs(is_valid, costs, state, gold)
+            self.set_costs(is_valid, costs, state.c, gold)
             for i in range(self.n_moves):
                 if is_valid[i] and costs[i] <= 0:
                     action = self.c[i]
@@ -124,15 +126,6 @@ cdef class TransitionSystem:
         action = self.lookup_transition(name)
         action.do(state.c, action.label)
 
-    cdef int initialize_state(self, StateC* state) nogil:
-        pass
-
-    cdef int finalize_state(self, StateC* state) nogil:
-        pass
-
-    def finalize_doc(self, doc):
-        pass
-
     cdef Transition lookup_transition(self, object name) except *:
         raise NotImplementedError
 
@@ -151,7 +144,7 @@ cdef class TransitionSystem:
             is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
 
     cdef int set_costs(self, int* is_valid, weight_t* costs,
-                       StateClass stcls, gold) except -1:
+                       const StateC* state, gold) except -1:
         raise NotImplementedError
 
     def get_class_name(self, int clas):
diff --git a/spacy/pipeline/dep_parser.pyx b/spacy/pipeline/dep_parser.pyx
index a9dcd705e..724eb6cd1 100644
--- a/spacy/pipeline/dep_parser.pyx
+++ b/spacy/pipeline/dep_parser.pyx
@@ -105,6 +105,93 @@ def make_parser(
         update_with_oracle_cut_size=update_with_oracle_cut_size,
         multitasks=[],
         learn_tokens=learn_tokens,
+        min_action_freq=min_action_freq,
+        beam_width=1,
+        beam_density=0.0,
+        beam_update_prob=0.0,
+    )
+
+@Language.factory(
+    "beam_parser",
+    assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
+    default_config={
+        "beam_width": 8,
+        "beam_density": 0.01,
+        "beam_update_prob": 0.5,
+        "moves": None,
+        "update_with_oracle_cut_size": 100,
+        "learn_tokens": False,
+        "min_action_freq": 30,
+        "model": DEFAULT_PARSER_MODEL,
+    },
+    default_score_weights={
+        "dep_uas": 0.5,
+        "dep_las": 0.5,
+        "dep_las_per_type": None,
+        "sents_p": None,
+        "sents_r": None,
+        "sents_f": 0.0,
+    },
+)
+def make_beam_parser(
+    nlp: Language,
+    name: str,
+    model: Model,
+    moves: Optional[list],
+    update_with_oracle_cut_size: int,
+    learn_tokens: bool,
+    min_action_freq: int,
+    beam_width: int,
+    beam_density: float,
+    beam_update_prob: float,
+):
+    """Create a transition-based DependencyParser component that uses beam-search.
+    The dependency parser jointly learns sentence segmentation and labelled
+    dependency parsing, and can optionally learn to merge tokens that had been
+    over-segmented by the tokenizer.
+
+    The parser uses a variant of the non-monotonic arc-eager transition-system
+    described by Honnibal and Johnson (2014), with the addition of a "break"
+    transition to perform the sentence segmentation. Nivre's pseudo-projective
+    dependency transformation is used to allow the parser to predict
+    non-projective parses.
+
+    The parser is trained using a global objective. That is, it learns to assign
+    probabilities to whole parses.
+
+    model (Model): The model for the transition-based parser. The model needs
+        to have a specific substructure of named components --- see the
+        spacy.ml.tb_framework.TransitionModel for details.
+    moves (List[str]): A list of transition names. Inferred from the data if not
+        provided.
+    beam_width (int): The number of candidate analyses to maintain.
+    beam_density (float): The minimum ratio between the scores of the first and
+        last candidates in the beam. This allows the parser to avoid exploring
+        candidates that are too far behind. This is mostly intended to improve
+        efficiency, but it can also improve accuracy as deeper search is not
+        always better.
+    beam_update_prob (float): The chance of making a beam update, instead of a
+        greedy update. Greedy updates are an approximation for the beam updates,
+        and are faster to compute.
+    learn_tokens (bool): Whether to learn to merge subtokens that are split
+        relative to the gold standard. Experimental.
+    min_action_freq (int): The minimum frequency of labelled actions to retain.
+        Rarer labelled actions have their label backed-off to "dep". While this
+        primarily affects the label accuracy, it can also affect the attachment
+        structure, as the labels are used to represent the pseudo-projectivity
+        transformation.
+    """
+    return DependencyParser(
+        nlp.vocab,
+        model,
+        name,
+        moves=moves,
+        update_with_oracle_cut_size=update_with_oracle_cut_size,
+        beam_width=beam_width,
+        beam_density=beam_density,
+        beam_update_prob=beam_update_prob,
+        multitasks=[],
+        learn_tokens=learn_tokens,
         min_action_freq=min_action_freq
     )
 
diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx
index 0f93b43ac..e748d95fd 100644
--- a/spacy/pipeline/ner.pyx
+++ b/spacy/pipeline/ner.pyx
@@ -82,6 +82,79 @@ def make_ner(
         multitasks=[],
         min_action_freq=1,
         learn_tokens=False,
+        beam_width=1,
+        beam_density=0.0,
+        beam_update_prob=0.0,
+    )
+
+@Language.factory(
+    "beam_ner",
+    assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
+    default_config={
+        "moves": None,
+        "update_with_oracle_cut_size": 100,
+        "model": DEFAULT_NER_MODEL,
+        "beam_density": 0.01,
+        "beam_update_prob": 0.5,
+        "beam_width": 32
+    },
+    default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
+)
+def make_beam_ner(
+    nlp: Language,
+    name: str,
+    model: Model,
+    moves: Optional[list],
+    update_with_oracle_cut_size: int,
+    beam_width: int,
+    beam_density: float,
+    beam_update_prob: float,
+):
+    """Create a transition-based EntityRecognizer component that uses beam-search.
+    The entity recognizer identifies non-overlapping labelled spans of tokens.
+
+    The transition-based algorithm used encodes certain assumptions that are
+    effective for "traditional" named entity recognition tasks, but may not be
+    a good fit for every span identification problem. Specifically, the loss
+    function optimizes for whole entity accuracy, so if your inter-annotator
+    agreement on boundary tokens is low, the component will likely perform poorly
+    on your problem. The transition-based algorithm also assumes that the most
+    decisive information about your entities will be close to their initial tokens.
+    If your entities are long and characterised by tokens in their middle, the
+    component will likely do poorly on your task.
+
+    model (Model): The model for the transition-based parser. The model needs
+        to have a specific substructure of named components --- see the
+        spacy.ml.tb_framework.TransitionModel for details.
+    moves (list[str]): A list of transition names. Inferred from the data if not
+        provided.
+    update_with_oracle_cut_size (int):
+        During training, cut long sequences into shorter segments by creating
+        intermediate states based on the gold-standard history. The model is
+        not very sensitive to this parameter, so you usually won't need to change
+        it. 100 is a good default.
+    beam_width (int): The number of candidate analyses to maintain.
+    beam_density (float): The minimum ratio between the scores of the first and
+        last candidates in the beam. This allows the parser to avoid exploring
+        candidates that are too far behind. This is mostly intended to improve
+        efficiency, but it can also improve accuracy as deeper search is not
+        always better.
+    beam_update_prob (float): The chance of making a beam update, instead of a
+        greedy update. Greedy updates are an approximation for the beam updates,
+        and are faster to compute.
+    """
+    return EntityRecognizer(
+        nlp.vocab,
+        model,
+        name,
+        moves=moves,
+        update_with_oracle_cut_size=update_with_oracle_cut_size,
+        multitasks=[],
+        min_action_freq=1,
+        learn_tokens=False,
+        beam_width=beam_width,
+        beam_density=beam_density,
+        beam_update_prob=beam_update_prob,
     )
 
 
diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx
index 63a8595cc..8aeacbafb 100644
--- a/spacy/pipeline/transition_parser.pyx
+++ b/spacy/pipeline/transition_parser.pyx
@@ -4,13 +4,14 @@ from cymem.cymem cimport Pool
 cimport numpy as np
 from itertools import islice
 from libcpp.vector cimport vector
-from libc.string cimport memset
+from libc.string cimport memset, memcpy
 from libc.stdlib cimport calloc, free
 import random
 from typing import Optional
 
 import srsly
-from thinc.api import set_dropout_rate
+from thinc.api import set_dropout_rate, CupyOps
+from thinc.extra.search cimport Beam
 import numpy.random
 import numpy
 import warnings
@@ -22,6 +23,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
 from ..ml.parser_model cimport get_c_weights, get_c_sizes
 from ..tokens.doc cimport Doc
 from .trainable_pipe import TrainablePipe
+from ._parser_internals cimport _beam_utils
+from ._parser_internals import _beam_utils
 
 from ..training import validate_examples, validate_get_examples
 from ..errors import Errors, Warnings
@@ -41,9 +44,12 @@ cdef class Parser(TrainablePipe):
         moves=None,
         *,
         update_with_oracle_cut_size,
-        multitasks=tuple(),
         min_action_freq,
         learn_tokens,
+        beam_width=1,
+        beam_density=0.0,
+        beam_update_prob=0.0,
+        multitasks=tuple(),
     ):
         """Create a Parser.
 
@@ -61,7 +67,10 @@ cdef class Parser(TrainablePipe):
             "update_with_oracle_cut_size": update_with_oracle_cut_size,
             "multitasks": list(multitasks),
             "min_action_freq": min_action_freq,
-            "learn_tokens": learn_tokens
+            "learn_tokens": learn_tokens,
+            "beam_width": beam_width,
+            "beam_density": beam_density,
+            "beam_update_prob": beam_update_prob
         }
         if moves is None:
             # defined by EntityRecognizer as a BiluoPushDown
@@ -183,7 +192,15 @@ cdef class Parser(TrainablePipe):
             result = self.moves.init_batch(docs)
             self._resize()
             return result
-        return self.greedy_parse(docs, drop=0.0)
+        if self.cfg["beam_width"] == 1:
+            return self.greedy_parse(docs, drop=0.0)
+        else:
+            return self.beam_parse(
+                docs,
+                drop=0.0,
+                beam_width=self.cfg["beam_width"],
+                beam_density=self.cfg["beam_density"]
+            )
 
     def greedy_parse(self, docs, drop=0.):
         cdef vector[StateC*] states
@@ -207,6 +224,31 @@ cdef class Parser(TrainablePipe):
         del model
         return batch
 
+    def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
+        cdef Beam beam
+        cdef Doc doc
+        batch = _beam_utils.BeamBatch(
+            self.moves,
+            self.moves.init_batch(docs),
+            None,
+            beam_width,
+            density=beam_density
+        )
+        # This is pretty dirty, but the NER can resize itself in init_batch,
+        # if labels are missing. We therefore have to check whether we need to
+        # expand our model output.
+        self._resize()
+        model = self.model.predict(docs)
+        while not batch.is_done:
+            states = batch.get_unfinished_states()
+            if not states:
+                break
+            scores = model.predict(states)
+            batch.advance(scores)
+        model.clear_memory()
+        del model
+        return list(batch)
+
     cdef void _parseC(self, StateC** states,
             WeightsC weights, SizesC sizes) nogil:
         cdef int i, j
@@ -227,14 +269,13 @@ cdef class Parser(TrainablePipe):
             unfinished.clear()
         free_activations(&activations)
 
-    def set_annotations(self, docs, states):
+    def set_annotations(self, docs, states_or_beams):
         cdef StateClass state
+        cdef Beam beam
         cdef Doc doc
+        states = _beam_utils.collect_states(states_or_beams, docs)
         for i, (state, doc) in enumerate(zip(states, docs)):
-            self.moves.finalize_state(state.c)
-            for j in range(doc.length):
-                doc.c[j] = state.c._sent[j]
-            self.moves.finalize_doc(doc)
+            self.moves.set_annotations(state, doc)
             for hook in self.postprocesses:
                 hook(doc)
 
@@ -265,7 +306,6 @@ cdef class Parser(TrainablePipe):
             else:
                 action = self.moves.c[guess]
                 action.do(states[i], action.label)
-                states[i].push_hist(guess)
         free(is_valid)
 
     def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
@@ -276,13 +316,23 @@ cdef class Parser(TrainablePipe):
         validate_examples(examples, "Parser.update")
         for multitask in self._multitasks:
             multitask.update(examples, drop=drop, sgd=sgd)
+    
         n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
         if n_examples == 0:
             return losses
         set_dropout_rate(self.model, drop)
-        # Prepare the stepwise model, and get the callback for finishing the batch
-        model, backprop_tok2vec = self.model.begin_update(
-            [eg.predicted for eg in examples])
+        # The probability we use beam update, instead of falling back to
+        # a greedy update
+        beam_update_prob = self.cfg["beam_update_prob"]
+        if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
+            return self.update_beam(
+                examples,
+                beam_width=self.cfg["beam_width"],
+                set_annotations=set_annotations,
+                sgd=sgd,
+                losses=losses,
+                beam_density=self.cfg["beam_density"]
+            )
         max_moves = self.cfg["update_with_oracle_cut_size"]
         if max_moves >= 1:
             # Chop sequences into lengths of this many words, to make the
@@ -296,6 +346,8 @@ cdef class Parser(TrainablePipe):
             states, golds, _ = self.moves.init_gold_batch(examples)
         if not states:
             return losses
+        model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
+ 
         all_states = list(states)
         states_golds = list(zip(states, golds))
         n_moves = 0
@@ -379,6 +431,27 @@ cdef class Parser(TrainablePipe):
         del tutor
         return losses
 
+    def update_beam(self, examples, *, beam_width,
+            drop=0., sgd=None, losses=None, set_annotations=False, beam_density=0.0):
+        states, golds, _ = self.moves.init_gold_batch(examples)
+        if not states:
+            return losses
+        # Prepare the stepwise model, and get the callback for finishing the batch
+        model, backprop_tok2vec = self.model.begin_update(
+            [eg.predicted for eg in examples])
+        loss = _beam_utils.update_beam(
+            self.moves,
+            states,
+            golds,
+            model,
+            beam_width,
+            beam_density=beam_density,
+        )
+        losses[self.name] += loss
+        backprop_tok2vec(golds)
+        if sgd is not None:
+            self.finish_update(sgd)
+
     def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
         cdef StateClass state
         cdef Pool mem = Pool()
@@ -396,7 +469,7 @@ cdef class Parser(TrainablePipe):
         for i, (state, gold) in enumerate(zip(states, golds)):
             memset(is_valid, 0, self.moves.n_moves * sizeof(int))
             memset(costs, 0, self.moves.n_moves * sizeof(float))
-            self.moves.set_costs(is_valid, costs, state, gold)
+            self.moves.set_costs(is_valid, costs, state.c, gold)
             for j in range(self.moves.n_moves):
                 if costs[j] <= 0.0 and j in unseen_classes:
                     unseen_classes.remove(j)
@@ -539,7 +612,6 @@ cdef class Parser(TrainablePipe):
                 for clas in oracle_actions[i:i+max_length]:
                     action = self.moves.c[clas]
                     action.do(state.c, action.label)
-                    state.c.push_hist(action.clas)
                     if state.is_final():
                         break
                 if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py
index 84070db73..fa78301af 100644
--- a/spacy/tests/parser/test_arc_eager_oracle.py
+++ b/spacy/tests/parser/test_arc_eager_oracle.py
@@ -7,6 +7,7 @@ from spacy.tokens import Doc
 from spacy.pipeline._parser_internals.nonproj import projectivize
 from spacy.pipeline._parser_internals.arc_eager import ArcEager
 from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
+from spacy.pipeline._parser_internals.stateclass import StateClass
 
 
 def get_sequence_costs(M, words, heads, deps, transitions):
@@ -47,15 +48,24 @@ def test_oracle_four_words(arc_eager, vocab):
     for dep in deps:
         arc_eager.add_action(2, dep)  # Left
         arc_eager.add_action(3, dep)  # Right
-    actions = ["L-left", "B-ROOT", "L-left"]
+    actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
     state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
+    expected_gold = [
+        ["S"],
+        ["B-ROOT", "L-left"],
+        ["B-ROOT"],
+        ["S"],
+        ["D"],
+        ["S"],
+        ["L-left"],
+        ["S"],
+        ["D"]
+    ]
     assert state.is_final()
     for i, state_costs in enumerate(cost_history):
         # Check gold moves is 0 cost
-        assert state_costs[actions[i]] == 0.0, actions[i]
-        for other_action, cost in state_costs.items():
-            if other_action != actions[i]:
-                assert cost >= 1, (i, other_action)
+        golds = [act for act, cost in state_costs.items() if cost < 1]
+        assert golds == expected_gold[i], (i, golds, expected_gold[i])
 
 
 annot_tuples = [
@@ -169,12 +179,15 @@ def test_oracle_dev_sentence(vocab, arc_eager):
         . punct said
     """
     expected_transitions = [
+        "S",  # Shift "Rolls-Royce"
         "S",  # Shift 'Motor'
         "S",  # Shift 'Cars'
         "L-nn",  # Attach 'Cars' to 'Inc.'
         "L-nn",  # Attach 'Motor' to 'Inc.'
-        "L-nn",  # Attach 'Rolls-Royce' to 'Inc.', force shift
+        "L-nn",  # Attach 'Rolls-Royce' to 'Inc.'
+        "S",     # Shift "Inc."
         "L-nsubj",  # Attach 'Inc.' to 'said'
+        "S",        # Shift 'said'
         "S",  # Shift 'it'
         "L-nsubj",  # Attach 'it.' to 'expects'
         "R-ccomp",  # Attach 'expects' to 'said'
@@ -204,6 +217,8 @@ def test_oracle_dev_sentence(vocab, arc_eager):
         "D",  # Reduce "steady"
         "D",  # Reduce "expects"
         "R-punct",  # Attach "." to "said"
+        "D",  # Reduce "."
+        "D",  # Reduce "said"
     ]
 
     gold_words = []
@@ -221,10 +236,40 @@ def test_oracle_dev_sentence(vocab, arc_eager):
     for dep in gold_deps:
         arc_eager.add_action(2, dep)  # Left
         arc_eager.add_action(3, dep)  # Right
-
     doc = Doc(Vocab(), words=gold_words)
     example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
-
-    ae_oracle_actions = arc_eager.get_oracle_sequence(example)
+    ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
     ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
     assert ae_oracle_actions == expected_transitions
+
+
+def test_oracle_bad_tokenization(vocab, arc_eager):
+    words_deps_heads = """
+        [catalase] dep is
+        : punct is
+        that nsubj is
+        is root is
+        bad comp is
+    """
+ 
+    gold_words = []
+    gold_deps = []
+    gold_heads = []
+    for line in words_deps_heads.strip().split("\n"):
+        line = line.strip()
+        if not line:
+            continue
+        word, dep, head = line.split()
+        gold_words.append(word)
+        gold_deps.append(dep)
+        gold_heads.append(head)
+    gold_heads = [gold_words.index(head) for head in gold_heads]
+    for dep in gold_deps:
+        arc_eager.add_action(2, dep)  # Left
+        arc_eager.add_action(3, dep)  # Right
+    reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
+    predicted = Doc(reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"])
+    example = Example(predicted=predicted, reference=reference)
+    ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
+    ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
+    assert ae_oracle_actions
diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py
index b4c22b48d..9ed87329c 100644
--- a/spacy/tests/parser/test_ner.py
+++ b/spacy/tests/parser/test_ner.py
@@ -54,7 +54,7 @@ def tsys(vocab, entity_types):
 
 def test_get_oracle_moves(tsys, doc, entity_annots):
     example = Example.from_dict(doc, {"entities": entity_annots})
-    act_classes = tsys.get_oracle_sequence(example)
+    act_classes = tsys.get_oracle_sequence(example, _debug=False)
     names = [tsys.get_class_name(act) for act in act_classes]
     assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
 
diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py
index e69de29bb..1f45b67c8 100644
--- a/spacy/tests/parser/test_nn_beam.py
+++ b/spacy/tests/parser/test_nn_beam.py
@@ -0,0 +1,144 @@
+# coding: utf8
+from __future__ import unicode_literals
+
+import pytest
+import hypothesis
+import hypothesis.strategies
+import numpy
+from spacy.vocab import Vocab
+from spacy.language import Language
+from spacy.pipeline import DependencyParser
+from spacy.pipeline._parser_internals.arc_eager import ArcEager
+from spacy.tokens import Doc
+from spacy.pipeline._parser_internals._beam_utils import BeamBatch
+from spacy.pipeline._parser_internals.stateclass import StateClass
+from spacy.training import Example
+from thinc.tests.strategies import ndarrays_of_shape
+
+
+@pytest.fixture(scope="module")
+def vocab():
+    return Vocab()
+
+
+@pytest.fixture(scope="module")
+def moves(vocab):
+    aeager = ArcEager(vocab.strings, {})
+    aeager.add_action(0, "")
+    aeager.add_action(1, "")
+    aeager.add_action(2, "nsubj")
+    aeager.add_action(2, "punct")
+    aeager.add_action(2, "aux")
+    aeager.add_action(2, "nsubjpass")
+    aeager.add_action(3, "dobj")
+    aeager.add_action(2, "aux")
+    aeager.add_action(4, "ROOT")
+    return aeager
+
+
+@pytest.fixture(scope="module")
+def docs(vocab):
+    return [
+        Doc(
+            vocab,
+            words=["Rats", "bite", "things"],
+            heads=[1, 1, 1],
+            deps=["nsubj", "ROOT", "dobj"],
+            sent_starts=[True, False, False]
+        )
+    ]
+
+
+@pytest.fixture(scope="module")
+def examples(docs):
+    return [Example(doc, doc.copy()) for doc in docs]
+
+
+@pytest.fixture
+def states(docs):
+    return [StateClass(doc) for doc in docs]
+
+
+@pytest.fixture
+def tokvecs(docs, vector_size):
+    output = []
+    for doc in docs:
+        vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size))
+        output.append(numpy.asarray(vec))
+    return output
+
+
+@pytest.fixture(scope="module")
+def batch_size(docs):
+    return len(docs)
+
+
+@pytest.fixture(scope="module")
+def beam_width():
+    return 4
+
+@pytest.fixture(params=[0.0, 0.5, 1.0])
+def beam_density(request):
+    return request.param
+
+@pytest.fixture
+def vector_size():
+    return 6
+
+
+@pytest.fixture
+def beam(moves, examples, beam_width):
+    states, golds, _ = moves.init_gold_batch(examples)
+    return BeamBatch(moves, states, golds, width=beam_width, density=0.0)
+
+
+@pytest.fixture
+def scores(moves, batch_size, beam_width):
+    return numpy.asarray(
+        numpy.concatenate(
+            [
+                numpy.random.uniform(-0.1, 0.1, (beam_width, moves.n_moves))
+                for _ in range(batch_size)
+            ]
+        ), dtype="float32")
+
+
+def test_create_beam(beam):
+    pass
+
+
+def test_beam_advance(beam, scores):
+    beam.advance(scores)
+
+
+def test_beam_advance_too_few_scores(beam, scores):
+    n_state = sum(len(beam) for beam in beam)
+    scores = scores[:n_state]
+    with pytest.raises(IndexError):
+        beam.advance(scores[:-1])
+
+
+def test_beam_parse(examples, beam_width):
+    nlp = Language()
+    parser = nlp.add_pipe("beam_parser")
+    parser.cfg["beam_width"] = beam_width
+    parser.add_label("nsubj")
+    parser.initialize(lambda: examples)
+    doc = nlp.make_doc("Australia is a country")
+    parser(doc)
+
+
+
+
+@hypothesis.given(hyp=hypothesis.strategies.data())
+def test_beam_density(moves, examples, beam_width, hyp):
+    beam_density = float(hyp.draw(hypothesis.strategies.floats(0.0, 1.0, width=32)))
+    states, golds, _ = moves.init_gold_batch(examples)
+    beam = BeamBatch(moves, states, golds, width=beam_width, density=beam_density)
+    n_state = sum(len(beam) for beam in beam)
+    scores = hyp.draw(ndarrays_of_shape((n_state, moves.n_moves)))
+    beam.advance(scores)
+    for b in beam:
+        beam_probs = b.probs
+        assert b.min_density == beam_density
+        assert beam_probs[-1] >= beam_probs[0] * beam_density
diff --git a/spacy/tests/parser/test_preset_sbd.py b/spacy/tests/parser/test_preset_sbd.py
index ab58ac17b..595bfa537 100644
--- a/spacy/tests/parser/test_preset_sbd.py
+++ b/spacy/tests/parser/test_preset_sbd.py
@@ -22,6 +22,7 @@ def _parser_example(parser):
 
 @pytest.fixture
 def parser(vocab):
+    vocab.strings.add("ROOT")
     config = {
         "learn_tokens": False,
         "min_action_freq": 30,
@@ -76,13 +77,16 @@ def test_sents_1_2(parser):
 
 def test_sents_1_3(parser):
     doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
-    doc[1].sent_start = True
-    doc[3].sent_start = True
+    doc[0].is_sent_start = True
+    doc[1].is_sent_start = True
+    doc[2].is_sent_start = None
+    doc[3].is_sent_start = True
     doc = parser(doc)
     assert len(list(doc.sents)) >= 3
     doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
-    doc[1].sent_start = True
-    doc[2].sent_start = False
-    doc[3].sent_start = True
+    doc[0].is_sent_start = True
+    doc[1].is_sent_start = True
+    doc[2].is_sent_start = False
+    doc[3].is_sent_start = True
     doc = parser(doc)
     assert len(list(doc.sents)) == 3
diff --git a/spacy/tests/parser/test_state.py b/spacy/tests/parser/test_state.py
new file mode 100644
index 000000000..7cd4b98e1
--- /dev/null
+++ b/spacy/tests/parser/test_state.py
@@ -0,0 +1,74 @@
+import pytest
+
+from spacy.tokens.doc import Doc
+from spacy.vocab import Vocab
+from spacy.pipeline._parser_internals.stateclass import StateClass
+
+@pytest.fixture
+def vocab():
+    return Vocab()
+
+@pytest.fixture
+def doc(vocab):
+    return Doc(vocab, words=["a", "b", "c", "d"])
+
+def test_init_state(doc):
+    state = StateClass(doc)
+    assert state.stack == []
+    assert state.queue == list(range(len(doc)))
+    assert not state.is_final()
+    assert state.buffer_length() == 4
+
+def test_push_pop(doc):
+    state = StateClass(doc)
+    state.push()
+    assert state.buffer_length() == 3
+    assert state.stack == [0]
+    assert 0 not in state.queue
+    state.push()
+    assert state.stack == [1, 0]
+    assert 1 not in state.queue
+    assert state.buffer_length() == 2
+    state.pop()
+    assert state.stack == [0]
+    assert 1 not in state.queue
+
+def test_stack_depth(doc):
+    state = StateClass(doc)
+    assert state.stack_depth() == 0
+    assert state.buffer_length() == len(doc)
+    state.push()
+    assert state.buffer_length() == 3
+    assert state.stack_depth() == 1
+
+
+def test_H(doc):
+    state = StateClass(doc)
+    assert state.H(0) == -1
+    state.add_arc(1, 0, 0)
+    assert state.arcs == [{"head": 1, "child": 0, "label": 0}]
+    assert state.H(0) == 1
+    state.add_arc(3, 1, 0)
+    assert state.H(1) == 3
+
+
+def test_L(doc):
+    state = StateClass(doc)
+    assert state.L(2, 1) == -1
+    state.add_arc(2, 1, 0)
+    assert state.arcs == [{"head": 2, "child": 1, "label": 0}]
+    assert state.L(2, 1) == 1
+    state.add_arc(2, 0, 0)
+    assert state.L(2, 1) == 0
+    assert state.n_L(2) == 2
+
+
+def test_R(doc):
+    state = StateClass(doc)
+    assert state.R(0, 1) == -1
+    state.add_arc(0, 1, 0)
+    assert state.arcs == [{"head": 0, "child": 1, "label": 0}]
+    assert state.R(0, 1) == 1
+    state.add_arc(0, 2, 0)
+    assert state.R(0, 1) == 2
+    assert state.n_R(0) == 2
diff --git a/spacy/tests/regression/test_issue4001-4500.py b/spacy/tests/regression/test_issue4001-4500.py
index 73aea5b4b..873ef9c1d 100644
--- a/spacy/tests/regression/test_issue4001-4500.py
+++ b/spacy/tests/regression/test_issue4001-4500.py
@@ -122,7 +122,8 @@ def test_issue4042_bug2():
     assert "SOME_LABEL" in ner1.labels
     apple_ent = Span(doc1, 5, 6, label="MY_ORG")
     doc1.ents = list(doc1.ents) + [apple_ent]
-    # reapply the NER - at this point it should resize itself
+    # Add the label explicitly. Previously we didn't require this.
+    ner1.add_label("MY_ORG")
     ner1(doc1)
     assert len(ner1.labels) == 2
     assert "SOME_LABEL" in ner1.labels
diff --git a/spacy/tests/serialize/test_serialize_pipeline.py b/spacy/tests/serialize/test_serialize_pipeline.py
index 951dd3035..2deaa180d 100644
--- a/spacy/tests/serialize/test_serialize_pipeline.py
+++ b/spacy/tests/serialize/test_serialize_pipeline.py
@@ -22,6 +22,9 @@ def parser(en_vocab):
         "learn_tokens": False,
         "min_action_freq": 30,
         "update_with_oracle_cut_size": 100,
+        "beam_width": 1,
+        "beam_update_prob": 1.0,
+        "beam_density": 0.0
     }
     cfg = {"model": DEFAULT_PARSER_MODEL}
     model = registry.resolve(cfg, validate=True)["model"]
@@ -36,6 +39,9 @@ def blank_parser(en_vocab):
         "learn_tokens": False,
         "min_action_freq": 30,
         "update_with_oracle_cut_size": 100,
+        "beam_width": 1,
+        "beam_update_prob": 1.0,
+        "beam_density": 0.0
     }
     cfg = {"model": DEFAULT_PARSER_MODEL}
     model = registry.resolve(cfg, validate=True)["model"]
@@ -58,6 +64,9 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
         "learn_tokens": False,
         "min_action_freq": 0,
         "update_with_oracle_cut_size": 100,
+        "beam_width": 1,
+        "beam_update_prob": 1.0,
+        "beam_density": 0.0
     }
     cfg = {"model": DEFAULT_PARSER_MODEL}
     model = registry.resolve(cfg, validate=True)["model"]
@@ -79,6 +88,9 @@ def test_serialize_parser_strings(Parser):
         "learn_tokens": False,
         "min_action_freq": 0,
         "update_with_oracle_cut_size": 100,
+        "beam_width": 1,
+        "beam_update_prob": 1.0,
+        "beam_density": 0.0
     }
     cfg = {"model": DEFAULT_PARSER_MODEL}
     model = registry.resolve(cfg, validate=True)["model"]
@@ -98,6 +110,9 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
         "learn_tokens": False,
         "min_action_freq": 0,
         "update_with_oracle_cut_size": 100,
+        "beam_width": 1,
+        "beam_update_prob": 1.0,
+        "beam_density": 0.0
     }
     cfg = {"model": DEFAULT_PARSER_MODEL}
     model = registry.resolve(cfg, validate=True)["model"]
diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx
index 6a556b5e7..21907e7dd 100644
--- a/spacy/training/example.pyx
+++ b/spacy/training/example.pyx
@@ -191,6 +191,24 @@ cdef class Example:
                     aligned_deps[cand_i] = deps[gold_i]
         return aligned_heads, aligned_deps
 
+    def get_aligned_sent_starts(self):
+        """Get list of SENT_START attributes aligned to the predicted tokenization.
+        If the reference has not sentence starts, return a list of None values.
+
+        The aligned sentence starts use the get_aligned_spans method, rather
+        than aligning the list of tags, so that it handles cases where a mistaken
+        tokenization starts the sentence.
+        """
+        if self.y.has_annotation("SENT_START"):
+            align = self.alignment.y2x
+            sent_starts = [False] * len(self.x)
+            for y_sent in self.y.sents:
+                x_start = int(align[y_sent.start].dataXd[0])
+                sent_starts[x_start] = True
+            return sent_starts
+        else:
+            return [None] * len(self.x)
+
     def get_aligned_spans_x2y(self, x_spans):
         return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)