mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add beam_parser and beam_ner components for v3 (#6369)
* Get basic beam tests working * Get basic beam tests working * Compile _beam_utils * Remove prints * Test beam density * Beam parser seems to train * Draft beam NER * Upd beam * Add hypothesis as dev dependency * Implement missing is-gold-parse method * Implement early update * Fix state hashing * Fix test * Fix test * Default to non-beam in parser constructor * Improve oracle for beam * Start refactoring beam * Update test * Refactor beam * Update nn * Refactor beam and weight by cost * Update ner beam settings * Update test * Add __init__.pxd * Upd test * Fix test * Upd test * Fix test * Remove ring buffer history from StateC * WIP change arc-eager transitions * Add state tests * Support ternary sent start values * Fix arc eager * Fix NER * Pass oracle cut size for beam * Fix ner test * Fix beam * Improve StateC.clone * Improve StateClass.borrow * Work directly with StateC, not StateClass * Remove print statements * Fix state copy * Improve state class * Refactor parser oracles * Fix arc eager oracle * Fix arc eager oracle * Use a vector to implement the stack * Refactor state data structure * Fix alignment of sent start * Add get_aligned_sent_starts method * Add test for ae oracle when bad sentence starts * Fix sentence segment handling * Avoid Reduce that inserts illegal sentence * Update preset SBD test * Fix test * Remove prints * Fix sent starts in Example * Improve python API of StateClass * Tweak comments and debug output of arc eager * Upd test * Fix state test * Fix state test
This commit is contained in:
		
							parent
							
								
									85ca8c2bdd
								
							
						
					
					
						commit
						8656a08777
					
				| 
						 | 
					@ -27,3 +27,4 @@ pytest>=4.6.5
 | 
				
			||||||
pytest-timeout>=1.3.0,<2.0.0
 | 
					pytest-timeout>=1.3.0,<2.0.0
 | 
				
			||||||
mock>=2.0.0,<3.0.0
 | 
					mock>=2.0.0,<3.0.0
 | 
				
			||||||
flake8>=3.5.0,<3.6.0
 | 
					flake8>=3.5.0,<3.6.0
 | 
				
			||||||
 | 
					hypothesis
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							| 
						 | 
					@ -48,6 +48,7 @@ MOD_NAMES = [
 | 
				
			||||||
    "spacy.pipeline._parser_internals._state",
 | 
					    "spacy.pipeline._parser_internals._state",
 | 
				
			||||||
    "spacy.pipeline._parser_internals.stateclass",
 | 
					    "spacy.pipeline._parser_internals.stateclass",
 | 
				
			||||||
    "spacy.pipeline._parser_internals.transition_system",
 | 
					    "spacy.pipeline._parser_internals.transition_system",
 | 
				
			||||||
 | 
					    "spacy.pipeline._parser_internals._beam_utils",
 | 
				
			||||||
    "spacy.tokenizer",
 | 
					    "spacy.tokenizer",
 | 
				
			||||||
    "spacy.training.align",
 | 
					    "spacy.training.align",
 | 
				
			||||||
    "spacy.training.gold_io",
 | 
					    "spacy.training.gold_io",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										0
									
								
								spacy/pipeline/_parser_internals/__init__.pxd
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								spacy/pipeline/_parser_internals/__init__.pxd
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										6
									
								
								spacy/pipeline/_parser_internals/_beam_utils.pxd
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								spacy/pipeline/_parser_internals/_beam_utils.pxd
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,6 @@
 | 
				
			||||||
 | 
					from ...typedefs cimport class_t, hash_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# These are passed as callbacks to thinc.search.Beam
 | 
				
			||||||
 | 
					cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int check_final_state(void* _state, void* extra_args) except -1
 | 
				
			||||||
							
								
								
									
										296
									
								
								spacy/pipeline/_parser_internals/_beam_utils.pyx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										296
									
								
								spacy/pipeline/_parser_internals/_beam_utils.pyx
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,296 @@
 | 
				
			||||||
 | 
					# cython: infer_types=True
 | 
				
			||||||
 | 
					# cython: profile=True
 | 
				
			||||||
 | 
					cimport numpy as np
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
 | 
					from cpython.ref cimport PyObject, Py_XDECREF
 | 
				
			||||||
 | 
					from thinc.extra.search cimport Beam
 | 
				
			||||||
 | 
					from thinc.extra.search import MaxViolation
 | 
				
			||||||
 | 
					from thinc.extra.search cimport MaxViolation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...typedefs cimport hash_t, class_t
 | 
				
			||||||
 | 
					from .transition_system cimport TransitionSystem, Transition
 | 
				
			||||||
 | 
					from ...errors import Errors
 | 
				
			||||||
 | 
					from .stateclass cimport StateC, StateClass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# These are passed as callbacks to thinc.search.Beam
 | 
				
			||||||
 | 
					cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
 | 
				
			||||||
 | 
					    dest = <StateC*>_dest
 | 
				
			||||||
 | 
					    src = <StateC*>_src
 | 
				
			||||||
 | 
					    moves = <const Transition*>_moves
 | 
				
			||||||
 | 
					    dest.clone(src)
 | 
				
			||||||
 | 
					    moves[clas].do(dest, moves[clas].label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int check_final_state(void* _state, void* extra_args) except -1:
 | 
				
			||||||
 | 
					    state = <StateC*>_state
 | 
				
			||||||
 | 
					    return state.is_final()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class BeamBatch(object):
 | 
				
			||||||
 | 
					    cdef public TransitionSystem moves
 | 
				
			||||||
 | 
					    cdef public object states
 | 
				
			||||||
 | 
					    cdef public object docs
 | 
				
			||||||
 | 
					    cdef public object golds
 | 
				
			||||||
 | 
					    cdef public object beams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, TransitionSystem moves, states, golds,
 | 
				
			||||||
 | 
					                 int width, float density=0.):
 | 
				
			||||||
 | 
					        cdef StateClass state
 | 
				
			||||||
 | 
					        self.moves = moves
 | 
				
			||||||
 | 
					        self.states = states
 | 
				
			||||||
 | 
					        self.docs = [state.doc for state in states]
 | 
				
			||||||
 | 
					        self.golds = golds
 | 
				
			||||||
 | 
					        self.beams = []
 | 
				
			||||||
 | 
					        cdef Beam beam
 | 
				
			||||||
 | 
					        cdef StateC* st
 | 
				
			||||||
 | 
					        for state in states:
 | 
				
			||||||
 | 
					            beam = Beam(self.moves.n_moves, width, min_density=density)
 | 
				
			||||||
 | 
					            beam.initialize(self.moves.init_beam_state,
 | 
				
			||||||
 | 
					                            self.moves.del_beam_state, state.c.length,
 | 
				
			||||||
 | 
					                            <void*>state.c._sent)
 | 
				
			||||||
 | 
					            for i in range(beam.width):
 | 
				
			||||||
 | 
					                st = <StateC*>beam.at(i)
 | 
				
			||||||
 | 
					                st.offset = state.c.offset
 | 
				
			||||||
 | 
					            beam.check_done(check_final_state, NULL)
 | 
				
			||||||
 | 
					            self.beams.append(beam)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def is_done(self):
 | 
				
			||||||
 | 
					        return all(b.is_done for b in self.beams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, i):
 | 
				
			||||||
 | 
					        return self.beams[i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.beams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_states(self):
 | 
				
			||||||
 | 
					        cdef Beam beam
 | 
				
			||||||
 | 
					        cdef StateC* state
 | 
				
			||||||
 | 
					        cdef StateClass stcls
 | 
				
			||||||
 | 
					        states = []
 | 
				
			||||||
 | 
					        for beam, doc in zip(self, self.docs):
 | 
				
			||||||
 | 
					            for i in range(beam.size):
 | 
				
			||||||
 | 
					                state = <StateC*>beam.at(i)
 | 
				
			||||||
 | 
					                stcls = StateClass.borrow(state, doc)
 | 
				
			||||||
 | 
					                states.append(stcls)
 | 
				
			||||||
 | 
					        return states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_unfinished_states(self):
 | 
				
			||||||
 | 
					        return [st for st in self.get_states() if not st.is_final()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def advance(self, float[:, ::1] scores, follow_gold=False):
 | 
				
			||||||
 | 
					        cdef Beam beam
 | 
				
			||||||
 | 
					        cdef int nr_class = scores.shape[1]
 | 
				
			||||||
 | 
					        cdef const float* c_scores = &scores[0, 0]
 | 
				
			||||||
 | 
					        docs = self.docs
 | 
				
			||||||
 | 
					        for i, beam in enumerate(self):
 | 
				
			||||||
 | 
					            if not beam.is_done:
 | 
				
			||||||
 | 
					                nr_state = self._set_scores(beam, c_scores, nr_class)
 | 
				
			||||||
 | 
					                assert nr_state
 | 
				
			||||||
 | 
					                if self.golds is not None:
 | 
				
			||||||
 | 
					                    self._set_costs(
 | 
				
			||||||
 | 
					                        beam,
 | 
				
			||||||
 | 
					                        docs[i],
 | 
				
			||||||
 | 
					                        self.golds[i],
 | 
				
			||||||
 | 
					                        follow_gold=follow_gold
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                c_scores += nr_state * nr_class
 | 
				
			||||||
 | 
					                beam.advance(transition_state, NULL, <void*>self.moves.c)
 | 
				
			||||||
 | 
					                beam.check_done(check_final_state, NULL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cdef int _set_scores(self, Beam beam, const float* scores, int nr_class) except -1:
 | 
				
			||||||
 | 
					        cdef int nr_state = 0
 | 
				
			||||||
 | 
					        for i in range(beam.size):
 | 
				
			||||||
 | 
					            state = <StateC*>beam.at(i)
 | 
				
			||||||
 | 
					            if not state.is_final():
 | 
				
			||||||
 | 
					                for j in range(nr_class):
 | 
				
			||||||
 | 
					                    beam.scores[i][j] = scores[nr_state * nr_class + j]
 | 
				
			||||||
 | 
					                self.moves.set_valid(beam.is_valid[i], state)
 | 
				
			||||||
 | 
					                nr_state += 1
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                for j in range(beam.nr_class):
 | 
				
			||||||
 | 
					                    beam.scores[i][j] = 0
 | 
				
			||||||
 | 
					                    beam.costs[i][j] = 0
 | 
				
			||||||
 | 
					        return nr_state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _set_costs(self, Beam beam, doc, gold, int follow_gold=False):
 | 
				
			||||||
 | 
					        cdef const StateC* state
 | 
				
			||||||
 | 
					        for i in range(beam.size):
 | 
				
			||||||
 | 
					            state = <const StateC*>beam.at(i)
 | 
				
			||||||
 | 
					            if state.is_final():
 | 
				
			||||||
 | 
					                for j in range(beam.nr_class):
 | 
				
			||||||
 | 
					                    beam.is_valid[i][j] = 0
 | 
				
			||||||
 | 
					                    beam.costs[i][j] = 9000
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.moves.set_costs(beam.is_valid[i], beam.costs[i],
 | 
				
			||||||
 | 
					                                     state, gold)
 | 
				
			||||||
 | 
					                if follow_gold:
 | 
				
			||||||
 | 
					                    min_cost = 0
 | 
				
			||||||
 | 
					                    for j in range(beam.nr_class):
 | 
				
			||||||
 | 
					                        if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
 | 
				
			||||||
 | 
					                            min_cost = beam.costs[i][j]
 | 
				
			||||||
 | 
					                    for j in range(beam.nr_class):
 | 
				
			||||||
 | 
					                        if beam.costs[i][j] > min_cost:
 | 
				
			||||||
 | 
					                            beam.is_valid[i][j] = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
 | 
				
			||||||
 | 
					    cdef MaxViolation violn
 | 
				
			||||||
 | 
					    pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
 | 
				
			||||||
 | 
					    gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
 | 
				
			||||||
 | 
					    cdef StateClass state
 | 
				
			||||||
 | 
					    beam_maps = []
 | 
				
			||||||
 | 
					    backprops = []
 | 
				
			||||||
 | 
					    violns = [MaxViolation() for _ in range(len(states))]
 | 
				
			||||||
 | 
					    dones = [False for _ in states]
 | 
				
			||||||
 | 
					    while not pbeam.is_done or not gbeam.is_done:
 | 
				
			||||||
 | 
					        # The beam maps let us find the right row in the flattened scores
 | 
				
			||||||
 | 
					        # array for each state. States are identified by (example id,
 | 
				
			||||||
 | 
					        # history). We keep a different beam map for each step (since we'll
 | 
				
			||||||
 | 
					        # have a flat scores array for each step). The beam map will let us
 | 
				
			||||||
 | 
					        # take the per-state losses, and compute the gradient for each (step,
 | 
				
			||||||
 | 
					        # state, class).
 | 
				
			||||||
 | 
					        # Gather all states from the two beams in a list. Some stats may occur
 | 
				
			||||||
 | 
					        # in both beams. To figure out which beam each state belonged to,
 | 
				
			||||||
 | 
					        # we keep two lists of indices, p_indices and g_indices
 | 
				
			||||||
 | 
					        states, p_indices, g_indices, beam_map = get_unique_states(pbeam, gbeam)
 | 
				
			||||||
 | 
					        beam_maps.append(beam_map)
 | 
				
			||||||
 | 
					        if not states:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        # Now that we have our flat list of states, feed them through the model
 | 
				
			||||||
 | 
					        scores, bp_scores = model.begin_update(states)
 | 
				
			||||||
 | 
					        assert scores.size != 0
 | 
				
			||||||
 | 
					        # Store the callbacks for the backward pass
 | 
				
			||||||
 | 
					        backprops.append(bp_scores)
 | 
				
			||||||
 | 
					        # Unpack the scores for the two beams. The indices arrays
 | 
				
			||||||
 | 
					        # tell us which example and state the scores-row refers to.
 | 
				
			||||||
 | 
					        # Now advance the states in the beams. The gold beam is constrained to
 | 
				
			||||||
 | 
					        # to follow only gold analyses.
 | 
				
			||||||
 | 
					        if not pbeam.is_done:
 | 
				
			||||||
 | 
					            pbeam.advance(model.ops.as_contig(scores[p_indices]))
 | 
				
			||||||
 | 
					        if not gbeam.is_done:
 | 
				
			||||||
 | 
					            gbeam.advance(model.ops.as_contig(scores[g_indices]), follow_gold=True)
 | 
				
			||||||
 | 
					        # Track the "maximum violation", to use in the update.
 | 
				
			||||||
 | 
					        for i, violn in enumerate(violns):
 | 
				
			||||||
 | 
					            if not dones[i]:
 | 
				
			||||||
 | 
					                violn.check_crf(pbeam[i], gbeam[i])
 | 
				
			||||||
 | 
					                if pbeam[i].is_done and gbeam[i].is_done:
 | 
				
			||||||
 | 
					                    dones[i] = True
 | 
				
			||||||
 | 
					    histories = []
 | 
				
			||||||
 | 
					    grads = []
 | 
				
			||||||
 | 
					    for violn in violns:
 | 
				
			||||||
 | 
					        if violn.p_hist:
 | 
				
			||||||
 | 
					            histories.append(violn.p_hist + violn.g_hist)
 | 
				
			||||||
 | 
					            d_loss = [d_l * violn.cost for d_l in violn.p_probs + violn.g_probs]
 | 
				
			||||||
 | 
					            grads.append(d_loss)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            histories.append([])
 | 
				
			||||||
 | 
					            grads.append([])
 | 
				
			||||||
 | 
					    loss = 0.0
 | 
				
			||||||
 | 
					    states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, grads)
 | 
				
			||||||
 | 
					    for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
 | 
				
			||||||
 | 
					        loss += (d_scores**2).mean()
 | 
				
			||||||
 | 
					        bp_scores(d_scores)
 | 
				
			||||||
 | 
					    return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def collect_states(beams, docs):
 | 
				
			||||||
 | 
					    cdef StateClass state
 | 
				
			||||||
 | 
					    cdef Beam beam
 | 
				
			||||||
 | 
					    states = []
 | 
				
			||||||
 | 
					    for state_or_beam, doc in zip(beams, docs):
 | 
				
			||||||
 | 
					        if isinstance(state_or_beam, StateClass):
 | 
				
			||||||
 | 
					            states.append(state_or_beam)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            beam = state_or_beam
 | 
				
			||||||
 | 
					            state = StateClass.borrow(<StateC*>beam.at(0), doc)
 | 
				
			||||||
 | 
					            states.append(state)
 | 
				
			||||||
 | 
					    return states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_unique_states(pbeams, gbeams):
 | 
				
			||||||
 | 
					    seen = {}
 | 
				
			||||||
 | 
					    states = []
 | 
				
			||||||
 | 
					    p_indices = []
 | 
				
			||||||
 | 
					    g_indices = []
 | 
				
			||||||
 | 
					    beam_map = {}
 | 
				
			||||||
 | 
					    docs = pbeams.docs
 | 
				
			||||||
 | 
					    cdef Beam pbeam, gbeam
 | 
				
			||||||
 | 
					    if len(pbeams) != len(gbeams):
 | 
				
			||||||
 | 
					        raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
 | 
				
			||||||
 | 
					    for eg_id, (pbeam, gbeam, doc) in enumerate(zip(pbeams, gbeams, docs)):
 | 
				
			||||||
 | 
					        if not pbeam.is_done:
 | 
				
			||||||
 | 
					            for i in range(pbeam.size):
 | 
				
			||||||
 | 
					                state = StateClass.borrow(<StateC*>pbeam.at(i), doc)
 | 
				
			||||||
 | 
					                if not state.is_final():
 | 
				
			||||||
 | 
					                    key = tuple([eg_id] + pbeam.histories[i])
 | 
				
			||||||
 | 
					                    if key in seen:
 | 
				
			||||||
 | 
					                        raise ValueError(Errors.E080.format(key=key))
 | 
				
			||||||
 | 
					                    seen[key] = len(states)
 | 
				
			||||||
 | 
					                    p_indices.append(len(states))
 | 
				
			||||||
 | 
					                    states.append(state)
 | 
				
			||||||
 | 
					            beam_map.update(seen)
 | 
				
			||||||
 | 
					        if not gbeam.is_done:
 | 
				
			||||||
 | 
					            for i in range(gbeam.size):
 | 
				
			||||||
 | 
					                state = StateClass.borrow(<StateC*>gbeam.at(i), doc)
 | 
				
			||||||
 | 
					                if not state.is_final():
 | 
				
			||||||
 | 
					                    key = tuple([eg_id] + gbeam.histories[i])
 | 
				
			||||||
 | 
					                    if key in seen:
 | 
				
			||||||
 | 
					                        g_indices.append(seen[key])
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        g_indices.append(len(states))
 | 
				
			||||||
 | 
					                        beam_map[key] = len(states)
 | 
				
			||||||
 | 
					                        states.append(state)
 | 
				
			||||||
 | 
					    p_indices = numpy.asarray(p_indices, dtype='i')
 | 
				
			||||||
 | 
					    g_indices = numpy.asarray(g_indices, dtype='i')
 | 
				
			||||||
 | 
					    return states, p_indices, g_indices, beam_map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_gradient(nr_class, beam_maps, histories, losses):
 | 
				
			||||||
 | 
					    """The global model assigns a loss to each parse. The beam scores
 | 
				
			||||||
 | 
					    are additive, so the same gradient is applied to each action
 | 
				
			||||||
 | 
					    in the history. This gives the gradient of a single *action*
 | 
				
			||||||
 | 
					    for a beam state -- so we have "the gradient of loss for taking
 | 
				
			||||||
 | 
					    action i given history H."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Histories: Each hitory is a list of actions
 | 
				
			||||||
 | 
					    Each candidate has a history
 | 
				
			||||||
 | 
					    Each beam has multiple candidates
 | 
				
			||||||
 | 
					    Each batch has multiple beams
 | 
				
			||||||
 | 
					    So history is list of lists of lists of ints
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    grads = []
 | 
				
			||||||
 | 
					    nr_steps = []
 | 
				
			||||||
 | 
					    for eg_id, hists in enumerate(histories):
 | 
				
			||||||
 | 
					        nr_step = 0
 | 
				
			||||||
 | 
					        for loss, hist in zip(losses[eg_id], hists):
 | 
				
			||||||
 | 
					            assert not numpy.isnan(loss)
 | 
				
			||||||
 | 
					            if loss != 0.0:
 | 
				
			||||||
 | 
					                nr_step = max(nr_step, len(hist))
 | 
				
			||||||
 | 
					        nr_steps.append(nr_step)
 | 
				
			||||||
 | 
					    for i in range(max(nr_steps)):
 | 
				
			||||||
 | 
					        grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
 | 
				
			||||||
 | 
					                                 dtype='f'))
 | 
				
			||||||
 | 
					    if len(histories) != len(losses):
 | 
				
			||||||
 | 
					        raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
 | 
				
			||||||
 | 
					    for eg_id, hists in enumerate(histories):
 | 
				
			||||||
 | 
					        for loss, hist in zip(losses[eg_id], hists):
 | 
				
			||||||
 | 
					            assert not numpy.isnan(loss)
 | 
				
			||||||
 | 
					            if loss == 0.0:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            key = tuple([eg_id])
 | 
				
			||||||
 | 
					            # Adjust loss for length
 | 
				
			||||||
 | 
					            # We need to do this because each state in a short path is scored
 | 
				
			||||||
 | 
					            # multiple times, as we add in the average cost when we run out
 | 
				
			||||||
 | 
					            # of actions.
 | 
				
			||||||
 | 
					            avg_loss = loss / len(hist)
 | 
				
			||||||
 | 
					            loss += avg_loss * (nr_steps[eg_id] - len(hist))
 | 
				
			||||||
 | 
					            for step, clas in enumerate(hist):
 | 
				
			||||||
 | 
					                i = beam_maps[step][key]
 | 
				
			||||||
 | 
					                # In step j, at state i action clas
 | 
				
			||||||
 | 
					                # resulted in loss
 | 
				
			||||||
 | 
					                grads[step][i, clas] += loss
 | 
				
			||||||
 | 
					                key = key + tuple([clas])
 | 
				
			||||||
 | 
					    return grads
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,9 @@
 | 
				
			||||||
from libc.string cimport memcpy, memset
 | 
					from libc.string cimport memcpy, memset
 | 
				
			||||||
from libc.stdlib cimport calloc, free
 | 
					from libc.stdlib cimport calloc, free
 | 
				
			||||||
from libc.stdint cimport uint32_t, uint64_t
 | 
					from libc.stdint cimport uint32_t, uint64_t
 | 
				
			||||||
 | 
					cimport libcpp
 | 
				
			||||||
 | 
					from libcpp.vector cimport vector
 | 
				
			||||||
 | 
					from libcpp.set cimport set
 | 
				
			||||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
 | 
					from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
 | 
				
			||||||
from murmurhash.mrmr cimport hash64
 | 
					from murmurhash.mrmr cimport hash64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,89 +17,48 @@ from ...typedefs cimport attr_t
 | 
				
			||||||
cdef inline bint is_space_token(const TokenC* token) nogil:
 | 
					cdef inline bint is_space_token(const TokenC* token) nogil:
 | 
				
			||||||
    return Lexeme.c_check_flag(token.lex, IS_SPACE)
 | 
					    return Lexeme.c_check_flag(token.lex, IS_SPACE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef struct RingBufferC:
 | 
					cdef struct ArcC:
 | 
				
			||||||
    int[8] data
 | 
					    int head
 | 
				
			||||||
    int i
 | 
					    int child
 | 
				
			||||||
    int default
 | 
					    attr_t label
 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
 | 
					 | 
				
			||||||
    ring.data[ring.i] = value
 | 
					 | 
				
			||||||
    ring.i += 1
 | 
					 | 
				
			||||||
    if ring.i >= 8:
 | 
					 | 
				
			||||||
        ring.i = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
 | 
					 | 
				
			||||||
    if i >= ring.i:
 | 
					 | 
				
			||||||
        return ring.default
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return ring.data[ring.i-i]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef cppclass StateC:
 | 
					cdef cppclass StateC:
 | 
				
			||||||
    int* _stack
 | 
					    int* _heads
 | 
				
			||||||
    int* _buffer
 | 
					    const TokenC* _sent
 | 
				
			||||||
    bint* shifted
 | 
					    vector[int] _stack
 | 
				
			||||||
    TokenC* _sent
 | 
					    vector[int] _rebuffer
 | 
				
			||||||
    SpanC* _ents
 | 
					    vector[SpanC] _ents
 | 
				
			||||||
 | 
					    vector[ArcC] _left_arcs
 | 
				
			||||||
 | 
					    vector[ArcC] _right_arcs
 | 
				
			||||||
 | 
					    vector[libcpp.bool] _unshiftable
 | 
				
			||||||
 | 
					    set[int] _sent_starts
 | 
				
			||||||
    TokenC _empty_token
 | 
					    TokenC _empty_token
 | 
				
			||||||
    RingBufferC _hist
 | 
					 | 
				
			||||||
    int length
 | 
					    int length
 | 
				
			||||||
    int offset
 | 
					    int offset
 | 
				
			||||||
    int _s_i
 | 
					 | 
				
			||||||
    int _b_i
 | 
					    int _b_i
 | 
				
			||||||
    int _e_i
 | 
					 | 
				
			||||||
    int _break
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __init__(const TokenC* sent, int length) nogil:
 | 
					    __init__(const TokenC* sent, int length) nogil:
 | 
				
			||||||
        cdef int PADDING = 5
 | 
					        this._sent = sent
 | 
				
			||||||
        this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
 | 
					        this._heads = <int*>calloc(length, sizeof(int))
 | 
				
			||||||
        this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
 | 
					        if not (this._sent and this._heads):
 | 
				
			||||||
        this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
 | 
					 | 
				
			||||||
        this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
 | 
					 | 
				
			||||||
        this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
 | 
					 | 
				
			||||||
        if not (this._buffer and this._stack and this.shifted
 | 
					 | 
				
			||||||
                and this._sent and this._ents):
 | 
					 | 
				
			||||||
            with gil:
 | 
					            with gil:
 | 
				
			||||||
                PyErr_SetFromErrno(MemoryError)
 | 
					                PyErr_SetFromErrno(MemoryError)
 | 
				
			||||||
                PyErr_CheckSignals()
 | 
					                PyErr_CheckSignals()
 | 
				
			||||||
        memset(&this._hist, 0, sizeof(this._hist))
 | 
					 | 
				
			||||||
        this.offset = 0
 | 
					        this.offset = 0
 | 
				
			||||||
        cdef int i
 | 
					 | 
				
			||||||
        for i in range(length + (PADDING * 2)):
 | 
					 | 
				
			||||||
            this._ents[i].end = -1
 | 
					 | 
				
			||||||
            this._sent[i].l_edge = i
 | 
					 | 
				
			||||||
            this._sent[i].r_edge = i
 | 
					 | 
				
			||||||
        for i in range(PADDING):
 | 
					 | 
				
			||||||
            this._sent[i].lex = &EMPTY_LEXEME
 | 
					 | 
				
			||||||
        this._sent += PADDING
 | 
					 | 
				
			||||||
        this._ents += PADDING
 | 
					 | 
				
			||||||
        this._buffer += PADDING
 | 
					 | 
				
			||||||
        this._stack += PADDING
 | 
					 | 
				
			||||||
        this.shifted += PADDING
 | 
					 | 
				
			||||||
        this.length = length
 | 
					        this.length = length
 | 
				
			||||||
        this._break = -1
 | 
					 | 
				
			||||||
        this._s_i = 0
 | 
					 | 
				
			||||||
        this._b_i = 0
 | 
					        this._b_i = 0
 | 
				
			||||||
        this._e_i = 0
 | 
					 | 
				
			||||||
        for i in range(length):
 | 
					        for i in range(length):
 | 
				
			||||||
            this._buffer[i] = i
 | 
					            this._heads[i] = -1
 | 
				
			||||||
 | 
					            this._unshiftable.push_back(0)
 | 
				
			||||||
        memset(&this._empty_token, 0, sizeof(TokenC))
 | 
					        memset(&this._empty_token, 0, sizeof(TokenC))
 | 
				
			||||||
        this._empty_token.lex = &EMPTY_LEXEME
 | 
					        this._empty_token.lex = &EMPTY_LEXEME
 | 
				
			||||||
        for i in range(length):
 | 
					 | 
				
			||||||
            this._sent[i] = sent[i]
 | 
					 | 
				
			||||||
            this._buffer[i] = i
 | 
					 | 
				
			||||||
        for i in range(length, length+PADDING):
 | 
					 | 
				
			||||||
            this._sent[i].lex = &EMPTY_LEXEME
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __dealloc__():
 | 
					    __dealloc__():
 | 
				
			||||||
        cdef int PADDING = 5
 | 
					        free(this._heads)
 | 
				
			||||||
        free(this._sent - PADDING)
 | 
					 | 
				
			||||||
        free(this._ents - PADDING)
 | 
					 | 
				
			||||||
        free(this._buffer - PADDING)
 | 
					 | 
				
			||||||
        free(this._stack - PADDING)
 | 
					 | 
				
			||||||
        free(this.shifted - PADDING)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void set_context_tokens(int* ids, int n) nogil:
 | 
					    void set_context_tokens(int* ids, int n) nogil:
 | 
				
			||||||
 | 
					        cdef int i, j
 | 
				
			||||||
        if n == 1:
 | 
					        if n == 1:
 | 
				
			||||||
            if this.B(0) >= 0:
 | 
					            if this.B(0) >= 0:
 | 
				
			||||||
                ids[0] = this.B(0)
 | 
					                ids[0] = this.B(0)
 | 
				
			||||||
| 
						 | 
					@ -145,22 +107,18 @@ cdef cppclass StateC:
 | 
				
			||||||
            ids[11] = this.R(this.S(1), 1)
 | 
					            ids[11] = this.R(this.S(1), 1)
 | 
				
			||||||
            ids[12] = this.R(this.S(1), 2)
 | 
					            ids[12] = this.R(this.S(1), 2)
 | 
				
			||||||
        elif n == 6:
 | 
					        elif n == 6:
 | 
				
			||||||
 | 
					            for i in range(6):
 | 
				
			||||||
 | 
					                ids[i] = -1
 | 
				
			||||||
            if this.B(0) >= 0:
 | 
					            if this.B(0) >= 0:
 | 
				
			||||||
                ids[0] = this.B(0)
 | 
					                ids[0] = this.B(0)
 | 
				
			||||||
                ids[1] = this.B(0)-1
 | 
					            if this.entity_is_open():
 | 
				
			||||||
            else:
 | 
					                ent = this.get_ent()
 | 
				
			||||||
                ids[0] = -1
 | 
					                j = 1
 | 
				
			||||||
                ids[1] = -1
 | 
					                for i in range(ent.start, this.B(0)):
 | 
				
			||||||
            ids[2] = this.B(1)
 | 
					                    ids[j] = i
 | 
				
			||||||
            ids[3] = this.E(0)
 | 
					                    j += 1
 | 
				
			||||||
            if ids[3] >= 1:
 | 
					                    if j >= 6:
 | 
				
			||||||
                ids[4] = this.E(0)-1
 | 
					                        break
 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                ids[4] = -1
 | 
					 | 
				
			||||||
            if (ids[3]+1) < this.length:
 | 
					 | 
				
			||||||
                ids[5] = this.E(0)+1
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                ids[5] = -1
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # TODO error =/
 | 
					            # TODO error =/
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
| 
						 | 
					@ -171,329 +129,256 @@ cdef cppclass StateC:
 | 
				
			||||||
                ids[i] = -1
 | 
					                ids[i] = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int S(int i) nogil const:
 | 
					    int S(int i) nogil const:
 | 
				
			||||||
        if i >= this._s_i:
 | 
					        if i >= this._stack.size():
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
        return this._stack[this._s_i - (i+1)]
 | 
					        elif i < 0:
 | 
				
			||||||
 | 
					            return -1
 | 
				
			||||||
 | 
					        return this._stack.at(this._stack.size() - (i+1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int B(int i) nogil const:
 | 
					    int B(int i) nogil const:
 | 
				
			||||||
        if (i + this._b_i) >= this.length:
 | 
					        if i < 0:
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
        return this._buffer[this._b_i + i]
 | 
					        elif i < this._rebuffer.size():
 | 
				
			||||||
 | 
					            return this._rebuffer.at(this._rebuffer.size() - (i+1))
 | 
				
			||||||
    const TokenC* S_(int i) nogil const:
 | 
					        else:
 | 
				
			||||||
        return this.safe_get(this.S(i))
 | 
					            b_i = this._b_i + (i - this._rebuffer.size())
 | 
				
			||||||
 | 
					            if b_i >= this.length:
 | 
				
			||||||
 | 
					                return -1
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return b_i
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const TokenC* B_(int i) nogil const:
 | 
					    const TokenC* B_(int i) nogil const:
 | 
				
			||||||
        return this.safe_get(this.B(i))
 | 
					        return this.safe_get(this.B(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const TokenC* H_(int i) nogil const:
 | 
					 | 
				
			||||||
        return this.safe_get(this.H(i))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const TokenC* E_(int i) nogil const:
 | 
					    const TokenC* E_(int i) nogil const:
 | 
				
			||||||
        return this.safe_get(this.E(i))
 | 
					        return this.safe_get(this.E(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const TokenC* L_(int i, int idx) nogil const:
 | 
					 | 
				
			||||||
        return this.safe_get(this.L(i, idx))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const TokenC* R_(int i, int idx) nogil const:
 | 
					 | 
				
			||||||
        return this.safe_get(this.R(i, idx))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const TokenC* safe_get(int i) nogil const:
 | 
					    const TokenC* safe_get(int i) nogil const:
 | 
				
			||||||
        if i < 0 or i >= this.length:
 | 
					        if i < 0 or i >= this.length:
 | 
				
			||||||
            return &this._empty_token
 | 
					            return &this._empty_token
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return &this._sent[i]
 | 
					            return &this._sent[i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int H(int i) nogil const:
 | 
					    void get_arcs(vector[ArcC]* arcs) nogil const:
 | 
				
			||||||
        if i < 0 or i >= this.length:
 | 
					        for i in range(this._left_arcs.size()):
 | 
				
			||||||
 | 
					            arc = this._left_arcs.at(i)
 | 
				
			||||||
 | 
					            if arc.head != -1 and arc.child != -1:
 | 
				
			||||||
 | 
					                arcs.push_back(arc)
 | 
				
			||||||
 | 
					        for i in range(this._right_arcs.size()):
 | 
				
			||||||
 | 
					            arc = this._right_arcs.at(i)
 | 
				
			||||||
 | 
					            if arc.head != -1 and arc.child != -1:
 | 
				
			||||||
 | 
					                arcs.push_back(arc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int H(int child) nogil const:
 | 
				
			||||||
 | 
					        if child >= this.length or child < 0:
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
        return this._sent[i].head + i
 | 
					        else:
 | 
				
			||||||
 | 
					            return this._heads[child]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int E(int i) nogil const:
 | 
					    int E(int i) nogil const:
 | 
				
			||||||
        if this._e_i <= 0 or this._e_i >= this.length:
 | 
					        if this._ents.size() == 0:
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
        if i < 0 or i >= this._e_i:
 | 
					 | 
				
			||||||
            return -1
 | 
					 | 
				
			||||||
        return this._ents[this._e_i - (i+1)].start
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    int L(int i, int idx) nogil const:
 | 
					 | 
				
			||||||
        if idx < 1:
 | 
					 | 
				
			||||||
            return -1
 | 
					 | 
				
			||||||
        if i < 0 or i >= this.length:
 | 
					 | 
				
			||||||
            return -1
 | 
					 | 
				
			||||||
        cdef const TokenC* target = &this._sent[i]
 | 
					 | 
				
			||||||
        if target.l_kids < <uint32_t>idx:
 | 
					 | 
				
			||||||
            return -1
 | 
					 | 
				
			||||||
        cdef const TokenC* ptr = &this._sent[target.l_edge]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while ptr < target:
 | 
					 | 
				
			||||||
            # If this head is still to the right of us, we can skip to it
 | 
					 | 
				
			||||||
            # No token that's between this token and this head could be our
 | 
					 | 
				
			||||||
            # child.
 | 
					 | 
				
			||||||
            if (ptr.head >= 1) and (ptr + ptr.head) < target:
 | 
					 | 
				
			||||||
                ptr += ptr.head
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            elif ptr + ptr.head == target:
 | 
					 | 
				
			||||||
                idx -= 1
 | 
					 | 
				
			||||||
                if idx == 0:
 | 
					 | 
				
			||||||
                    return ptr - this._sent
 | 
					 | 
				
			||||||
                ptr += 1
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
                ptr += 1
 | 
					            return this._ents.back().start
 | 
				
			||||||
        return -1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int R(int i, int idx) nogil const:
 | 
					    int L(int head, int idx) nogil const:
 | 
				
			||||||
        if idx < 1:
 | 
					        if idx < 1 or this._left_arcs.size() == 0:
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
        if i < 0 or i >= this.length:
 | 
					        cdef vector[int] lefts
 | 
				
			||||||
 | 
					        for i in range(this._left_arcs.size()):
 | 
				
			||||||
 | 
					            arc = this._left_arcs.at(i)
 | 
				
			||||||
 | 
					            if arc.head == head and arc.child != -1 and arc.child < head:
 | 
				
			||||||
 | 
					                lefts.push_back(arc.child)
 | 
				
			||||||
 | 
					        idx = (<int>lefts.size()) - idx
 | 
				
			||||||
 | 
					        if idx < 0:
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
        cdef const TokenC* target = &this._sent[i]
 | 
					 | 
				
			||||||
        if target.r_kids < <uint32_t>idx:
 | 
					 | 
				
			||||||
            return -1
 | 
					 | 
				
			||||||
        cdef const TokenC* ptr = &this._sent[target.r_edge]
 | 
					 | 
				
			||||||
        while ptr > target:
 | 
					 | 
				
			||||||
            # If this head is still to the right of us, we can skip to it
 | 
					 | 
				
			||||||
            # No token that's between this token and this head could be our
 | 
					 | 
				
			||||||
            # child.
 | 
					 | 
				
			||||||
            if (ptr.head < 0) and ((ptr + ptr.head) > target):
 | 
					 | 
				
			||||||
                ptr += ptr.head
 | 
					 | 
				
			||||||
            elif ptr + ptr.head == target:
 | 
					 | 
				
			||||||
                idx -= 1
 | 
					 | 
				
			||||||
                if idx == 0:
 | 
					 | 
				
			||||||
                    return ptr - this._sent
 | 
					 | 
				
			||||||
                ptr -= 1
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
                ptr -= 1
 | 
					            return lefts.at(idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int R(int head, int idx) nogil const:
 | 
				
			||||||
 | 
					        if idx < 1 or this._right_arcs.size() == 0:
 | 
				
			||||||
            return -1
 | 
					            return -1
 | 
				
			||||||
 | 
					        cdef vector[int] rights
 | 
				
			||||||
 | 
					        for i in range(this._right_arcs.size()):
 | 
				
			||||||
 | 
					            arc = this._right_arcs.at(i)
 | 
				
			||||||
 | 
					            if arc.head == head and arc.child != -1 and arc.child > head:
 | 
				
			||||||
 | 
					                rights.push_back(arc.child)
 | 
				
			||||||
 | 
					        idx = (<int>rights.size()) - idx
 | 
				
			||||||
 | 
					        if idx < 0:
 | 
				
			||||||
 | 
					            return -1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return rights.at(idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint empty() nogil const:
 | 
					    bint empty() nogil const:
 | 
				
			||||||
        return this._s_i <= 0
 | 
					        return this._stack.size() == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint eol() nogil const:
 | 
					    bint eol() nogil const:
 | 
				
			||||||
        return this.buffer_length() == 0
 | 
					        return this.buffer_length() == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint at_break() nogil const:
 | 
					 | 
				
			||||||
        return this._break != -1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    bint is_final() nogil const:
 | 
					    bint is_final() nogil const:
 | 
				
			||||||
        return this.stack_depth() <= 0 and this._b_i >= this.length
 | 
					        return this.stack_depth() <= 0 and this.eol()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint has_head(int i) nogil const:
 | 
					    int cannot_sent_start(int word) nogil const:
 | 
				
			||||||
        return this.safe_get(i).head != 0
 | 
					        if word < 0 or word >= this.length:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif this._sent[word].sent_start == -1:
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int n_L(int i) nogil const:
 | 
					    int is_sent_start(int word) nogil const:
 | 
				
			||||||
        return this.safe_get(i).l_kids
 | 
					        if word < 0 or word >= this.length:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif this._sent[word].sent_start == 1:
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					        elif this._sent_starts.count(word) >= 1:
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int n_R(int i) nogil const:
 | 
					    void set_sent_start(int word, int value) nogil:
 | 
				
			||||||
        return this.safe_get(i).r_kids
 | 
					        if value >= 1:
 | 
				
			||||||
 | 
					            this._sent_starts.insert(word)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bint has_head(int child) nogil const:
 | 
				
			||||||
 | 
					        return this._heads[child] >= 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int l_edge(int word) nogil const:
 | 
				
			||||||
 | 
					        return word
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int r_edge(int word) nogil const:
 | 
				
			||||||
 | 
					        return word
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					    int n_L(int head) nogil const:
 | 
				
			||||||
 | 
					        cdef int n = 0
 | 
				
			||||||
 | 
					        for i in range(this._left_arcs.size()):
 | 
				
			||||||
 | 
					            arc = this._left_arcs.at(i) 
 | 
				
			||||||
 | 
					            if arc.head == head and arc.child != -1 and arc.child < arc.head:
 | 
				
			||||||
 | 
					                n += 1
 | 
				
			||||||
 | 
					        return n
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int n_R(int head) nogil const:
 | 
				
			||||||
 | 
					        cdef int n = 0
 | 
				
			||||||
 | 
					        for i in range(this._right_arcs.size()):
 | 
				
			||||||
 | 
					            arc = this._right_arcs.at(i) 
 | 
				
			||||||
 | 
					            if arc.head == head and arc.child != -1 and arc.child > arc.head:
 | 
				
			||||||
 | 
					                n += 1
 | 
				
			||||||
 | 
					        return n
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint stack_is_connected() nogil const:
 | 
					    bint stack_is_connected() nogil const:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint entity_is_open() nogil const:
 | 
					    bint entity_is_open() nogil const:
 | 
				
			||||||
        if this._e_i < 1:
 | 
					        if this._ents.size() == 0:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        return this._ents[this._e_i-1].end == -1
 | 
					        else:
 | 
				
			||||||
 | 
					            return this._ents.back().end == -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int stack_depth() nogil const:
 | 
					    int stack_depth() nogil const:
 | 
				
			||||||
        return this._s_i
 | 
					        return this._stack.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int buffer_length() nogil const:
 | 
					    int buffer_length() nogil const:
 | 
				
			||||||
        if this._break != -1:
 | 
					 | 
				
			||||||
            return this._break - this._b_i
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
        return this.length - this._b_i
 | 
					        return this.length - this._b_i
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    uint64_t hash() nogil const:
 | 
					 | 
				
			||||||
        cdef TokenC[11] sig
 | 
					 | 
				
			||||||
        sig[0] = this.S_(2)[0]
 | 
					 | 
				
			||||||
        sig[1] = this.S_(1)[0]
 | 
					 | 
				
			||||||
        sig[2] = this.R_(this.S(1), 1)[0]
 | 
					 | 
				
			||||||
        sig[3] = this.L_(this.S(0), 1)[0]
 | 
					 | 
				
			||||||
        sig[4] = this.L_(this.S(0), 2)[0]
 | 
					 | 
				
			||||||
        sig[5] = this.S_(0)[0]
 | 
					 | 
				
			||||||
        sig[6] = this.R_(this.S(0), 2)[0]
 | 
					 | 
				
			||||||
        sig[7] = this.R_(this.S(0), 1)[0]
 | 
					 | 
				
			||||||
        sig[8] = this.B_(0)[0]
 | 
					 | 
				
			||||||
        sig[9] = this.E_(0)[0]
 | 
					 | 
				
			||||||
        sig[10] = this.E_(1)[0]
 | 
					 | 
				
			||||||
        return hash64(sig, sizeof(sig), this._s_i) \
 | 
					 | 
				
			||||||
             + hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void push_hist(int act) nogil:
 | 
					 | 
				
			||||||
        ring_push(&this._hist, act+1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    int get_hist(int i) nogil:
 | 
					 | 
				
			||||||
        return ring_get(&this._hist, i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void push() nogil:
 | 
					    void push() nogil:
 | 
				
			||||||
        if this.B(0) != -1:
 | 
					        b0 = this.B(0)
 | 
				
			||||||
            this._stack[this._s_i] = this.B(0)
 | 
					        if this._rebuffer.size():
 | 
				
			||||||
        this._s_i += 1
 | 
					            b0 = this._rebuffer.back()
 | 
				
			||||||
 | 
					            this._rebuffer.pop_back()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            b0 = this._b_i
 | 
				
			||||||
            this._b_i += 1
 | 
					            this._b_i += 1
 | 
				
			||||||
        if this.safe_get(this.B_(0).l_edge).sent_start == 1:
 | 
					        this._stack.push_back(b0)
 | 
				
			||||||
            this.set_break(this.B_(0).l_edge)
 | 
					 | 
				
			||||||
        if this._b_i > this._break:
 | 
					 | 
				
			||||||
            this._break = -1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void pop() nogil:
 | 
					    void pop() nogil:
 | 
				
			||||||
        if this._s_i >= 1:
 | 
					        this._stack.pop_back()
 | 
				
			||||||
            this._s_i -= 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void force_final() nogil:
 | 
					    void force_final() nogil:
 | 
				
			||||||
        # This should only be used in desperate situations, as it may leave
 | 
					        # This should only be used in desperate situations, as it may leave
 | 
				
			||||||
        # the analysis in an unexpected state.
 | 
					        # the analysis in an unexpected state.
 | 
				
			||||||
        this._s_i = 0
 | 
					        this._stack.clear()
 | 
				
			||||||
        this._b_i = this.length
 | 
					        this._b_i = this.length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void unshift() nogil:
 | 
					    void unshift() nogil:
 | 
				
			||||||
        this._b_i -= 1
 | 
					        s0 = this._stack.back()
 | 
				
			||||||
        this._buffer[this._b_i] = this.S(0)
 | 
					        this._unshiftable[s0] = 1
 | 
				
			||||||
        this._s_i -= 1
 | 
					        this._rebuffer.push_back(s0)
 | 
				
			||||||
        this.shifted[this.B(0)] = True
 | 
					        this._stack.pop_back()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int is_unshiftable(int item) nogil const:
 | 
				
			||||||
 | 
					        if item >= this._unshiftable.size():
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return this._unshiftable.at(item)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void set_reshiftable(int item) nogil:
 | 
				
			||||||
 | 
					        if item < this._unshiftable.size():
 | 
				
			||||||
 | 
					            this._unshiftable[item] = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void add_arc(int head, int child, attr_t label) nogil:
 | 
					    void add_arc(int head, int child, attr_t label) nogil:
 | 
				
			||||||
        if this.has_head(child):
 | 
					        if this.has_head(child):
 | 
				
			||||||
            this.del_arc(this.H(child), child)
 | 
					            this.del_arc(this.H(child), child)
 | 
				
			||||||
 | 
					        cdef ArcC arc
 | 
				
			||||||
        cdef int dist = head - child
 | 
					        arc.head = head
 | 
				
			||||||
        this._sent[child].head = dist
 | 
					        arc.child = child
 | 
				
			||||||
        this._sent[child].dep = label
 | 
					        arc.label = label
 | 
				
			||||||
        cdef int i
 | 
					        if head > child:
 | 
				
			||||||
        if child > head:
 | 
					            this._left_arcs.push_back(arc)
 | 
				
			||||||
            this._sent[head].r_kids += 1
 | 
					 | 
				
			||||||
            # Some transition systems can have a word in the buffer have a
 | 
					 | 
				
			||||||
            # rightward child, e.g. from Unshift.
 | 
					 | 
				
			||||||
            this._sent[head].r_edge = this._sent[child].r_edge
 | 
					 | 
				
			||||||
            i = 0
 | 
					 | 
				
			||||||
            while this.has_head(head) and i < this.length:
 | 
					 | 
				
			||||||
                head = this.H(head)
 | 
					 | 
				
			||||||
                this._sent[head].r_edge = this._sent[child].r_edge
 | 
					 | 
				
			||||||
                i += 1 # Guard against infinite loops
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            this._sent[head].l_kids += 1
 | 
					            this._right_arcs.push_back(arc)
 | 
				
			||||||
            this._sent[head].l_edge = this._sent[child].l_edge
 | 
					        this._heads[child] = head
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void del_arc(int h_i, int c_i) nogil:
 | 
					    void del_arc(int h_i, int c_i) nogil:
 | 
				
			||||||
        cdef int dist = h_i - c_i
 | 
					        cdef vector[ArcC]* arcs
 | 
				
			||||||
        cdef TokenC* h = &this._sent[h_i]
 | 
					        if h_i > c_i:
 | 
				
			||||||
        cdef int i = 0
 | 
					            arcs = &this._left_arcs
 | 
				
			||||||
        if c_i > h_i:
 | 
					 | 
				
			||||||
            # this.R_(h_i, 2) returns the second-rightmost child token of h_i
 | 
					 | 
				
			||||||
            # If we have more than 2 rightmost children, our 2nd rightmost child's
 | 
					 | 
				
			||||||
            # rightmost edge is going to be our new rightmost edge.
 | 
					 | 
				
			||||||
            h.r_edge = this.R_(h_i, 2).r_edge if h.r_kids >= 2 else h_i
 | 
					 | 
				
			||||||
            h.r_kids -= 1
 | 
					 | 
				
			||||||
            new_edge = h.r_edge
 | 
					 | 
				
			||||||
            # Correct upwards in the tree --- see Issue #251
 | 
					 | 
				
			||||||
            while h.head < 0 and i < this.length: # Guard infinite loop
 | 
					 | 
				
			||||||
                h += h.head
 | 
					 | 
				
			||||||
                h.r_edge = new_edge
 | 
					 | 
				
			||||||
                i += 1
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # Same logic applies for left edge, but we don't need to walk up
 | 
					            arcs = &this._right_arcs
 | 
				
			||||||
            # the tree, as the head is off the stack.
 | 
					        if arcs.size() == 0:
 | 
				
			||||||
            h.l_edge = this.L_(h_i, 2).l_edge if h.l_kids >= 2 else h_i
 | 
					            return
 | 
				
			||||||
            h.l_kids -= 1
 | 
					        arc = arcs.back()
 | 
				
			||||||
 | 
					        if arc.head == h_i and arc.child == c_i:
 | 
				
			||||||
 | 
					            arcs.pop_back()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for i in range(arcs.size()-1):
 | 
				
			||||||
 | 
					                arc = arcs.at(i)
 | 
				
			||||||
 | 
					                if arc.head == h_i and arc.child == c_i:
 | 
				
			||||||
 | 
					                    arc.head = -1
 | 
				
			||||||
 | 
					                    arc.child = -1
 | 
				
			||||||
 | 
					                    arc.label = 0
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SpanC get_ent() nogil const:
 | 
				
			||||||
 | 
					        cdef SpanC ent
 | 
				
			||||||
 | 
					        if this._ents.size() == 0:
 | 
				
			||||||
 | 
					            ent.start = 0
 | 
				
			||||||
 | 
					            ent.end = 0
 | 
				
			||||||
 | 
					            ent.label = 0
 | 
				
			||||||
 | 
					            return ent
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return this._ents.back()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void open_ent(attr_t label) nogil:
 | 
					    void open_ent(attr_t label) nogil:
 | 
				
			||||||
        this._ents[this._e_i].start = this.B(0)
 | 
					        cdef SpanC ent
 | 
				
			||||||
        this._ents[this._e_i].label = label
 | 
					        ent.start = this.B(0)
 | 
				
			||||||
        this._ents[this._e_i].end = -1
 | 
					        ent.label = label
 | 
				
			||||||
        this._e_i += 1
 | 
					        ent.end = -1
 | 
				
			||||||
 | 
					        this._ents.push_back(ent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void close_ent() nogil:
 | 
					    void close_ent() nogil:
 | 
				
			||||||
        # Note that we don't decrement _e_i here! We want to maintain all
 | 
					        this._ents.back().end = this.B(0)+1
 | 
				
			||||||
        # entities, not over-write them...
 | 
					 | 
				
			||||||
        this._ents[this._e_i-1].end = this.B(0)+1
 | 
					 | 
				
			||||||
        this._sent[this.B(0)].ent_iob = 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void set_ent_tag(int i, int ent_iob, attr_t ent_type) nogil:
 | 
					 | 
				
			||||||
        if 0 <= i < this.length:
 | 
					 | 
				
			||||||
            this._sent[i].ent_iob = ent_iob
 | 
					 | 
				
			||||||
            this._sent[i].ent_type = ent_type
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    void set_break(int i) nogil:
 | 
					 | 
				
			||||||
        if 0 <= i < this.length:
 | 
					 | 
				
			||||||
            this._sent[i].sent_start = 1
 | 
					 | 
				
			||||||
            this._break = this._b_i
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void clone(const StateC* src) nogil:
 | 
					    void clone(const StateC* src) nogil:
 | 
				
			||||||
        this.length = src.length
 | 
					        this.length = src.length
 | 
				
			||||||
        memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
 | 
					        this._sent = src._sent
 | 
				
			||||||
        memcpy(this._stack, src._stack, this.length * sizeof(int))
 | 
					        this._stack = src._stack
 | 
				
			||||||
        memcpy(this._buffer, src._buffer, this.length * sizeof(int))
 | 
					        this._rebuffer = src._rebuffer
 | 
				
			||||||
        memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
 | 
					        this._sent_starts = src._sent_starts
 | 
				
			||||||
        memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
 | 
					        this._unshiftable = src._unshiftable
 | 
				
			||||||
 | 
					        memcpy(this._heads, src._heads, this.length * sizeof(this._heads[0]))
 | 
				
			||||||
 | 
					        this._ents = src._ents
 | 
				
			||||||
 | 
					        this._left_arcs = src._left_arcs
 | 
				
			||||||
 | 
					        this._right_arcs = src._right_arcs
 | 
				
			||||||
        this._b_i = src._b_i
 | 
					        this._b_i = src._b_i
 | 
				
			||||||
        this._s_i = src._s_i
 | 
					 | 
				
			||||||
        this._e_i = src._e_i
 | 
					 | 
				
			||||||
        this._break = src._break
 | 
					 | 
				
			||||||
        this.offset = src.offset
 | 
					        this.offset = src.offset
 | 
				
			||||||
        this._empty_token = src._empty_token
 | 
					        this._empty_token = src._empty_token
 | 
				
			||||||
 | 
					 | 
				
			||||||
    void fast_forward() nogil:
 | 
					 | 
				
			||||||
        # space token attachement policy:
 | 
					 | 
				
			||||||
        # - attach space tokens always to the last preceding real token
 | 
					 | 
				
			||||||
        # - except if it's the beginning of a sentence, then attach to the first following
 | 
					 | 
				
			||||||
        # - boundary case: a document containing multiple space tokens but nothing else,
 | 
					 | 
				
			||||||
        #   then make the last space token the head of all others
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while is_space_token(this.B_(0)) \
 | 
					 | 
				
			||||||
        or this.buffer_length() == 0 \
 | 
					 | 
				
			||||||
        or this.stack_depth() == 0:
 | 
					 | 
				
			||||||
            if this.buffer_length() == 0:
 | 
					 | 
				
			||||||
                # remove the last sentence's root from the stack
 | 
					 | 
				
			||||||
                if this.stack_depth() == 1:
 | 
					 | 
				
			||||||
                    this.pop()
 | 
					 | 
				
			||||||
                # parser got stuck: reduce stack or unshift
 | 
					 | 
				
			||||||
                elif this.stack_depth() > 1:
 | 
					 | 
				
			||||||
                    if this.has_head(this.S(0)):
 | 
					 | 
				
			||||||
                        this.pop()
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        this.unshift()
 | 
					 | 
				
			||||||
                # stack is empty but there is another sentence on the buffer
 | 
					 | 
				
			||||||
                elif (this.length - this._b_i) >= 1:
 | 
					 | 
				
			||||||
                    this.push()
 | 
					 | 
				
			||||||
                else: # stack empty and nothing else coming
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            elif is_space_token(this.B_(0)):
 | 
					 | 
				
			||||||
                # the normal case: we're somewhere inside a sentence
 | 
					 | 
				
			||||||
                if this.stack_depth() > 0:
 | 
					 | 
				
			||||||
                    # assert not is_space_token(this.S_(0))
 | 
					 | 
				
			||||||
                    # attach all coming space tokens to their last preceding
 | 
					 | 
				
			||||||
                    # real token (which should be on the top of the stack)
 | 
					 | 
				
			||||||
                    while is_space_token(this.B_(0)):
 | 
					 | 
				
			||||||
                        this.add_arc(this.S(0),this.B(0),0)
 | 
					 | 
				
			||||||
                        this.push()
 | 
					 | 
				
			||||||
                        this.pop()
 | 
					 | 
				
			||||||
                # the rare case: we're at the beginning of a document:
 | 
					 | 
				
			||||||
                # space tokens are attached to the first real token on the buffer
 | 
					 | 
				
			||||||
                elif this.stack_depth() == 0:
 | 
					 | 
				
			||||||
                    # store all space tokens on the stack until a real token shows up
 | 
					 | 
				
			||||||
                    # or the last token on the buffer is reached
 | 
					 | 
				
			||||||
                    while is_space_token(this.B_(0)) and this.buffer_length() > 1:
 | 
					 | 
				
			||||||
                        this.push()
 | 
					 | 
				
			||||||
                    # empty the stack by attaching all space tokens to the
 | 
					 | 
				
			||||||
                    # first token on the buffer
 | 
					 | 
				
			||||||
                    # boundary case: if all tokens are space tokens, the last one
 | 
					 | 
				
			||||||
                    # becomes the head of all others
 | 
					 | 
				
			||||||
                    while this.stack_depth() > 0:
 | 
					 | 
				
			||||||
                        this.add_arc(this.B(0),this.S(0),0)
 | 
					 | 
				
			||||||
                        this.pop()
 | 
					 | 
				
			||||||
                    # move the first token onto the stack
 | 
					 | 
				
			||||||
                    this.push()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            elif this.stack_depth() == 0:
 | 
					 | 
				
			||||||
                # for one token sentences (?)
 | 
					 | 
				
			||||||
                if this.buffer_length() == 1:
 | 
					 | 
				
			||||||
                    this.push()
 | 
					 | 
				
			||||||
                    this.pop()
 | 
					 | 
				
			||||||
                # with an empty stack and a non-empty buffer
 | 
					 | 
				
			||||||
                # only shift is valid anyway
 | 
					 | 
				
			||||||
                elif (this.length - this._b_i) >= 1:
 | 
					 | 
				
			||||||
                    this.push()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            else: # can this even happen?
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,7 @@
 | 
				
			||||||
from .stateclass cimport StateClass
 | 
					from ._state cimport StateC
 | 
				
			||||||
from ...typedefs cimport weight_t, attr_t
 | 
					from ...typedefs cimport weight_t, attr_t
 | 
				
			||||||
from .transition_system cimport Transition, TransitionSystem
 | 
					from .transition_system cimport Transition, TransitionSystem
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class ArcEager(TransitionSystem):
 | 
					cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil
 | 
					 | 
				
			||||||
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,16 +14,11 @@ from ._state cimport StateC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ...errors import Errors
 | 
					from ...errors import Errors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
 | 
					 | 
				
			||||||
cdef int BINARY_COSTS = 1
 | 
					 | 
				
			||||||
cdef weight_t MIN_SCORE = -90000
 | 
					cdef weight_t MIN_SCORE = -90000
 | 
				
			||||||
cdef attr_t SUBTOK_LABEL = hash_string(u'subtok')
 | 
					cdef attr_t SUBTOK_LABEL = hash_string(u'subtok')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DEF NON_MONOTONIC = True
 | 
					DEF NON_MONOTONIC = True
 | 
				
			||||||
DEF USE_BREAK = True
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Break transition from here
 | 
					 | 
				
			||||||
# http://www.aclweb.org/anthology/P13-1074
 | 
					 | 
				
			||||||
cdef enum:
 | 
					cdef enum:
 | 
				
			||||||
    SHIFT
 | 
					    SHIFT
 | 
				
			||||||
    REDUCE
 | 
					    REDUCE
 | 
				
			||||||
| 
						 | 
					@ -61,9 +56,11 @@ cdef struct GoldParseStateC:
 | 
				
			||||||
    int32_t* n_kids
 | 
					    int32_t* n_kids
 | 
				
			||||||
    int32_t length
 | 
					    int32_t length
 | 
				
			||||||
    int32_t stride
 | 
					    int32_t stride
 | 
				
			||||||
 | 
					    weight_t push_cost
 | 
				
			||||||
 | 
					    weight_t pop_cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
 | 
					cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
 | 
				
			||||||
        heads, labels, sent_starts) except *:
 | 
					        heads, labels, sent_starts) except *:
 | 
				
			||||||
    cdef GoldParseStateC gs
 | 
					    cdef GoldParseStateC gs
 | 
				
			||||||
    gs.length = len(heads)
 | 
					    gs.length = len(heads)
 | 
				
			||||||
| 
						 | 
					@ -142,10 +139,12 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls,
 | 
				
			||||||
            if head != i:
 | 
					            if head != i:
 | 
				
			||||||
                gs.kids[head][js[head]] = i
 | 
					                gs.kids[head][js[head]] = i
 | 
				
			||||||
                js[head] += 1
 | 
					                js[head] += 1
 | 
				
			||||||
 | 
					    gs.push_cost = push_cost(state, &gs)
 | 
				
			||||||
 | 
					    gs.pop_cost = pop_cost(state, &gs)
 | 
				
			||||||
    return gs
 | 
					    return gs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
 | 
					cdef void update_gold_state(GoldParseStateC* gs, const StateC* s) nogil:
 | 
				
			||||||
    for i in range(gs.length):
 | 
					    for i in range(gs.length):
 | 
				
			||||||
        gs.state_bits[i] = set_state_flag(
 | 
					        gs.state_bits[i] = set_state_flag(
 | 
				
			||||||
            gs.state_bits[i],
 | 
					            gs.state_bits[i],
 | 
				
			||||||
| 
						 | 
					@ -160,9 +159,9 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
 | 
				
			||||||
        gs.n_kids_in_stack[i] = 0
 | 
					        gs.n_kids_in_stack[i] = 0
 | 
				
			||||||
        gs.n_kids_in_buffer[i] = 0
 | 
					        gs.n_kids_in_buffer[i] = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for i in range(stcls.stack_depth()):
 | 
					    for i in range(s.stack_depth()):
 | 
				
			||||||
        s_i = stcls.S(i)
 | 
					        s_i = s.S(i)
 | 
				
			||||||
        if not is_head_unknown(gs, s_i):
 | 
					        if not is_head_unknown(gs, s_i) and gs.heads[s_i] != s_i:
 | 
				
			||||||
            gs.n_kids_in_stack[gs.heads[s_i]] += 1
 | 
					            gs.n_kids_in_stack[gs.heads[s_i]] += 1
 | 
				
			||||||
        for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
 | 
					        for kid in gs.kids[s_i][:gs.n_kids[s_i]]:
 | 
				
			||||||
            gs.state_bits[kid] = set_state_flag(
 | 
					            gs.state_bits[kid] = set_state_flag(
 | 
				
			||||||
| 
						 | 
					@ -170,9 +169,11 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
 | 
				
			||||||
                HEAD_IN_STACK,
 | 
					                HEAD_IN_STACK,
 | 
				
			||||||
                1
 | 
					                1
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
    for i in range(stcls.buffer_length()):
 | 
					    for i in range(s.buffer_length()):
 | 
				
			||||||
        b_i = stcls.B(i)
 | 
					        b_i = s.B(i)
 | 
				
			||||||
        if not is_head_unknown(gs, b_i):
 | 
					        if s.is_sent_start(b_i):
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        if not is_head_unknown(gs, b_i) and gs.heads[b_i] != b_i:
 | 
				
			||||||
            gs.n_kids_in_buffer[gs.heads[b_i]] += 1
 | 
					            gs.n_kids_in_buffer[gs.heads[b_i]] += 1
 | 
				
			||||||
        for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
 | 
					        for kid in gs.kids[b_i][:gs.n_kids[b_i]]:
 | 
				
			||||||
            gs.state_bits[kid] = set_state_flag(
 | 
					            gs.state_bits[kid] = set_state_flag(
 | 
				
			||||||
| 
						 | 
					@ -180,6 +181,8 @@ cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) nogil:
 | 
				
			||||||
                HEAD_IN_BUFFER,
 | 
					                HEAD_IN_BUFFER,
 | 
				
			||||||
                1
 | 
					                1
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					    gs.push_cost = push_cost(s, gs)
 | 
				
			||||||
 | 
					    gs.pop_cost = pop_cost(s, gs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class ArcEagerGold:
 | 
					cdef class ArcEagerGold:
 | 
				
			||||||
| 
						 | 
					@ -191,17 +194,17 @@ cdef class ArcEagerGold:
 | 
				
			||||||
        heads, labels = example.get_aligned_parse(projectivize=True)
 | 
					        heads, labels = example.get_aligned_parse(projectivize=True)
 | 
				
			||||||
        labels = [label if label is not None else "" for label in labels]
 | 
					        labels = [label if label is not None else "" for label in labels]
 | 
				
			||||||
        labels = [example.x.vocab.strings.add(label) for label in labels]
 | 
					        labels = [example.x.vocab.strings.add(label) for label in labels]
 | 
				
			||||||
        sent_starts = example.get_aligned("SENT_START")
 | 
					        sent_starts = example.get_aligned_sent_starts()
 | 
				
			||||||
        assert len(heads) == len(labels) == len(sent_starts)
 | 
					        assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
 | 
				
			||||||
        self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts)
 | 
					        self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, StateClass stcls):
 | 
					    def update(self, StateClass stcls):
 | 
				
			||||||
        update_gold_state(&self.c, stcls)
 | 
					        update_gold_state(&self.c, stcls.c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef int check_state_gold(char state_bits, char flag) nogil:
 | 
					cdef int check_state_gold(char state_bits, char flag) nogil:
 | 
				
			||||||
    cdef char one = 1
 | 
					    cdef char one = 1
 | 
				
			||||||
    return state_bits & (one << flag)
 | 
					    return 1 if (state_bits & (one << flag)) else 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef int set_state_flag(char state_bits, char flag, int value) nogil:
 | 
					cdef int set_state_flag(char state_bits, char flag, int value) nogil:
 | 
				
			||||||
| 
						 | 
					@ -232,41 +235,30 @@ cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Helper functions for the arc-eager oracle
 | 
					# Helper functions for the arc-eager oracle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil:
 | 
					cdef weight_t push_cost(const StateC* state, const GoldParseStateC* gold) nogil:
 | 
				
			||||||
    gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
    cdef weight_t cost = 0
 | 
					    cdef weight_t cost = 0
 | 
				
			||||||
    if is_head_in_stack(gold, target):
 | 
					    b0 = state.B(0)
 | 
				
			||||||
 | 
					    if b0 < 0:
 | 
				
			||||||
 | 
					        return 9000
 | 
				
			||||||
 | 
					    if is_head_in_stack(gold, b0):
 | 
				
			||||||
        cost += 1
 | 
					        cost += 1
 | 
				
			||||||
    cost += gold.n_kids_in_stack[target]
 | 
					    cost += gold.n_kids_in_stack[b0]
 | 
				
			||||||
    if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
 | 
					    if Break.is_valid(state, 0) and is_sent_start(gold, state.B(1)):
 | 
				
			||||||
        cost += 1
 | 
					        cost += 1
 | 
				
			||||||
    return cost
 | 
					    return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef weight_t pop_cost(StateClass stcls, const void* _gold, int target) nogil:
 | 
					cdef weight_t pop_cost(const StateC* state, const GoldParseStateC* gold) nogil:
 | 
				
			||||||
    gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
    cdef weight_t cost = 0
 | 
					    cdef weight_t cost = 0
 | 
				
			||||||
    if is_head_in_buffer(gold, target):
 | 
					    s0 = state.S(0)
 | 
				
			||||||
        cost += 1
 | 
					    if s0 < 0:
 | 
				
			||||||
    cost += gold[0].n_kids_in_buffer[target]
 | 
					        return 9000
 | 
				
			||||||
    if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
 | 
					    if is_head_in_buffer(gold, s0):
 | 
				
			||||||
        cost += 1
 | 
					        cost += 1
 | 
				
			||||||
 | 
					    cost += gold.n_kids_in_buffer[s0]
 | 
				
			||||||
    return cost
 | 
					    return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) nogil:
 | 
					 | 
				
			||||||
    gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
    if arc_is_gold(gold, head, child):
 | 
					 | 
				
			||||||
        return 0
 | 
					 | 
				
			||||||
    elif stcls.H(child) == gold.heads[child]:
 | 
					 | 
				
			||||||
        return 1
 | 
					 | 
				
			||||||
    # Head in buffer
 | 
					 | 
				
			||||||
    elif is_head_in_buffer(gold, child):
 | 
					 | 
				
			||||||
        return 1
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
 | 
					cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
 | 
				
			||||||
    if is_head_unknown(gold, child):
 | 
					    if is_head_unknown(gold, child):
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
| 
						 | 
					@ -276,7 +268,7 @@ cdef bint arc_is_gold(const GoldParseStateC* gold, int head, int child) nogil:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef bint label_is_gold(const GoldParseStateC* gold, int head, int child, attr_t label) nogil:
 | 
					cdef bint label_is_gold(const GoldParseStateC* gold, int child, attr_t label) nogil:
 | 
				
			||||||
    if is_head_unknown(gold, child):
 | 
					    if is_head_unknown(gold, child):
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    elif label == 0:
 | 
					    elif label == 0:
 | 
				
			||||||
| 
						 | 
					@ -292,218 +284,251 @@ cdef bint _is_gold_root(const GoldParseStateC* gold, int word) nogil:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Shift:
 | 
					cdef class Shift:
 | 
				
			||||||
 | 
					    """Move the first word of the buffer onto the stack and mark it as "shifted"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Validity:
 | 
				
			||||||
 | 
					    * If stack is empty
 | 
				
			||||||
 | 
					    * At least two words in sentence
 | 
				
			||||||
 | 
					    * Word has not been shifted before
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Cost: push_cost 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Action:
 | 
				
			||||||
 | 
					    * Mark B[0] as 'shifted'
 | 
				
			||||||
 | 
					    * Push stack
 | 
				
			||||||
 | 
					    * Advance buffer
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
        sent_start = st._sent[st.B_(0).l_edge].sent_start
 | 
					        if st.stack_depth() == 0:
 | 
				
			||||||
        return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1
 | 
					            return 1
 | 
				
			||||||
 | 
					        elif st.buffer_length() < 2:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif st.is_sent_start(st.B(0)):
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif st.is_unshiftable(st.B(0)):
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.push()
 | 
					        st.push()
 | 
				
			||||||
        st.fast_forward()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass st, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					        gold = <const GoldParseStateC*>_gold
 | 
				
			||||||
        return Shift.move_cost(st, gold) + Shift.label_cost(st, gold, label)
 | 
					        return gold.push_cost
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        return push_cost(s, gold, s.B(0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        return 0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Reduce:
 | 
					cdef class Reduce:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Pop from the stack. If it has no head and the stack isn't empty, place
 | 
				
			||||||
 | 
					    it back on the buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Validity:
 | 
				
			||||||
 | 
					    * Stack not empty
 | 
				
			||||||
 | 
					    * Buffer nt empty
 | 
				
			||||||
 | 
					    * Stack depth 1 and cannot sent start l_edge(st.B(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Cost:
 | 
				
			||||||
 | 
					    * If B[0] is the start of a sentence, cost is 0
 | 
				
			||||||
 | 
					    * Arcs between stack and buffer
 | 
				
			||||||
 | 
					    * If arc has no head, we're saving arcs between S[0] and S[1:], so decrement
 | 
				
			||||||
 | 
					        cost by those arcs.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
        return st.stack_depth() >= 2
 | 
					        if st.stack_depth() == 0:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					 | 
				
			||||||
        if st.has_head(st.S(0)):
 | 
					 | 
				
			||||||
            st.pop()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            st.unshift()
 | 
					 | 
				
			||||||
        st.fast_forward()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t move_cost(StateClass st, const void* _gold) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        s0 = st.S(0)
 | 
					 | 
				
			||||||
        cost = pop_cost(st, gold, s0)
 | 
					 | 
				
			||||||
        return_to_buffer = not st.has_head(s0)
 | 
					 | 
				
			||||||
        if return_to_buffer:
 | 
					 | 
				
			||||||
            # Decrement cost for the arcs we save, as we'll be putting this
 | 
					 | 
				
			||||||
            # back to the buffer
 | 
					 | 
				
			||||||
            if is_head_in_stack(gold, s0):
 | 
					 | 
				
			||||||
                cost -= 1
 | 
					 | 
				
			||||||
            cost -= gold.n_kids_in_stack[s0]
 | 
					 | 
				
			||||||
            if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
 | 
					 | 
				
			||||||
                cost -= 1
 | 
					 | 
				
			||||||
        return cost
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        return 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef class LeftArc:
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					 | 
				
			||||||
        if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
 | 
					 | 
				
			||||||
            return 0
 | 
					 | 
				
			||||||
        sent_start = st._sent[st.B_(0).l_edge].sent_start
 | 
					 | 
				
			||||||
        return sent_start != 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					 | 
				
			||||||
        st.add_arc(st.B(0), st.S(0), label)
 | 
					 | 
				
			||||||
        st.pop()
 | 
					 | 
				
			||||||
        st.fast_forward()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil:
 | 
					 | 
				
			||||||
        cdef weight_t cost = 0
 | 
					 | 
				
			||||||
        s0 = s.S(0)
 | 
					 | 
				
			||||||
        b0 = s.B(0)
 | 
					 | 
				
			||||||
        if arc_is_gold(gold, b0, s0):
 | 
					 | 
				
			||||||
            # Have a negative cost if we 'recover' from the wrong dependency
 | 
					 | 
				
			||||||
            return 0 if not s.has_head(s0) else -1
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # Account for deps we might lose between S0 and stack
 | 
					 | 
				
			||||||
            if not s.has_head(s0):
 | 
					 | 
				
			||||||
                cost += gold.n_kids_in_stack[s0]
 | 
					 | 
				
			||||||
                if is_head_in_buffer(gold, s0):
 | 
					 | 
				
			||||||
                    cost += 1
 | 
					 | 
				
			||||||
            return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef class RightArc:
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					 | 
				
			||||||
        # If there's (perhaps partial) parse pre-set, don't allow cycle.
 | 
					 | 
				
			||||||
        if label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
 | 
					 | 
				
			||||||
            return 0
 | 
					 | 
				
			||||||
        sent_start = st._sent[st.B_(0).l_edge].sent_start
 | 
					 | 
				
			||||||
        return sent_start != 1 and st.H(st.S(0)) != st.B(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					 | 
				
			||||||
        st.add_arc(st.S(0), st.B(0), label)
 | 
					 | 
				
			||||||
        st.push()
 | 
					 | 
				
			||||||
        st.fast_forward()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        if arc_is_gold(gold, s.S(0), s.B(0)):
 | 
					 | 
				
			||||||
            return 0
 | 
					 | 
				
			||||||
        elif s.c.shifted[s.B(0)]:
 | 
					 | 
				
			||||||
            return push_cost(s, gold, s.B(0))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef class Break:
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					 | 
				
			||||||
        cdef int i
 | 
					 | 
				
			||||||
        if not USE_BREAK:
 | 
					 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        elif st.at_break():
 | 
					        elif st.buffer_length() == 0:
 | 
				
			||||||
            return False
 | 
					            return True
 | 
				
			||||||
        elif st.stack_depth() < 1:
 | 
					        elif st.stack_depth() == 1 and st.cannot_sent_start(st.l_edge(st.B(0))):
 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        elif st.B_(0).l_edge < 0:
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        elif st._sent[st.B_(0).l_edge].sent_start < 0:
 | 
					 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return True
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.set_break(st.B_(0).l_edge)
 | 
					        if st.has_head(st.S(0)) or st.stack_depth() == 1:
 | 
				
			||||||
        st.fast_forward()
 | 
					            st.pop()
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        return Break.move_cost(s, gold) + Break.label_cost(s, gold, label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil:
 | 
					 | 
				
			||||||
        gold = <const GoldParseStateC*>_gold
 | 
					 | 
				
			||||||
        cost = 0
 | 
					 | 
				
			||||||
        for i in range(s.stack_depth()):
 | 
					 | 
				
			||||||
            S_i = s.S(i)
 | 
					 | 
				
			||||||
            cost += gold.n_kids_in_buffer[S_i]
 | 
					 | 
				
			||||||
            if is_head_in_buffer(gold, S_i):
 | 
					 | 
				
			||||||
                cost += 1
 | 
					 | 
				
			||||||
        # It's weird not to check the gold sentence boundaries but if we do,
 | 
					 | 
				
			||||||
        # we can't account for "sunk costs", i.e. situations where we're already
 | 
					 | 
				
			||||||
        # wrong.
 | 
					 | 
				
			||||||
        s0_root = _get_root(s.S(0), gold)
 | 
					 | 
				
			||||||
        b0_root = _get_root(s.B(0), gold)
 | 
					 | 
				
			||||||
        if s0_root != b0_root or s0_root == -1 or b0_root == -1:
 | 
					 | 
				
			||||||
            return cost
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return cost + 1
 | 
					            st.unshift()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef inline weight_t label_cost(StateClass s, const void* gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
 | 
					        gold = <const GoldParseStateC*>_gold
 | 
				
			||||||
 | 
					        if state.is_sent_start(state.B(0)):
 | 
				
			||||||
            return 0
 | 
					            return 0
 | 
				
			||||||
 | 
					        s0 = state.S(0)
 | 
				
			||||||
 | 
					        cost = gold.pop_cost
 | 
				
			||||||
 | 
					        if not state.has_head(s0):
 | 
				
			||||||
 | 
					            # Decrement cost for the arcs we save, as we'll be putting this
 | 
				
			||||||
 | 
					            # back to the buffer
 | 
				
			||||||
 | 
					            if is_head_in_stack(gold, s0):
 | 
				
			||||||
 | 
					                cost -= 1
 | 
				
			||||||
 | 
					            cost -= gold.n_kids_in_stack[s0]
 | 
				
			||||||
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef int _get_root(int word, const GoldParseStateC* gold) nogil:
 | 
					
 | 
				
			||||||
    if is_head_unknown(gold, word):
 | 
					cdef class LeftArc:
 | 
				
			||||||
        return -1
 | 
					    """Add an arc between B[0] and S[0], replacing the previous head of S[0] if
 | 
				
			||||||
    while gold.heads[word] != word and word >= 0:
 | 
					    one is set. Pop S[0] from the stack.
 | 
				
			||||||
        word = gold.heads[word]
 | 
					
 | 
				
			||||||
        if is_head_unknown(gold, word):
 | 
					    Validity:
 | 
				
			||||||
            return -1
 | 
					    * len(S) >= 1
 | 
				
			||||||
 | 
					    * len(B) >= 1
 | 
				
			||||||
 | 
					    * not is_sent_start(B[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Cost:
 | 
				
			||||||
 | 
					        pop_cost - Arc(B[0], S[0], label) + (Arc(S[1], S[0]) if H(S[0]) else Arcs(S, S[0]))
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        if st.stack_depth() == 0:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif st.buffer_length() == 0:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif st.is_sent_start(st.B(0)):
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
        return word
 | 
					            return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        st.add_arc(st.B(0), st.S(0), label)
 | 
				
			||||||
 | 
					        # If we change the stack, it's okay to remove the shifted mark, as
 | 
				
			||||||
 | 
					        # we can't get in an infinite loop this way.
 | 
				
			||||||
 | 
					        st.set_reshiftable(st.B(0))
 | 
				
			||||||
 | 
					        st.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
 | 
					        gold = <const GoldParseStateC*>_gold
 | 
				
			||||||
 | 
					        cdef weight_t cost = gold.pop_cost
 | 
				
			||||||
 | 
					        s0 = state.S(0)
 | 
				
			||||||
 | 
					        s1 = state.S(1)
 | 
				
			||||||
 | 
					        b0 = state.B(0)
 | 
				
			||||||
 | 
					        if state.has_head(s0):
 | 
				
			||||||
 | 
					            # Increment cost if we're clobbering a correct arc
 | 
				
			||||||
 | 
					            cost += gold.heads[s0] == s1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # If there's no head, we're losing arcs between S0 and S[1:].
 | 
				
			||||||
 | 
					            cost += is_head_in_stack(gold, s0)
 | 
				
			||||||
 | 
					            cost += gold.n_kids_in_stack[s0]
 | 
				
			||||||
 | 
					        if b0 != -1 and s0 != -1 and gold.heads[s0] == b0:
 | 
				
			||||||
 | 
					            cost -= 1
 | 
				
			||||||
 | 
					            cost += not label_is_gold(gold, s0, label)
 | 
				
			||||||
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class RightArc:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Add an arc from S[0] to B[0]. Push B[0].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Validity:
 | 
				
			||||||
 | 
					    * len(S) >= 1
 | 
				
			||||||
 | 
					    * len(B) >= 1
 | 
				
			||||||
 | 
					    * not is_sent_start(B[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Cost:
 | 
				
			||||||
 | 
					        push_cost + (not shifted[b0] and Arc(B[1:], B[0])) - Arc(S[0], B[0], label)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        if st.stack_depth() == 0:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif st.buffer_length() == 0:
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif st.is_sent_start(st.B(0)):
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        elif label == SUBTOK_LABEL and st.S(0) != (st.B(0)-1):
 | 
				
			||||||
 | 
					            # If there's (perhaps partial) parse pre-set, don't allow cycle.
 | 
				
			||||||
 | 
					            return 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        st.add_arc(st.S(0), st.B(0), label)
 | 
				
			||||||
 | 
					        st.push()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef inline weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
 | 
					        gold = <const GoldParseStateC*>_gold
 | 
				
			||||||
 | 
					        cost = gold.push_cost
 | 
				
			||||||
 | 
					        s0 = state.S(0)
 | 
				
			||||||
 | 
					        b0 = state.B(0)
 | 
				
			||||||
 | 
					        if s0 != -1 and b0 != -1 and gold.heads[b0] == s0:
 | 
				
			||||||
 | 
					            cost -= 1
 | 
				
			||||||
 | 
					            cost += not label_is_gold(gold, b0, label)
 | 
				
			||||||
 | 
					        elif is_head_in_buffer(gold, b0) and not state.is_unshiftable(b0):
 | 
				
			||||||
 | 
					            cost += 1
 | 
				
			||||||
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class Break:
 | 
				
			||||||
 | 
					    """Mark the second word of the buffer as the start of a 
 | 
				
			||||||
 | 
					    sentence. 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Validity:
 | 
				
			||||||
 | 
					    * len(buffer) >= 2
 | 
				
			||||||
 | 
					    * B[1] == B[0] + 1
 | 
				
			||||||
 | 
					    * not is_sent_start(B[1])
 | 
				
			||||||
 | 
					    * not cannot_sent_start(B[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Action:
 | 
				
			||||||
 | 
					    * mark_sent_start(B[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Cost:
 | 
				
			||||||
 | 
					    * not is_sent_start(B[1])
 | 
				
			||||||
 | 
					    * Arcs between B[0] and B[1:]
 | 
				
			||||||
 | 
					    * Arcs between S and B[1]
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        cdef int i
 | 
				
			||||||
 | 
					        if st.buffer_length() < 2:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        elif st.B(1) != st.B(0) + 1:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        elif st.is_sent_start(st.B(1)):
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        elif st.cannot_sent_start(st.B(1)):
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        st.set_sent_start(st.B(1), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    cdef weight_t cost(const StateC* state, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
 | 
					        gold = <const GoldParseStateC*>_gold
 | 
				
			||||||
 | 
					        cdef int b0 = state.B(0)
 | 
				
			||||||
 | 
					        cdef int cost = 0
 | 
				
			||||||
 | 
					        cdef int si
 | 
				
			||||||
 | 
					        for i in range(state.stack_depth()):
 | 
				
			||||||
 | 
					            si = state.S(i)
 | 
				
			||||||
 | 
					            if is_head_in_buffer(gold, si):
 | 
				
			||||||
 | 
					                cost += 1
 | 
				
			||||||
 | 
					            cost += gold.n_kids_in_buffer[si]
 | 
				
			||||||
 | 
					            # We need to score into B[1:], so subtract deps that are at b0
 | 
				
			||||||
 | 
					            if gold.heads[b0] == si:
 | 
				
			||||||
 | 
					                cost -= 1
 | 
				
			||||||
 | 
					            if gold.heads[si] == b0:
 | 
				
			||||||
 | 
					                cost -= 1
 | 
				
			||||||
 | 
					        if not is_sent_start(gold, state.B(1)) \
 | 
				
			||||||
 | 
					        and not is_sent_start_unknown(gold, state.B(1)):
 | 
				
			||||||
 | 
					            cost += 1
 | 
				
			||||||
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
 | 
					cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
 | 
				
			||||||
    st = new StateC(<const TokenC*>tokens, length)
 | 
					    st = new StateC(<const TokenC*>tokens, length)
 | 
				
			||||||
    for i in range(st.length):
 | 
					 | 
				
			||||||
        if st._sent[i].dep == 0:
 | 
					 | 
				
			||||||
            st._sent[i].l_edge = i
 | 
					 | 
				
			||||||
            st._sent[i].r_edge = i
 | 
					 | 
				
			||||||
            st._sent[i].head = 0
 | 
					 | 
				
			||||||
            st._sent[i].dep = 0
 | 
					 | 
				
			||||||
            st._sent[i].l_kids = 0
 | 
					 | 
				
			||||||
            st._sent[i].r_kids = 0
 | 
					 | 
				
			||||||
    st.fast_forward()
 | 
					 | 
				
			||||||
    return <void*>st
 | 
					    return <void*>st
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -515,6 +540,8 @@ cdef int _del_state(Pool mem, void* state, void* x) except -1:
 | 
				
			||||||
cdef class ArcEager(TransitionSystem):
 | 
					cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        TransitionSystem.__init__(self, *args, **kwargs)
 | 
					        TransitionSystem.__init__(self, *args, **kwargs)
 | 
				
			||||||
 | 
					        self.init_beam_state = _init_state
 | 
				
			||||||
 | 
					        self.del_beam_state = _del_state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def get_actions(cls, **kwargs):
 | 
					    def get_actions(cls, **kwargs):
 | 
				
			||||||
| 
						 | 
					@ -537,7 +564,7 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
                    label = 'ROOT'
 | 
					                    label = 'ROOT'
 | 
				
			||||||
                if head == child:
 | 
					                if head == child:
 | 
				
			||||||
                    actions[BREAK][label] += 1
 | 
					                    actions[BREAK][label] += 1
 | 
				
			||||||
                elif head < child:
 | 
					                if head < child:
 | 
				
			||||||
                    actions[RIGHT][label] += 1
 | 
					                    actions[RIGHT][label] += 1
 | 
				
			||||||
                    actions[REDUCE][''] += 1
 | 
					                    actions[REDUCE][''] += 1
 | 
				
			||||||
                elif head > child:
 | 
					                elif head > child:
 | 
				
			||||||
| 
						 | 
					@ -567,8 +594,14 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        t.do(state.c, t.label)
 | 
					        t.do(state.c, t.label)
 | 
				
			||||||
        return state
 | 
					        return state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_gold_parse(self, StateClass state, gold):
 | 
					    def is_gold_parse(self, StateClass state, ArcEagerGold gold):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        for i in range(state.c.length):
 | 
				
			||||||
 | 
					            token = state.c.safe_get(i)
 | 
				
			||||||
 | 
					            if not arc_is_gold(&gold.c, i, i+token.head):
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					            elif not label_is_gold(&gold.c, i, token.dep):
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_gold(self, StateClass state, Example example):
 | 
					    def init_gold(self, StateClass state, Example example):
 | 
				
			||||||
        gold = ArcEagerGold(self, state, example)
 | 
					        gold = ArcEagerGold(self, state, example)
 | 
				
			||||||
| 
						 | 
					@ -576,6 +609,7 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        return gold
 | 
					        return gold
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_gold_batch(self, examples):
 | 
					    def init_gold_batch(self, examples):
 | 
				
			||||||
 | 
					        # TODO: Projectivitity?
 | 
				
			||||||
        all_states = self.init_batch([eg.predicted for eg in examples])
 | 
					        all_states = self.init_batch([eg.predicted for eg in examples])
 | 
				
			||||||
        golds = []
 | 
					        golds = []
 | 
				
			||||||
        states = []
 | 
					        states = []
 | 
				
			||||||
| 
						 | 
					@ -662,24 +696,13 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
            raise ValueError(Errors.E019.format(action=move, src='arc_eager'))
 | 
					            raise ValueError(Errors.E019.format(action=move, src='arc_eager'))
 | 
				
			||||||
        return t
 | 
					        return t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int initialize_state(self, StateC* st) nogil:
 | 
					    def set_annotations(self, StateClass state, Doc doc):
 | 
				
			||||||
        for i in range(st.length):
 | 
					        for arc in state.arcs:
 | 
				
			||||||
            if st._sent[i].dep == 0:
 | 
					            doc.c[arc["child"]].head = arc["head"] - arc["child"]
 | 
				
			||||||
                st._sent[i].l_edge = i
 | 
					            doc.c[arc["child"]].dep = arc["label"]
 | 
				
			||||||
                st._sent[i].r_edge = i
 | 
					        for i in range(doc.length):
 | 
				
			||||||
                st._sent[i].head = 0
 | 
					            if doc.c[i].head == 0:
 | 
				
			||||||
                st._sent[i].dep = 0
 | 
					                doc.c[i].dep = self.root_label
 | 
				
			||||||
                st._sent[i].l_kids = 0
 | 
					 | 
				
			||||||
                st._sent[i].r_kids = 0
 | 
					 | 
				
			||||||
        st.fast_forward()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef int finalize_state(self, StateC* st) nogil:
 | 
					 | 
				
			||||||
        cdef int i
 | 
					 | 
				
			||||||
        for i in range(st.length):
 | 
					 | 
				
			||||||
            if st._sent[i].head == 0:
 | 
					 | 
				
			||||||
                st._sent[i].dep = self.root_label
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def finalize_doc(self, Doc doc):
 | 
					 | 
				
			||||||
        set_children_from_heads(doc.c, 0, doc.length)
 | 
					        set_children_from_heads(doc.c, 0, doc.length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def has_gold(self, Example eg, start=0, end=None):
 | 
					    def has_gold(self, Example eg, start=0, end=None):
 | 
				
			||||||
| 
						 | 
					@ -690,7 +713,7 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_valid(self, int* output, const StateC* st) nogil:
 | 
					    cdef int set_valid(self, int* output, const StateC* st) nogil:
 | 
				
			||||||
        cdef bint[N_MOVES] is_valid
 | 
					        cdef int[N_MOVES] is_valid
 | 
				
			||||||
        is_valid[SHIFT] = Shift.is_valid(st, 0)
 | 
					        is_valid[SHIFT] = Shift.is_valid(st, 0)
 | 
				
			||||||
        is_valid[REDUCE] = Reduce.is_valid(st, 0)
 | 
					        is_valid[REDUCE] = Reduce.is_valid(st, 0)
 | 
				
			||||||
        is_valid[LEFT] = LeftArc.is_valid(st, 0)
 | 
					        is_valid[LEFT] = LeftArc.is_valid(st, 0)
 | 
				
			||||||
| 
						 | 
					@ -710,29 +733,31 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        gold_state = gold_.c
 | 
					        gold_state = gold_.c
 | 
				
			||||||
        n_gold = 0
 | 
					        n_gold = 0
 | 
				
			||||||
        if self.c[i].is_valid(stcls.c, self.c[i].label):
 | 
					        if self.c[i].is_valid(stcls.c, self.c[i].label):
 | 
				
			||||||
            cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
 | 
					            cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            cost = 9000
 | 
					            cost = 9000
 | 
				
			||||||
        return cost
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
					    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
				
			||||||
                       StateClass stcls, gold) except -1:
 | 
					                       const StateC* state, gold) except -1:
 | 
				
			||||||
        if not isinstance(gold, ArcEagerGold):
 | 
					        if not isinstance(gold, ArcEagerGold):
 | 
				
			||||||
            raise TypeError(Errors.E909.format(name="ArcEagerGold"))
 | 
					            raise TypeError(Errors.E909.format(name="ArcEagerGold"))
 | 
				
			||||||
        cdef ArcEagerGold gold_ = gold
 | 
					        cdef ArcEagerGold gold_ = gold
 | 
				
			||||||
        gold_.update(stcls)
 | 
					 | 
				
			||||||
        gold_state = gold_.c
 | 
					        gold_state = gold_.c
 | 
				
			||||||
 | 
					        update_gold_state(&gold_state, state)
 | 
				
			||||||
 | 
					        self.set_valid(is_valid, state)
 | 
				
			||||||
        cdef int n_gold = 0
 | 
					        cdef int n_gold = 0
 | 
				
			||||||
        for i in range(self.n_moves):
 | 
					        for i in range(self.n_moves):
 | 
				
			||||||
            if self.c[i].is_valid(stcls.c, self.c[i].label):
 | 
					            if is_valid[i]:
 | 
				
			||||||
                is_valid[i] = True
 | 
					                costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
 | 
				
			||||||
                costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
 | 
					 | 
				
			||||||
                if costs[i] <= 0:
 | 
					                if costs[i] <= 0:
 | 
				
			||||||
                    n_gold += 1
 | 
					                    n_gold += 1
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                is_valid[i] = False
 | 
					 | 
				
			||||||
                costs[i] = 9000
 | 
					                costs[i] = 9000
 | 
				
			||||||
        if n_gold < 1:
 | 
					        if n_gold < 1:
 | 
				
			||||||
 | 
					            for i in range(self.n_moves):
 | 
				
			||||||
 | 
					                print(self.get_class_name(i), is_valid[i], costs[i])
 | 
				
			||||||
 | 
					            print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
 | 
				
			||||||
            raise ValueError
 | 
					            raise ValueError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
 | 
					    def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
 | 
				
			||||||
| 
						 | 
					@ -748,12 +773,13 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        failed = False
 | 
					        failed = False
 | 
				
			||||||
        while not state.is_final():
 | 
					        while not state.is_final():
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                self.set_costs(is_valid, costs, state, gold)
 | 
					                self.set_costs(is_valid, costs, state.c, gold)
 | 
				
			||||||
            except ValueError:
 | 
					            except ValueError:
 | 
				
			||||||
                failed = True
 | 
					                failed = True
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					            min_cost = min(costs[i] for i in range(self.n_moves))
 | 
				
			||||||
            for i in range(self.n_moves):
 | 
					            for i in range(self.n_moves):
 | 
				
			||||||
                if is_valid[i] and costs[i] <= 0:
 | 
					                if is_valid[i] and costs[i] <= min_cost:
 | 
				
			||||||
                    action = self.c[i]
 | 
					                    action = self.c[i]
 | 
				
			||||||
                    history.append(i)
 | 
					                    history.append(i)
 | 
				
			||||||
                    s0 = state.S(0)
 | 
					                    s0 = state.S(0)
 | 
				
			||||||
| 
						 | 
					@ -762,9 +788,7 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
                        example = _debug
 | 
					                        example = _debug
 | 
				
			||||||
                        debug_log.append(" ".join((
 | 
					                        debug_log.append(" ".join((
 | 
				
			||||||
                            self.get_class_name(i),
 | 
					                            self.get_class_name(i),
 | 
				
			||||||
                            "S0=", (example.x[s0].text if s0 >= 0 else "__"),
 | 
					                            state.print_state()
 | 
				
			||||||
                            "B0=", (example.x[b0].text if b0 >= 0 else "__"),
 | 
					 | 
				
			||||||
                            "S0 head?", str(state.has_head(state.S(0))),
 | 
					 | 
				
			||||||
                        )))
 | 
					                        )))
 | 
				
			||||||
                    action.do(state.c, action.label)
 | 
					                    action.do(state.c, action.label)
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
| 
						 | 
					@ -783,6 +807,8 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
            print("Aligned heads")
 | 
					            print("Aligned heads")
 | 
				
			||||||
            for i, head in enumerate(aligned_heads):
 | 
					            for i, head in enumerate(aligned_heads):
 | 
				
			||||||
                print(example.x[i], example.x[head] if head is not None else "__")
 | 
					                print(example.x[i], example.x[head] if head is not None else "__")
 | 
				
			||||||
 | 
					            print("Aligned sent starts")
 | 
				
			||||||
 | 
					            print(example.get_aligned_sent_starts())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            print("Predicted tokens")
 | 
					            print("Predicted tokens")
 | 
				
			||||||
            print([(w.i, w.text) for w in example.x])
 | 
					            print([(w.i, w.text) for w in example.x])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,9 +3,12 @@ from cymem.cymem cimport Pool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections import Counter
 | 
					from collections import Counter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...tokens.doc cimport Doc
 | 
				
			||||||
 | 
					from ...tokens.span import Span
 | 
				
			||||||
from ...typedefs cimport weight_t, attr_t
 | 
					from ...typedefs cimport weight_t, attr_t
 | 
				
			||||||
from ...lexeme cimport Lexeme
 | 
					from ...lexeme cimport Lexeme
 | 
				
			||||||
from ...attrs cimport IS_SPACE
 | 
					from ...attrs cimport IS_SPACE
 | 
				
			||||||
 | 
					from ...structs cimport TokenC
 | 
				
			||||||
from ...training.example cimport Example
 | 
					from ...training.example cimport Example
 | 
				
			||||||
from .stateclass cimport StateClass
 | 
					from .stateclass cimport StateClass
 | 
				
			||||||
from ._state cimport StateC
 | 
					from ._state cimport StateC
 | 
				
			||||||
| 
						 | 
					@ -46,17 +49,17 @@ cdef class BiluoGold:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, BiluoPushDown moves, StateClass stcls, Example example):
 | 
					    def __init__(self, BiluoPushDown moves, StateClass stcls, Example example):
 | 
				
			||||||
        self.mem = Pool()
 | 
					        self.mem = Pool()
 | 
				
			||||||
        self.c = create_gold_state(self.mem, moves, stcls, example)
 | 
					        self.c = create_gold_state(self.mem, moves, stcls.c, example)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, StateClass stcls):
 | 
					    def update(self, StateClass stcls):
 | 
				
			||||||
        update_gold_state(&self.c, stcls)
 | 
					        update_gold_state(&self.c, stcls.c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef GoldNERStateC create_gold_state(
 | 
					cdef GoldNERStateC create_gold_state(
 | 
				
			||||||
    Pool mem,
 | 
					    Pool mem,
 | 
				
			||||||
    BiluoPushDown moves,
 | 
					    BiluoPushDown moves,
 | 
				
			||||||
    StateClass stcls,
 | 
					    const StateC* stcls,
 | 
				
			||||||
    Example example
 | 
					    Example example
 | 
				
			||||||
) except *:
 | 
					) except *:
 | 
				
			||||||
    cdef GoldNERStateC gs
 | 
					    cdef GoldNERStateC gs
 | 
				
			||||||
| 
						 | 
					@ -67,7 +70,7 @@ cdef GoldNERStateC create_gold_state(
 | 
				
			||||||
    return gs
 | 
					    return gs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
 | 
					cdef void update_gold_state(GoldNERStateC* gs, const StateC* state) except *:
 | 
				
			||||||
    # We don't need to update each time, unlike the parser.
 | 
					    # We don't need to update each time, unlike the parser.
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -75,14 +78,15 @@ cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
 | 
				
			||||||
cdef do_func_t[N_MOVES] do_funcs
 | 
					cdef do_func_t[N_MOVES] do_funcs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef bint _entity_is_sunk(StateClass st, Transition* golds) nogil:
 | 
					cdef bint _entity_is_sunk(const StateC* state, Transition* golds) nogil:
 | 
				
			||||||
    if not st.entity_is_open():
 | 
					    if not state.entity_is_open():
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef const Transition* gold = &golds[st.E(0)]
 | 
					    cdef const Transition* gold = &golds[state.E(0)]
 | 
				
			||||||
 | 
					    ent = state.get_ent()
 | 
				
			||||||
    if gold.move != BEGIN and gold.move != UNIT:
 | 
					    if gold.move != BEGIN and gold.move != UNIT:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    elif gold.label != st.E_(0).ent_type:
 | 
					    elif gold.label != ent.label:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
| 
						 | 
					@ -228,15 +232,18 @@ cdef class BiluoPushDown(TransitionSystem):
 | 
				
			||||||
            self.labels[action][label_name] = -1
 | 
					            self.labels[action][label_name] = -1
 | 
				
			||||||
        return 1
 | 
					        return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int initialize_state(self, StateC* st) nogil:
 | 
					    def set_annotations(self, StateClass state, Doc doc):
 | 
				
			||||||
        # This is especially necessary when we use limited training data.
 | 
					        cdef int i
 | 
				
			||||||
        for i in range(st.length):
 | 
					        ents = []
 | 
				
			||||||
            if st._sent[i].ent_type != 0:
 | 
					        for i in range(state.c._ents.size()):
 | 
				
			||||||
                with gil:
 | 
					            ent = state.c._ents.at(i)
 | 
				
			||||||
                    self.add_action(BEGIN, st._sent[i].ent_type)
 | 
					            if ent.start != -1 and ent.end != -1:
 | 
				
			||||||
                    self.add_action(IN, st._sent[i].ent_type)
 | 
					                ents.append(Span(doc, ent.start, ent.end, label=ent.label))
 | 
				
			||||||
                    self.add_action(UNIT, st._sent[i].ent_type)
 | 
					        doc.set_ents(ents, default="unmodified")
 | 
				
			||||||
                    self.add_action(LAST, st._sent[i].ent_type)
 | 
					        # Set non-blocked tokens to O
 | 
				
			||||||
 | 
					        for i in range(doc.length):
 | 
				
			||||||
 | 
					            if doc.c[i].ent_iob == 0:
 | 
				
			||||||
 | 
					                doc.c[i].ent_iob = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_gold(self, StateClass state, Example example):
 | 
					    def init_gold(self, StateClass state, Example example):
 | 
				
			||||||
        return BiluoGold(self, state, example)
 | 
					        return BiluoGold(self, state, example)
 | 
				
			||||||
| 
						 | 
					@ -255,26 +262,25 @@ cdef class BiluoPushDown(TransitionSystem):
 | 
				
			||||||
        gold_state = gold_.c
 | 
					        gold_state = gold_.c
 | 
				
			||||||
        n_gold = 0
 | 
					        n_gold = 0
 | 
				
			||||||
        if self.c[i].is_valid(stcls.c, self.c[i].label):
 | 
					        if self.c[i].is_valid(stcls.c, self.c[i].label):
 | 
				
			||||||
            cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
 | 
					            cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            cost = 9000
 | 
					            cost = 9000
 | 
				
			||||||
        return cost
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
					    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
				
			||||||
                       StateClass stcls, gold) except -1:
 | 
					                       const StateC* state, gold) except -1:
 | 
				
			||||||
        if not isinstance(gold, BiluoGold):
 | 
					        if not isinstance(gold, BiluoGold):
 | 
				
			||||||
            raise TypeError(Errors.E909.format(name="BiluoGold"))
 | 
					            raise TypeError(Errors.E909.format(name="BiluoGold"))
 | 
				
			||||||
        cdef BiluoGold gold_ = gold
 | 
					        cdef BiluoGold gold_ = gold
 | 
				
			||||||
        gold_.update(stcls)
 | 
					 | 
				
			||||||
        gold_state = gold_.c
 | 
					        gold_state = gold_.c
 | 
				
			||||||
 | 
					        update_gold_state(&gold_state, state)
 | 
				
			||||||
        n_gold = 0
 | 
					        n_gold = 0
 | 
				
			||||||
 | 
					        self.set_valid(is_valid, state)
 | 
				
			||||||
        for i in range(self.n_moves):
 | 
					        for i in range(self.n_moves):
 | 
				
			||||||
            if self.c[i].is_valid(stcls.c, self.c[i].label):
 | 
					            if is_valid[i]:
 | 
				
			||||||
                is_valid[i] = 1
 | 
					                costs[i] = self.c[i].get_cost(state, &gold_state, self.c[i].label)
 | 
				
			||||||
                costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
 | 
					 | 
				
			||||||
                n_gold += costs[i] <= 0
 | 
					                n_gold += costs[i] <= 0
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                is_valid[i] = 0
 | 
					 | 
				
			||||||
                costs[i] = 9000
 | 
					                costs[i] = 9000
 | 
				
			||||||
        if n_gold < 1:
 | 
					        if n_gold < 1:
 | 
				
			||||||
            raise ValueError
 | 
					            raise ValueError
 | 
				
			||||||
| 
						 | 
					@ -290,7 +296,7 @@ cdef class Missing:
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        return 9000
 | 
					        return 9000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -299,10 +305,10 @@ cdef class Begin:
 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
        cdef int preset_ent_iob = st.B_(0).ent_iob
 | 
					        cdef int preset_ent_iob = st.B_(0).ent_iob
 | 
				
			||||||
        cdef attr_t preset_ent_label = st.B_(0).ent_type
 | 
					        cdef attr_t preset_ent_label = st.B_(0).ent_type
 | 
				
			||||||
        # If we're the last token of the input, we can't B -- must U or O.
 | 
					        if st.entity_is_open():
 | 
				
			||||||
        if st.B(1) == -1:
 | 
					 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        elif st.entity_is_open():
 | 
					        if st.buffer_length() < 2:
 | 
				
			||||||
 | 
					            # If we're the last token of the input, we can't B -- must U or O.
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        elif label == 0:
 | 
					        elif label == 0:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
| 
						 | 
					@ -337,12 +343,11 @@ cdef class Begin:
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.open_ent(label)
 | 
					        st.open_ent(label)
 | 
				
			||||||
        st.set_ent_tag(st.B(0), 3, label)
 | 
					 | 
				
			||||||
        st.push()
 | 
					        st.push()
 | 
				
			||||||
        st.pop()
 | 
					        st.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        gold = <GoldNERStateC*>_gold
 | 
					        gold = <GoldNERStateC*>_gold
 | 
				
			||||||
        cdef int g_act = gold.ner[s.B(0)].move
 | 
					        cdef int g_act = gold.ner[s.B(0)].move
 | 
				
			||||||
        cdef attr_t g_tag = gold.ner[s.B(0)].label
 | 
					        cdef attr_t g_tag = gold.ner[s.B(0)].label
 | 
				
			||||||
| 
						 | 
					@ -366,16 +371,17 @@ cdef class Begin:
 | 
				
			||||||
cdef class In:
 | 
					cdef class In:
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
					    cdef bint is_valid(const StateC* st, attr_t label) nogil:
 | 
				
			||||||
 | 
					        if not st.entity_is_open():
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        if st.buffer_length() < 2:
 | 
				
			||||||
 | 
					            # If we're at the end, we can't I.
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        ent = st.get_ent()
 | 
				
			||||||
        cdef int preset_ent_iob = st.B_(0).ent_iob
 | 
					        cdef int preset_ent_iob = st.B_(0).ent_iob
 | 
				
			||||||
        cdef attr_t preset_ent_label = st.B_(0).ent_type
 | 
					        cdef attr_t preset_ent_label = st.B_(0).ent_type
 | 
				
			||||||
        if label == 0:
 | 
					        if label == 0:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        elif st.E_(0).ent_type != label:
 | 
					        elif ent.label != label:
 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        elif not st.entity_is_open():
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        elif st.B(1) == -1:
 | 
					 | 
				
			||||||
            # If we're at the end, we can't I.
 | 
					 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        elif preset_ent_iob == 3:
 | 
					        elif preset_ent_iob == 3:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
| 
						 | 
					@ -401,12 +407,11 @@ cdef class In:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.set_ent_tag(st.B(0), 1, label)
 | 
					 | 
				
			||||||
        st.push()
 | 
					        st.push()
 | 
				
			||||||
        st.pop()
 | 
					        st.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        gold = <GoldNERStateC*>_gold
 | 
					        gold = <GoldNERStateC*>_gold
 | 
				
			||||||
        move = IN
 | 
					        move = IN
 | 
				
			||||||
        cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
 | 
					        cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
 | 
				
			||||||
| 
						 | 
					@ -457,7 +462,7 @@ cdef class Last:
 | 
				
			||||||
                # Otherwise, force acceptance, even if we're across a sentence
 | 
					                # Otherwise, force acceptance, even if we're across a sentence
 | 
				
			||||||
                # boundary or the token is whitespace.
 | 
					                # boundary or the token is whitespace.
 | 
				
			||||||
                return True
 | 
					                return True
 | 
				
			||||||
        elif st.E_(0).ent_type != label:
 | 
					        elif st.get_ent().label != label:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        elif st.B_(1).ent_iob == 1:
 | 
					        elif st.B_(1).ent_iob == 1:
 | 
				
			||||||
            # If a preset entity has I next, we can't L here.
 | 
					            # If a preset entity has I next, we can't L here.
 | 
				
			||||||
| 
						 | 
					@ -468,12 +473,11 @@ cdef class Last:
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.close_ent()
 | 
					        st.close_ent()
 | 
				
			||||||
        st.set_ent_tag(st.B(0), 1, label)
 | 
					 | 
				
			||||||
        st.push()
 | 
					        st.push()
 | 
				
			||||||
        st.pop()
 | 
					        st.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        gold = <GoldNERStateC*>_gold
 | 
					        gold = <GoldNERStateC*>_gold
 | 
				
			||||||
        move = LAST
 | 
					        move = LAST
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -537,12 +541,11 @@ cdef class Unit:
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.open_ent(label)
 | 
					        st.open_ent(label)
 | 
				
			||||||
        st.close_ent()
 | 
					        st.close_ent()
 | 
				
			||||||
        st.set_ent_tag(st.B(0), 3, label)
 | 
					 | 
				
			||||||
        st.push()
 | 
					        st.push()
 | 
				
			||||||
        st.pop()
 | 
					        st.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        gold = <GoldNERStateC*>_gold
 | 
					        gold = <GoldNERStateC*>_gold
 | 
				
			||||||
        cdef int g_act = gold.ner[s.B(0)].move
 | 
					        cdef int g_act = gold.ner[s.B(0)].move
 | 
				
			||||||
        cdef attr_t g_tag = gold.ner[s.B(0)].label
 | 
					        cdef attr_t g_tag = gold.ner[s.B(0)].label
 | 
				
			||||||
| 
						 | 
					@ -578,12 +581,11 @@ cdef class Out:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef int transition(StateC* st, attr_t label) nogil:
 | 
					    cdef int transition(StateC* st, attr_t label) nogil:
 | 
				
			||||||
        st.set_ent_tag(st.B(0), 2, 0)
 | 
					 | 
				
			||||||
        st.push()
 | 
					        st.push()
 | 
				
			||||||
        st.pop()
 | 
					        st.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef weight_t cost(StateClass s, const void* _gold, attr_t label) nogil:
 | 
					    cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
 | 
				
			||||||
        gold = <GoldNERStateC*>_gold
 | 
					        gold = <GoldNERStateC*>_gold
 | 
				
			||||||
        cdef int g_act = gold.ner[s.B(0)].move
 | 
					        cdef int g_act = gold.ner[s.B(0)].move
 | 
				
			||||||
        cdef attr_t g_tag = gold.ner[s.B(0)].label
 | 
					        cdef attr_t g_tag = gold.ner[s.B(0)].label
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,30 +2,24 @@ from cymem.cymem cimport Pool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ...structs cimport TokenC, SpanC
 | 
					from ...structs cimport TokenC, SpanC
 | 
				
			||||||
from ...typedefs cimport attr_t
 | 
					from ...typedefs cimport attr_t
 | 
				
			||||||
 | 
					from ...tokens.doc cimport Doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._state cimport StateC
 | 
					from ._state cimport StateC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class StateClass:
 | 
					cdef class StateClass:
 | 
				
			||||||
    cdef Pool mem
 | 
					 | 
				
			||||||
    cdef StateC* c
 | 
					    cdef StateC* c
 | 
				
			||||||
 | 
					    cdef readonly Doc doc
 | 
				
			||||||
    cdef int _borrowed
 | 
					    cdef int _borrowed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef inline StateClass init(const TokenC* sent, int length):
 | 
					    cdef inline StateClass borrow(StateC* ptr, Doc doc):
 | 
				
			||||||
        cdef StateClass self = StateClass()
 | 
					        cdef StateClass self = StateClass()
 | 
				
			||||||
        self.c = new StateC(sent, length)
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    cdef inline StateClass borrow(StateC* ptr):
 | 
					 | 
				
			||||||
        cdef StateClass self = StateClass()
 | 
					 | 
				
			||||||
        del self.c
 | 
					 | 
				
			||||||
        self.c = ptr
 | 
					        self.c = ptr
 | 
				
			||||||
        self._borrowed = 1
 | 
					        self._borrowed = 1
 | 
				
			||||||
 | 
					        self.doc = doc
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    cdef inline StateClass init_offset(const TokenC* sent, int length, int
 | 
					    cdef inline StateClass init_offset(const TokenC* sent, int length, int
 | 
				
			||||||
                                       offset):
 | 
					                                       offset):
 | 
				
			||||||
| 
						 | 
					@ -33,105 +27,3 @@ cdef class StateClass:
 | 
				
			||||||
        self.c = new StateC(sent, length)
 | 
					        self.c = new StateC(sent, length)
 | 
				
			||||||
        self.c.offset = offset
 | 
					        self.c.offset = offset
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int S(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.S(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int B(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.B(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* S_(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.S_(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* B_(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.B_(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* H_(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.H_(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* E_(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.E_(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* L_(self, int i, int idx) nogil:
 | 
					 | 
				
			||||||
        return self.c.L_(i, idx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* R_(self, int i, int idx) nogil:
 | 
					 | 
				
			||||||
        return self.c.R_(i, idx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline const TokenC* safe_get(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.safe_get(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int H(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.H(i)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    cdef inline int E(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.E(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int L(self, int i, int idx) nogil:
 | 
					 | 
				
			||||||
        return self.c.L(i, idx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int R(self, int i, int idx) nogil:
 | 
					 | 
				
			||||||
        return self.c.R(i, idx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline bint empty(self) nogil:
 | 
					 | 
				
			||||||
        return self.c.empty()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline bint eol(self) nogil:
 | 
					 | 
				
			||||||
        return self.c.eol()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline bint at_break(self) nogil:
 | 
					 | 
				
			||||||
        return self.c.at_break()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline bint has_head(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.has_head(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int n_L(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.n_L(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int n_R(self, int i) nogil:
 | 
					 | 
				
			||||||
        return self.c.n_R(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline bint stack_is_connected(self) nogil:
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline bint entity_is_open(self) nogil:
 | 
					 | 
				
			||||||
        return self.c.entity_is_open()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int stack_depth(self) nogil:
 | 
					 | 
				
			||||||
        return self.c.stack_depth()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline int buffer_length(self) nogil:
 | 
					 | 
				
			||||||
        return self.c.buffer_length()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void push(self) nogil:
 | 
					 | 
				
			||||||
        self.c.push()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void pop(self) nogil:
 | 
					 | 
				
			||||||
        self.c.pop()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void unshift(self) nogil:
 | 
					 | 
				
			||||||
        self.c.unshift()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void add_arc(self, int head, int child, attr_t label) nogil:
 | 
					 | 
				
			||||||
        self.c.add_arc(head, child, label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void del_arc(self, int head, int child) nogil:
 | 
					 | 
				
			||||||
        self.c.del_arc(head, child)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void open_ent(self, attr_t label) nogil:
 | 
					 | 
				
			||||||
        self.c.open_ent(label)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void close_ent(self) nogil:
 | 
					 | 
				
			||||||
        self.c.close_ent()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void set_ent_tag(self, int i, int ent_iob, attr_t ent_type) nogil:
 | 
					 | 
				
			||||||
        self.c.set_ent_tag(i, ent_iob, ent_type)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void set_break(self, int i) nogil:
 | 
					 | 
				
			||||||
        self.c.set_break(i)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void clone(self, StateClass src) nogil:
 | 
					 | 
				
			||||||
        self.c.clone(src.c)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef inline void fast_forward(self) nogil:
 | 
					 | 
				
			||||||
        self.c.fast_forward()
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,17 +1,20 @@
 | 
				
			||||||
# cython: infer_types=True
 | 
					# cython: infer_types=True
 | 
				
			||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
 | 
					from libcpp.vector cimport vector
 | 
				
			||||||
 | 
					from ._state cimport ArcC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ...tokens.doc cimport Doc
 | 
					from ...tokens.doc cimport Doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class StateClass:
 | 
					cdef class StateClass:
 | 
				
			||||||
    def __init__(self, Doc doc=None, int offset=0):
 | 
					    def __init__(self, Doc doc=None, int offset=0):
 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					 | 
				
			||||||
        self.mem = mem
 | 
					 | 
				
			||||||
        self._borrowed = 0
 | 
					        self._borrowed = 0
 | 
				
			||||||
        if doc is not None:
 | 
					        if doc is not None:
 | 
				
			||||||
            self.c = new StateC(doc.c, doc.length)
 | 
					            self.c = new StateC(doc.c, doc.length)
 | 
				
			||||||
            self.c.offset = offset
 | 
					            self.c.offset = offset
 | 
				
			||||||
 | 
					            self.doc = doc
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.doc = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __dealloc__(self):
 | 
					    def __dealloc__(self):
 | 
				
			||||||
        if self._borrowed != 1:
 | 
					        if self._borrowed != 1:
 | 
				
			||||||
| 
						 | 
					@ -19,36 +22,157 @@ cdef class StateClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def stack(self):
 | 
					    def stack(self):
 | 
				
			||||||
        return {self.S(i) for i in range(self.c._s_i)}
 | 
					        return [self.S(i) for i in range(self.c.stack_depth())]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def queue(self):
 | 
					    def queue(self):
 | 
				
			||||||
        return {self.B(i) for i in range(self.c.buffer_length())}
 | 
					        return [self.B(i) for i in range(self.c.buffer_length())]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def token_vector_lenth(self):
 | 
					    def token_vector_lenth(self):
 | 
				
			||||||
        return self.doc.tensor.shape[1]
 | 
					        return self.doc.tensor.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def history(self):
 | 
					    def arcs(self):
 | 
				
			||||||
        hist = numpy.ndarray((8,), dtype='i')
 | 
					        cdef vector[ArcC] arcs
 | 
				
			||||||
        for i in range(8):
 | 
					        self.c.get_arcs(&arcs)
 | 
				
			||||||
            hist[i] = self.c.get_hist(i+1)
 | 
					        return list(arcs)
 | 
				
			||||||
        return hist
 | 
					        #py_arcs = []
 | 
				
			||||||
 | 
					        #for arc in arcs:
 | 
				
			||||||
 | 
					        #    if arc.head != -1 and arc.child != -1:
 | 
				
			||||||
 | 
					        #        py_arcs.append((arc.head, arc.child, arc.label))
 | 
				
			||||||
 | 
					        #return arcs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_arc(self, int head, int child, int label):
 | 
				
			||||||
 | 
					        self.c.add_arc(head, child, label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def del_arc(self, int head, int child):
 | 
				
			||||||
 | 
					        self.c.del_arc(head, child)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def H(self, int child):
 | 
				
			||||||
 | 
					        return self.c.H(child)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def L(self, int head, int idx):
 | 
				
			||||||
 | 
					        return self.c.L(head, idx)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def R(self, int head, int idx):
 | 
				
			||||||
 | 
					        return self.c.R(head, idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def _b_i(self):
 | 
				
			||||||
 | 
					        return self.c._b_i
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def length(self):
 | 
				
			||||||
 | 
					        return self.c.length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_final(self):
 | 
					    def is_final(self):
 | 
				
			||||||
        return self.c.is_final()
 | 
					        return self.c.is_final()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def copy(self):
 | 
					    def copy(self):
 | 
				
			||||||
        cdef StateClass new_state = StateClass.init(self.c._sent, self.c.length)
 | 
					        cdef StateClass new_state = StateClass(doc=self.doc, offset=self.c.offset)
 | 
				
			||||||
        new_state.c.clone(self.c)
 | 
					        new_state.c.clone(self.c)
 | 
				
			||||||
        return new_state
 | 
					        return new_state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_state(self, words):
 | 
					    def print_state(self):
 | 
				
			||||||
 | 
					        words = [token.text for token in self.doc]
 | 
				
			||||||
        words = list(words) + ['_']
 | 
					        words = list(words) + ['_']
 | 
				
			||||||
        top = f"{words[self.S(0)]}_{self.S_(0).head}"
 | 
					        bools = ["F", "T"]
 | 
				
			||||||
        second = f"{words[self.S(1)]}_{self.S_(1).head}"
 | 
					        sent_starts = [bools[self.c.is_sent_start(i)] for i in range(len(self.doc))]
 | 
				
			||||||
        third = f"{words[self.S(2)]}_{self.S_(2).head}"
 | 
					        shifted = [1 if self.c.is_unshiftable(i) else 0 for i in range(self.c.length)]
 | 
				
			||||||
        n0 = words[self.B(0)]
 | 
					        shifted.append("")
 | 
				
			||||||
        n1 = words[self.B(1)]
 | 
					        sent_starts.append("")
 | 
				
			||||||
        return ' '.join((third, second, top, '|', n0, n1))
 | 
					        top = f"{self.S(0)}{words[self.S(0)]}_{words[self.H(self.S(0))]}_{shifted[self.S(0)]}"
 | 
				
			||||||
 | 
					        second = f"{self.S(1)}{words[self.S(1)]}_{words[self.H(self.S(1))]}_{shifted[self.S(1)]}"
 | 
				
			||||||
 | 
					        third = f"{self.S(2)}{words[self.S(2)]}_{words[self.H(self.S(2))]}_{shifted[self.S(2)]}"
 | 
				
			||||||
 | 
					        n0 = f"{self.B(0)}{words[self.B(0)]}_{sent_starts[self.B(0)]}_{shifted[self.B(0)]}"
 | 
				
			||||||
 | 
					        n1 = f"{self.B(1)}{words[self.B(1)]}_{sent_starts[self.B(1)]}_{shifted[self.B(1)]}"
 | 
				
			||||||
 | 
					        return ' '.join((str(self.stack_depth()), str(self.buffer_length()), third, second, top, '|', n0, n1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def S(self, int i):
 | 
				
			||||||
 | 
					        return self.c.S(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def B(self, int i):
 | 
				
			||||||
 | 
					        return self.c.B(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def H(self, int i):
 | 
				
			||||||
 | 
					        return self.c.H(i)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def E(self, int i):
 | 
				
			||||||
 | 
					        return self.c.E(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def L(self, int i, int idx):
 | 
				
			||||||
 | 
					        return self.c.L(i, idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def R(self, int i, int idx):
 | 
				
			||||||
 | 
					        return self.c.R(i, idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def S_(self, int i):
 | 
				
			||||||
 | 
					        return self.doc[self.c.S(i)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def B_(self, int i):
 | 
				
			||||||
 | 
					        return self.doc[self.c.B(i)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def H_(self, int i):
 | 
				
			||||||
 | 
					        return self.doc[self.c.H(i)]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def E_(self, int i):
 | 
				
			||||||
 | 
					        return self.doc[self.c.E(i)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def L_(self, int i, int idx):
 | 
				
			||||||
 | 
					        return self.doc[self.c.L(i, idx)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def R_(self, int i, int idx):
 | 
				
			||||||
 | 
					        return self.doc[self.c.R(i, idx)]
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					    def empty(self):
 | 
				
			||||||
 | 
					        return self.c.empty()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eol(self):
 | 
				
			||||||
 | 
					        return self.c.eol()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def at_break(self):
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					        #return self.c.at_break()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_head(self, int i):
 | 
				
			||||||
 | 
					        return self.c.has_head(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def  n_L(self, int i):
 | 
				
			||||||
 | 
					        return self.c.n_L(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def n_R(self, int i):
 | 
				
			||||||
 | 
					        return self.c.n_R(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def entity_is_open(self):
 | 
				
			||||||
 | 
					        return self.c.entity_is_open()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stack_depth(self):
 | 
				
			||||||
 | 
					        return self.c.stack_depth()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def buffer_length(self):
 | 
				
			||||||
 | 
					        return self.c.buffer_length()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def push(self):
 | 
				
			||||||
 | 
					        self.c.push()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def pop(self):
 | 
				
			||||||
 | 
					        self.c.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unshift(self):
 | 
				
			||||||
 | 
					        self.c.unshift()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_arc(self, int head, int child, attr_t label):
 | 
				
			||||||
 | 
					        self.c.add_arc(head, child, label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def del_arc(self, int head, int child):
 | 
				
			||||||
 | 
					        self.c.del_arc(head, child)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def open_ent(self, attr_t label):
 | 
				
			||||||
 | 
					        self.c.open_ent(label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close_ent(self):
 | 
				
			||||||
 | 
					        self.c.close_ent()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def clone(self, StateClass src):
 | 
				
			||||||
 | 
					        self.c.clone(src.c)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,14 +16,14 @@ cdef struct Transition:
 | 
				
			||||||
    weight_t score
 | 
					    weight_t score
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bint (*is_valid)(const StateC* state, attr_t label) nogil
 | 
					    bint (*is_valid)(const StateC* state, attr_t label) nogil
 | 
				
			||||||
    weight_t (*get_cost)(StateClass state, const void* gold, attr_t label) nogil
 | 
					    weight_t (*get_cost)(const StateC* state, const void* gold, attr_t label) nogil
 | 
				
			||||||
    int (*do)(StateC* state, attr_t label) nogil
 | 
					    int (*do)(StateC* state, attr_t label) nogil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ctypedef weight_t (*get_cost_func_t)(StateClass state, const void* gold,
 | 
					ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
 | 
				
			||||||
        attr_tlabel) nogil
 | 
					        attr_tlabel) nogil
 | 
				
			||||||
ctypedef weight_t (*move_cost_func_t)(StateClass state, const void* gold) nogil
 | 
					ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
 | 
				
			||||||
ctypedef weight_t (*label_cost_func_t)(StateClass state, const void*
 | 
					ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
 | 
				
			||||||
        gold, attr_t label) nogil
 | 
					        gold, attr_t label) nogil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
 | 
					ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
 | 
				
			||||||
| 
						 | 
					@ -41,9 +41,8 @@ cdef class TransitionSystem:
 | 
				
			||||||
    cdef public attr_t root_label
 | 
					    cdef public attr_t root_label
 | 
				
			||||||
    cdef public freqs
 | 
					    cdef public freqs
 | 
				
			||||||
    cdef public object labels
 | 
					    cdef public object labels
 | 
				
			||||||
 | 
					    cdef init_state_t init_beam_state
 | 
				
			||||||
    cdef int initialize_state(self, StateC* state) nogil
 | 
					    cdef del_state_t del_beam_state
 | 
				
			||||||
    cdef int finalize_state(self, StateC* state) nogil
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef Transition lookup_transition(self, object name) except *
 | 
					    cdef Transition lookup_transition(self, object name) except *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,4 +51,4 @@ cdef class TransitionSystem:
 | 
				
			||||||
    cdef int set_valid(self, int* output, const StateC* st) nogil
 | 
					    cdef int set_valid(self, int* output, const StateC* st) nogil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
					    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
				
			||||||
                       StateClass state, gold) except -1
 | 
					                       const StateC* state, gold) except -1
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,7 @@ from cymem.cymem cimport Pool
 | 
				
			||||||
from collections import Counter
 | 
					from collections import Counter
 | 
				
			||||||
import srsly
 | 
					import srsly
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from . cimport _beam_utils
 | 
				
			||||||
from ...typedefs cimport weight_t, attr_t
 | 
					from ...typedefs cimport weight_t, attr_t
 | 
				
			||||||
from ...tokens.doc cimport Doc
 | 
					from ...tokens.doc cimport Doc
 | 
				
			||||||
from ...structs cimport TokenC
 | 
					from ...structs cimport TokenC
 | 
				
			||||||
| 
						 | 
					@ -44,6 +45,8 @@ cdef class TransitionSystem:
 | 
				
			||||||
        if labels_by_action:
 | 
					        if labels_by_action:
 | 
				
			||||||
            self.initialize_actions(labels_by_action, min_freq=min_freq)
 | 
					            self.initialize_actions(labels_by_action, min_freq=min_freq)
 | 
				
			||||||
        self.root_label = self.strings.add('ROOT')
 | 
					        self.root_label = self.strings.add('ROOT')
 | 
				
			||||||
 | 
					        self.init_beam_state = _init_state
 | 
				
			||||||
 | 
					        self.del_beam_state = _del_state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __reduce__(self):
 | 
					    def __reduce__(self):
 | 
				
			||||||
        return (self.__class__, (self.strings, self.labels), None, None)
 | 
					        return (self.__class__, (self.strings, self.labels), None, None)
 | 
				
			||||||
| 
						 | 
					@ -54,7 +57,6 @@ cdef class TransitionSystem:
 | 
				
			||||||
        offset = 0
 | 
					        offset = 0
 | 
				
			||||||
        for doc in docs:
 | 
					        for doc in docs:
 | 
				
			||||||
            state = StateClass(doc, offset=offset)
 | 
					            state = StateClass(doc, offset=offset)
 | 
				
			||||||
            self.initialize_state(state.c)
 | 
					 | 
				
			||||||
            states.append(state)
 | 
					            states.append(state)
 | 
				
			||||||
            offset += len(doc)
 | 
					            offset += len(doc)
 | 
				
			||||||
        return states
 | 
					        return states
 | 
				
			||||||
| 
						 | 
					@ -80,7 +82,7 @@ cdef class TransitionSystem:
 | 
				
			||||||
        history = []
 | 
					        history = []
 | 
				
			||||||
        debug_log = []
 | 
					        debug_log = []
 | 
				
			||||||
        while not state.is_final():
 | 
					        while not state.is_final():
 | 
				
			||||||
            self.set_costs(is_valid, costs, state, gold)
 | 
					            self.set_costs(is_valid, costs, state.c, gold)
 | 
				
			||||||
            for i in range(self.n_moves):
 | 
					            for i in range(self.n_moves):
 | 
				
			||||||
                if is_valid[i] and costs[i] <= 0:
 | 
					                if is_valid[i] and costs[i] <= 0:
 | 
				
			||||||
                    action = self.c[i]
 | 
					                    action = self.c[i]
 | 
				
			||||||
| 
						 | 
					@ -124,15 +126,6 @@ cdef class TransitionSystem:
 | 
				
			||||||
        action = self.lookup_transition(name)
 | 
					        action = self.lookup_transition(name)
 | 
				
			||||||
        action.do(state.c, action.label)
 | 
					        action.do(state.c, action.label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int initialize_state(self, StateC* state) nogil:
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef int finalize_state(self, StateC* state) nogil:
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def finalize_doc(self, doc):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cdef Transition lookup_transition(self, object name) except *:
 | 
					    cdef Transition lookup_transition(self, object name) except *:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,7 +144,7 @@ cdef class TransitionSystem:
 | 
				
			||||||
            is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
 | 
					            is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
					    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
				
			||||||
                       StateClass stcls, gold) except -1:
 | 
					                       const StateC* state, gold) except -1:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_class_name(self, int clas):
 | 
					    def get_class_name(self, int clas):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -105,6 +105,93 @@ def make_parser(
 | 
				
			||||||
        update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
					        update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
				
			||||||
        multitasks=[],
 | 
					        multitasks=[],
 | 
				
			||||||
        learn_tokens=learn_tokens,
 | 
					        learn_tokens=learn_tokens,
 | 
				
			||||||
 | 
					        min_action_freq=min_action_freq,
 | 
				
			||||||
 | 
					        beam_width=1,
 | 
				
			||||||
 | 
					        beam_density=0.0,
 | 
				
			||||||
 | 
					        beam_update_prob=0.0,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Language.factory(
 | 
				
			||||||
 | 
					    "beam_parser",
 | 
				
			||||||
 | 
					    assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
 | 
				
			||||||
 | 
					    default_config={
 | 
				
			||||||
 | 
					        "beam_width": 8,
 | 
				
			||||||
 | 
					        "beam_density": 0.01,
 | 
				
			||||||
 | 
					        "beam_update_prob": 0.5,
 | 
				
			||||||
 | 
					        "moves": None,
 | 
				
			||||||
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "learn_tokens": False,
 | 
				
			||||||
 | 
					        "min_action_freq": 30,
 | 
				
			||||||
 | 
					        "model": DEFAULT_PARSER_MODEL,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    default_score_weights={
 | 
				
			||||||
 | 
					        "dep_uas": 0.5,
 | 
				
			||||||
 | 
					        "dep_las": 0.5,
 | 
				
			||||||
 | 
					        "dep_las_per_type": None,
 | 
				
			||||||
 | 
					        "sents_p": None,
 | 
				
			||||||
 | 
					        "sents_r": None,
 | 
				
			||||||
 | 
					        "sents_f": 0.0,
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def make_beam_parser(
 | 
				
			||||||
 | 
					    nlp: Language,
 | 
				
			||||||
 | 
					    name: str,
 | 
				
			||||||
 | 
					    model: Model,
 | 
				
			||||||
 | 
					    moves: Optional[list],
 | 
				
			||||||
 | 
					    update_with_oracle_cut_size: int,
 | 
				
			||||||
 | 
					    learn_tokens: bool,
 | 
				
			||||||
 | 
					    min_action_freq: int,
 | 
				
			||||||
 | 
					    beam_width: int,
 | 
				
			||||||
 | 
					    beam_density: float,
 | 
				
			||||||
 | 
					    beam_update_prob: float,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """Create a transition-based DependencyParser component that uses beam-search.
 | 
				
			||||||
 | 
					    The dependency parser jointly learns sentence segmentation and labelled
 | 
				
			||||||
 | 
					    dependency parsing, and can optionally learn to merge tokens that had been
 | 
				
			||||||
 | 
					    over-segmented by the tokenizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The parser uses a variant of the non-monotonic arc-eager transition-system
 | 
				
			||||||
 | 
					    described by Honnibal and Johnson (2014), with the addition of a "break"
 | 
				
			||||||
 | 
					    transition to perform the sentence segmentation. Nivre's pseudo-projective
 | 
				
			||||||
 | 
					    dependency transformation is used to allow the parser to predict
 | 
				
			||||||
 | 
					    non-projective parses.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The parser is trained using a global objective. That is, it learns to assign
 | 
				
			||||||
 | 
					    probabilities to whole parses.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model (Model): The model for the transition-based parser. The model needs
 | 
				
			||||||
 | 
					        to have a specific substructure of named components --- see the
 | 
				
			||||||
 | 
					        spacy.ml.tb_framework.TransitionModel for details.
 | 
				
			||||||
 | 
					    moves (List[str]): A list of transition names. Inferred from the data if not
 | 
				
			||||||
 | 
					        provided.
 | 
				
			||||||
 | 
					    beam_width (int): The number of candidate analyses to maintain.
 | 
				
			||||||
 | 
					    beam_density (float): The minimum ratio between the scores of the first and
 | 
				
			||||||
 | 
					        last candidates in the beam. This allows the parser to avoid exploring
 | 
				
			||||||
 | 
					        candidates that are too far behind. This is mostly intended to improve
 | 
				
			||||||
 | 
					        efficiency, but it can also improve accuracy as deeper search is not
 | 
				
			||||||
 | 
					        always better.
 | 
				
			||||||
 | 
					    beam_update_prob (float): The chance of making a beam update, instead of a
 | 
				
			||||||
 | 
					        greedy update. Greedy updates are an approximation for the beam updates,
 | 
				
			||||||
 | 
					        and are faster to compute.
 | 
				
			||||||
 | 
					    learn_tokens (bool): Whether to learn to merge subtokens that are split
 | 
				
			||||||
 | 
					        relative to the gold standard. Experimental.
 | 
				
			||||||
 | 
					    min_action_freq (int): The minimum frequency of labelled actions to retain.
 | 
				
			||||||
 | 
					        Rarer labelled actions have their label backed-off to "dep". While this
 | 
				
			||||||
 | 
					        primarily affects the label accuracy, it can also affect the attachment
 | 
				
			||||||
 | 
					        structure, as the labels are used to represent the pseudo-projectivity
 | 
				
			||||||
 | 
					        transformation.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return DependencyParser(
 | 
				
			||||||
 | 
					        nlp.vocab,
 | 
				
			||||||
 | 
					        model,
 | 
				
			||||||
 | 
					        name,
 | 
				
			||||||
 | 
					        moves=moves,
 | 
				
			||||||
 | 
					        update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
				
			||||||
 | 
					        beam_width=beam_width,
 | 
				
			||||||
 | 
					        beam_density=beam_density,
 | 
				
			||||||
 | 
					        beam_update_prob=beam_update_prob,
 | 
				
			||||||
 | 
					        multitasks=[],
 | 
				
			||||||
 | 
					        learn_tokens=learn_tokens,
 | 
				
			||||||
        min_action_freq=min_action_freq
 | 
					        min_action_freq=min_action_freq
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -82,6 +82,79 @@ def make_ner(
 | 
				
			||||||
        multitasks=[],
 | 
					        multitasks=[],
 | 
				
			||||||
        min_action_freq=1,
 | 
					        min_action_freq=1,
 | 
				
			||||||
        learn_tokens=False,
 | 
					        learn_tokens=False,
 | 
				
			||||||
 | 
					        beam_width=1,
 | 
				
			||||||
 | 
					        beam_density=0.0,
 | 
				
			||||||
 | 
					        beam_update_prob=0.0,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Language.factory(
 | 
				
			||||||
 | 
					    "beam_ner",
 | 
				
			||||||
 | 
					    assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
 | 
				
			||||||
 | 
					    default_config={
 | 
				
			||||||
 | 
					        "moves": None,
 | 
				
			||||||
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "model": DEFAULT_NER_MODEL,
 | 
				
			||||||
 | 
					        "beam_density": 0.01,
 | 
				
			||||||
 | 
					        "beam_update_prob": 0.5,
 | 
				
			||||||
 | 
					        "beam_width": 32
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def make_beam_ner(
 | 
				
			||||||
 | 
					    nlp: Language,
 | 
				
			||||||
 | 
					    name: str,
 | 
				
			||||||
 | 
					    model: Model,
 | 
				
			||||||
 | 
					    moves: Optional[list],
 | 
				
			||||||
 | 
					    update_with_oracle_cut_size: int,
 | 
				
			||||||
 | 
					    beam_width: int,
 | 
				
			||||||
 | 
					    beam_density: float,
 | 
				
			||||||
 | 
					    beam_update_prob: float,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """Create a transition-based EntityRecognizer component that uses beam-search.
 | 
				
			||||||
 | 
					    The entity recognizer identifies non-overlapping labelled spans of tokens.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The transition-based algorithm used encodes certain assumptions that are
 | 
				
			||||||
 | 
					    effective for "traditional" named entity recognition tasks, but may not be
 | 
				
			||||||
 | 
					    a good fit for every span identification problem. Specifically, the loss
 | 
				
			||||||
 | 
					    function optimizes for whole entity accuracy, so if your inter-annotator
 | 
				
			||||||
 | 
					    agreement on boundary tokens is low, the component will likely perform poorly
 | 
				
			||||||
 | 
					    on your problem. The transition-based algorithm also assumes that the most
 | 
				
			||||||
 | 
					    decisive information about your entities will be close to their initial tokens.
 | 
				
			||||||
 | 
					    If your entities are long and characterised by tokens in their middle, the
 | 
				
			||||||
 | 
					    component will likely do poorly on your task.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model (Model): The model for the transition-based parser. The model needs
 | 
				
			||||||
 | 
					        to have a specific substructure of named components --- see the
 | 
				
			||||||
 | 
					        spacy.ml.tb_framework.TransitionModel for details.
 | 
				
			||||||
 | 
					    moves (list[str]): A list of transition names. Inferred from the data if not
 | 
				
			||||||
 | 
					        provided.
 | 
				
			||||||
 | 
					    update_with_oracle_cut_size (int):
 | 
				
			||||||
 | 
					        During training, cut long sequences into shorter segments by creating
 | 
				
			||||||
 | 
					        intermediate states based on the gold-standard history. The model is
 | 
				
			||||||
 | 
					        not very sensitive to this parameter, so you usually won't need to change
 | 
				
			||||||
 | 
					        it. 100 is a good default.
 | 
				
			||||||
 | 
					    beam_width (int): The number of candidate analyses to maintain.
 | 
				
			||||||
 | 
					    beam_density (float): The minimum ratio between the scores of the first and
 | 
				
			||||||
 | 
					        last candidates in the beam. This allows the parser to avoid exploring
 | 
				
			||||||
 | 
					        candidates that are too far behind. This is mostly intended to improve
 | 
				
			||||||
 | 
					        efficiency, but it can also improve accuracy as deeper search is not
 | 
				
			||||||
 | 
					        always better.
 | 
				
			||||||
 | 
					    beam_update_prob (float): The chance of making a beam update, instead of a
 | 
				
			||||||
 | 
					        greedy update. Greedy updates are an approximation for the beam updates,
 | 
				
			||||||
 | 
					        and are faster to compute.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return EntityRecognizer(
 | 
				
			||||||
 | 
					        nlp.vocab,
 | 
				
			||||||
 | 
					        model,
 | 
				
			||||||
 | 
					        name,
 | 
				
			||||||
 | 
					        moves=moves,
 | 
				
			||||||
 | 
					        update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
				
			||||||
 | 
					        multitasks=[],
 | 
				
			||||||
 | 
					        min_action_freq=1,
 | 
				
			||||||
 | 
					        learn_tokens=False,
 | 
				
			||||||
 | 
					        beam_width=beam_width,
 | 
				
			||||||
 | 
					        beam_density=beam_density,
 | 
				
			||||||
 | 
					        beam_update_prob=beam_update_prob,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,13 +4,14 @@ from cymem.cymem cimport Pool
 | 
				
			||||||
cimport numpy as np
 | 
					cimport numpy as np
 | 
				
			||||||
from itertools import islice
 | 
					from itertools import islice
 | 
				
			||||||
from libcpp.vector cimport vector
 | 
					from libcpp.vector cimport vector
 | 
				
			||||||
from libc.string cimport memset
 | 
					from libc.string cimport memset, memcpy
 | 
				
			||||||
from libc.stdlib cimport calloc, free
 | 
					from libc.stdlib cimport calloc, free
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import srsly
 | 
					import srsly
 | 
				
			||||||
from thinc.api import set_dropout_rate
 | 
					from thinc.api import set_dropout_rate, CupyOps
 | 
				
			||||||
 | 
					from thinc.extra.search cimport Beam
 | 
				
			||||||
import numpy.random
 | 
					import numpy.random
 | 
				
			||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
| 
						 | 
					@ -22,6 +23,8 @@ from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
 | 
				
			||||||
from ..ml.parser_model cimport get_c_weights, get_c_sizes
 | 
					from ..ml.parser_model cimport get_c_weights, get_c_sizes
 | 
				
			||||||
from ..tokens.doc cimport Doc
 | 
					from ..tokens.doc cimport Doc
 | 
				
			||||||
from .trainable_pipe import TrainablePipe
 | 
					from .trainable_pipe import TrainablePipe
 | 
				
			||||||
 | 
					from ._parser_internals cimport _beam_utils
 | 
				
			||||||
 | 
					from ._parser_internals import _beam_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..training import validate_examples, validate_get_examples
 | 
					from ..training import validate_examples, validate_get_examples
 | 
				
			||||||
from ..errors import Errors, Warnings
 | 
					from ..errors import Errors, Warnings
 | 
				
			||||||
| 
						 | 
					@ -41,9 +44,12 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
        moves=None,
 | 
					        moves=None,
 | 
				
			||||||
        *,
 | 
					        *,
 | 
				
			||||||
        update_with_oracle_cut_size,
 | 
					        update_with_oracle_cut_size,
 | 
				
			||||||
        multitasks=tuple(),
 | 
					 | 
				
			||||||
        min_action_freq,
 | 
					        min_action_freq,
 | 
				
			||||||
        learn_tokens,
 | 
					        learn_tokens,
 | 
				
			||||||
 | 
					        beam_width=1,
 | 
				
			||||||
 | 
					        beam_density=0.0,
 | 
				
			||||||
 | 
					        beam_update_prob=0.0,
 | 
				
			||||||
 | 
					        multitasks=tuple(),
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Create a Parser.
 | 
					        """Create a Parser.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,7 +67,10 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            "update_with_oracle_cut_size": update_with_oracle_cut_size,
 | 
					            "update_with_oracle_cut_size": update_with_oracle_cut_size,
 | 
				
			||||||
            "multitasks": list(multitasks),
 | 
					            "multitasks": list(multitasks),
 | 
				
			||||||
            "min_action_freq": min_action_freq,
 | 
					            "min_action_freq": min_action_freq,
 | 
				
			||||||
            "learn_tokens": learn_tokens
 | 
					            "learn_tokens": learn_tokens,
 | 
				
			||||||
 | 
					            "beam_width": beam_width,
 | 
				
			||||||
 | 
					            "beam_density": beam_density,
 | 
				
			||||||
 | 
					            "beam_update_prob": beam_update_prob
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        if moves is None:
 | 
					        if moves is None:
 | 
				
			||||||
            # defined by EntityRecognizer as a BiluoPushDown
 | 
					            # defined by EntityRecognizer as a BiluoPushDown
 | 
				
			||||||
| 
						 | 
					@ -183,7 +192,15 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            result = self.moves.init_batch(docs)
 | 
					            result = self.moves.init_batch(docs)
 | 
				
			||||||
            self._resize()
 | 
					            self._resize()
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
 | 
					        if self.cfg["beam_width"] == 1:
 | 
				
			||||||
            return self.greedy_parse(docs, drop=0.0)
 | 
					            return self.greedy_parse(docs, drop=0.0)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.beam_parse(
 | 
				
			||||||
 | 
					                docs,
 | 
				
			||||||
 | 
					                drop=0.0,
 | 
				
			||||||
 | 
					                beam_width=self.cfg["beam_width"],
 | 
				
			||||||
 | 
					                beam_density=self.cfg["beam_density"]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def greedy_parse(self, docs, drop=0.):
 | 
					    def greedy_parse(self, docs, drop=0.):
 | 
				
			||||||
        cdef vector[StateC*] states
 | 
					        cdef vector[StateC*] states
 | 
				
			||||||
| 
						 | 
					@ -207,6 +224,31 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
        del model
 | 
					        del model
 | 
				
			||||||
        return batch
 | 
					        return batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
 | 
				
			||||||
 | 
					        cdef Beam beam
 | 
				
			||||||
 | 
					        cdef Doc doc
 | 
				
			||||||
 | 
					        batch = _beam_utils.BeamBatch(
 | 
				
			||||||
 | 
					            self.moves,
 | 
				
			||||||
 | 
					            self.moves.init_batch(docs),
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					            beam_width,
 | 
				
			||||||
 | 
					            density=beam_density
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # This is pretty dirty, but the NER can resize itself in init_batch,
 | 
				
			||||||
 | 
					        # if labels are missing. We therefore have to check whether we need to
 | 
				
			||||||
 | 
					        # expand our model output.
 | 
				
			||||||
 | 
					        self._resize()
 | 
				
			||||||
 | 
					        model = self.model.predict(docs)
 | 
				
			||||||
 | 
					        while not batch.is_done:
 | 
				
			||||||
 | 
					            states = batch.get_unfinished_states()
 | 
				
			||||||
 | 
					            if not states:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            scores = model.predict(states)
 | 
				
			||||||
 | 
					            batch.advance(scores)
 | 
				
			||||||
 | 
					        model.clear_memory()
 | 
				
			||||||
 | 
					        del model
 | 
				
			||||||
 | 
					        return list(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef void _parseC(self, StateC** states,
 | 
					    cdef void _parseC(self, StateC** states,
 | 
				
			||||||
            WeightsC weights, SizesC sizes) nogil:
 | 
					            WeightsC weights, SizesC sizes) nogil:
 | 
				
			||||||
        cdef int i, j
 | 
					        cdef int i, j
 | 
				
			||||||
| 
						 | 
					@ -227,14 +269,13 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            unfinished.clear()
 | 
					            unfinished.clear()
 | 
				
			||||||
        free_activations(&activations)
 | 
					        free_activations(&activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_annotations(self, docs, states):
 | 
					    def set_annotations(self, docs, states_or_beams):
 | 
				
			||||||
        cdef StateClass state
 | 
					        cdef StateClass state
 | 
				
			||||||
 | 
					        cdef Beam beam
 | 
				
			||||||
        cdef Doc doc
 | 
					        cdef Doc doc
 | 
				
			||||||
 | 
					        states = _beam_utils.collect_states(states_or_beams, docs)
 | 
				
			||||||
        for i, (state, doc) in enumerate(zip(states, docs)):
 | 
					        for i, (state, doc) in enumerate(zip(states, docs)):
 | 
				
			||||||
            self.moves.finalize_state(state.c)
 | 
					            self.moves.set_annotations(state, doc)
 | 
				
			||||||
            for j in range(doc.length):
 | 
					 | 
				
			||||||
                doc.c[j] = state.c._sent[j]
 | 
					 | 
				
			||||||
            self.moves.finalize_doc(doc)
 | 
					 | 
				
			||||||
            for hook in self.postprocesses:
 | 
					            for hook in self.postprocesses:
 | 
				
			||||||
                hook(doc)
 | 
					                hook(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -265,7 +306,6 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                action = self.moves.c[guess]
 | 
					                action = self.moves.c[guess]
 | 
				
			||||||
                action.do(states[i], action.label)
 | 
					                action.do(states[i], action.label)
 | 
				
			||||||
                states[i].push_hist(guess)
 | 
					 | 
				
			||||||
        free(is_valid)
 | 
					        free(is_valid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
 | 
					    def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
 | 
				
			||||||
| 
						 | 
					@ -276,13 +316,23 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
        validate_examples(examples, "Parser.update")
 | 
					        validate_examples(examples, "Parser.update")
 | 
				
			||||||
        for multitask in self._multitasks:
 | 
					        for multitask in self._multitasks:
 | 
				
			||||||
            multitask.update(examples, drop=drop, sgd=sgd)
 | 
					            multitask.update(examples, drop=drop, sgd=sgd)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
        n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
 | 
					        n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
 | 
				
			||||||
        if n_examples == 0:
 | 
					        if n_examples == 0:
 | 
				
			||||||
            return losses
 | 
					            return losses
 | 
				
			||||||
        set_dropout_rate(self.model, drop)
 | 
					        set_dropout_rate(self.model, drop)
 | 
				
			||||||
        # Prepare the stepwise model, and get the callback for finishing the batch
 | 
					        # The probability we use beam update, instead of falling back to
 | 
				
			||||||
        model, backprop_tok2vec = self.model.begin_update(
 | 
					        # a greedy update
 | 
				
			||||||
            [eg.predicted for eg in examples])
 | 
					        beam_update_prob = self.cfg["beam_update_prob"]
 | 
				
			||||||
 | 
					        if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
 | 
				
			||||||
 | 
					            return self.update_beam(
 | 
				
			||||||
 | 
					                examples,
 | 
				
			||||||
 | 
					                beam_width=self.cfg["beam_width"],
 | 
				
			||||||
 | 
					                set_annotations=set_annotations,
 | 
				
			||||||
 | 
					                sgd=sgd,
 | 
				
			||||||
 | 
					                losses=losses,
 | 
				
			||||||
 | 
					                beam_density=self.cfg["beam_density"]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        max_moves = self.cfg["update_with_oracle_cut_size"]
 | 
					        max_moves = self.cfg["update_with_oracle_cut_size"]
 | 
				
			||||||
        if max_moves >= 1:
 | 
					        if max_moves >= 1:
 | 
				
			||||||
            # Chop sequences into lengths of this many words, to make the
 | 
					            # Chop sequences into lengths of this many words, to make the
 | 
				
			||||||
| 
						 | 
					@ -296,6 +346,8 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
            states, golds, _ = self.moves.init_gold_batch(examples)
 | 
					            states, golds, _ = self.moves.init_gold_batch(examples)
 | 
				
			||||||
        if not states:
 | 
					        if not states:
 | 
				
			||||||
            return losses
 | 
					            return losses
 | 
				
			||||||
 | 
					        model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
        all_states = list(states)
 | 
					        all_states = list(states)
 | 
				
			||||||
        states_golds = list(zip(states, golds))
 | 
					        states_golds = list(zip(states, golds))
 | 
				
			||||||
        n_moves = 0
 | 
					        n_moves = 0
 | 
				
			||||||
| 
						 | 
					@ -379,6 +431,27 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
        del tutor
 | 
					        del tutor
 | 
				
			||||||
        return losses
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_beam(self, examples, *, beam_width,
 | 
				
			||||||
 | 
					            drop=0., sgd=None, losses=None, set_annotations=False, beam_density=0.0):
 | 
				
			||||||
 | 
					        states, golds, _ = self.moves.init_gold_batch(examples)
 | 
				
			||||||
 | 
					        if not states:
 | 
				
			||||||
 | 
					            return losses
 | 
				
			||||||
 | 
					        # Prepare the stepwise model, and get the callback for finishing the batch
 | 
				
			||||||
 | 
					        model, backprop_tok2vec = self.model.begin_update(
 | 
				
			||||||
 | 
					            [eg.predicted for eg in examples])
 | 
				
			||||||
 | 
					        loss = _beam_utils.update_beam(
 | 
				
			||||||
 | 
					            self.moves,
 | 
				
			||||||
 | 
					            states,
 | 
				
			||||||
 | 
					            golds,
 | 
				
			||||||
 | 
					            model,
 | 
				
			||||||
 | 
					            beam_width,
 | 
				
			||||||
 | 
					            beam_density=beam_density,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        losses[self.name] += loss
 | 
				
			||||||
 | 
					        backprop_tok2vec(golds)
 | 
				
			||||||
 | 
					        if sgd is not None:
 | 
				
			||||||
 | 
					            self.finish_update(sgd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
 | 
					    def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
 | 
				
			||||||
        cdef StateClass state
 | 
					        cdef StateClass state
 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
| 
						 | 
					@ -396,7 +469,7 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
        for i, (state, gold) in enumerate(zip(states, golds)):
 | 
					        for i, (state, gold) in enumerate(zip(states, golds)):
 | 
				
			||||||
            memset(is_valid, 0, self.moves.n_moves * sizeof(int))
 | 
					            memset(is_valid, 0, self.moves.n_moves * sizeof(int))
 | 
				
			||||||
            memset(costs, 0, self.moves.n_moves * sizeof(float))
 | 
					            memset(costs, 0, self.moves.n_moves * sizeof(float))
 | 
				
			||||||
            self.moves.set_costs(is_valid, costs, state, gold)
 | 
					            self.moves.set_costs(is_valid, costs, state.c, gold)
 | 
				
			||||||
            for j in range(self.moves.n_moves):
 | 
					            for j in range(self.moves.n_moves):
 | 
				
			||||||
                if costs[j] <= 0.0 and j in unseen_classes:
 | 
					                if costs[j] <= 0.0 and j in unseen_classes:
 | 
				
			||||||
                    unseen_classes.remove(j)
 | 
					                    unseen_classes.remove(j)
 | 
				
			||||||
| 
						 | 
					@ -539,7 +612,6 @@ cdef class Parser(TrainablePipe):
 | 
				
			||||||
                for clas in oracle_actions[i:i+max_length]:
 | 
					                for clas in oracle_actions[i:i+max_length]:
 | 
				
			||||||
                    action = self.moves.c[clas]
 | 
					                    action = self.moves.c[clas]
 | 
				
			||||||
                    action.do(state.c, action.label)
 | 
					                    action.do(state.c, action.label)
 | 
				
			||||||
                    state.c.push_hist(action.clas)
 | 
					 | 
				
			||||||
                    if state.is_final():
 | 
					                    if state.is_final():
 | 
				
			||||||
                        break
 | 
					                        break
 | 
				
			||||||
                if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
 | 
					                if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,6 +7,7 @@ from spacy.tokens import Doc
 | 
				
			||||||
from spacy.pipeline._parser_internals.nonproj import projectivize
 | 
					from spacy.pipeline._parser_internals.nonproj import projectivize
 | 
				
			||||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
 | 
					from spacy.pipeline._parser_internals.arc_eager import ArcEager
 | 
				
			||||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
 | 
					from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
 | 
				
			||||||
 | 
					from spacy.pipeline._parser_internals.stateclass import StateClass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_sequence_costs(M, words, heads, deps, transitions):
 | 
					def get_sequence_costs(M, words, heads, deps, transitions):
 | 
				
			||||||
| 
						 | 
					@ -47,15 +48,24 @@ def test_oracle_four_words(arc_eager, vocab):
 | 
				
			||||||
    for dep in deps:
 | 
					    for dep in deps:
 | 
				
			||||||
        arc_eager.add_action(2, dep)  # Left
 | 
					        arc_eager.add_action(2, dep)  # Left
 | 
				
			||||||
        arc_eager.add_action(3, dep)  # Right
 | 
					        arc_eager.add_action(3, dep)  # Right
 | 
				
			||||||
    actions = ["L-left", "B-ROOT", "L-left"]
 | 
					    actions = ["S", "L-left", "B-ROOT", "S", "D", "S", "L-left", "S", "D"]
 | 
				
			||||||
    state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
 | 
					    state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
 | 
				
			||||||
 | 
					    expected_gold = [
 | 
				
			||||||
 | 
					        ["S"],
 | 
				
			||||||
 | 
					        ["B-ROOT", "L-left"],
 | 
				
			||||||
 | 
					        ["B-ROOT"],
 | 
				
			||||||
 | 
					        ["S"],
 | 
				
			||||||
 | 
					        ["D"],
 | 
				
			||||||
 | 
					        ["S"],
 | 
				
			||||||
 | 
					        ["L-left"],
 | 
				
			||||||
 | 
					        ["S"],
 | 
				
			||||||
 | 
					        ["D"]
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
    assert state.is_final()
 | 
					    assert state.is_final()
 | 
				
			||||||
    for i, state_costs in enumerate(cost_history):
 | 
					    for i, state_costs in enumerate(cost_history):
 | 
				
			||||||
        # Check gold moves is 0 cost
 | 
					        # Check gold moves is 0 cost
 | 
				
			||||||
        assert state_costs[actions[i]] == 0.0, actions[i]
 | 
					        golds = [act for act, cost in state_costs.items() if cost < 1]
 | 
				
			||||||
        for other_action, cost in state_costs.items():
 | 
					        assert golds == expected_gold[i], (i, golds, expected_gold[i])
 | 
				
			||||||
            if other_action != actions[i]:
 | 
					 | 
				
			||||||
                assert cost >= 1, (i, other_action)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
annot_tuples = [
 | 
					annot_tuples = [
 | 
				
			||||||
| 
						 | 
					@ -169,12 +179,15 @@ def test_oracle_dev_sentence(vocab, arc_eager):
 | 
				
			||||||
        . punct said
 | 
					        . punct said
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    expected_transitions = [
 | 
					    expected_transitions = [
 | 
				
			||||||
 | 
					        "S",  # Shift "Rolls-Royce"
 | 
				
			||||||
        "S",  # Shift 'Motor'
 | 
					        "S",  # Shift 'Motor'
 | 
				
			||||||
        "S",  # Shift 'Cars'
 | 
					        "S",  # Shift 'Cars'
 | 
				
			||||||
        "L-nn",  # Attach 'Cars' to 'Inc.'
 | 
					        "L-nn",  # Attach 'Cars' to 'Inc.'
 | 
				
			||||||
        "L-nn",  # Attach 'Motor' to 'Inc.'
 | 
					        "L-nn",  # Attach 'Motor' to 'Inc.'
 | 
				
			||||||
        "L-nn",  # Attach 'Rolls-Royce' to 'Inc.', force shift
 | 
					        "L-nn",  # Attach 'Rolls-Royce' to 'Inc.'
 | 
				
			||||||
 | 
					        "S",     # Shift "Inc."
 | 
				
			||||||
        "L-nsubj",  # Attach 'Inc.' to 'said'
 | 
					        "L-nsubj",  # Attach 'Inc.' to 'said'
 | 
				
			||||||
 | 
					        "S",        # Shift 'said'
 | 
				
			||||||
        "S",  # Shift 'it'
 | 
					        "S",  # Shift 'it'
 | 
				
			||||||
        "L-nsubj",  # Attach 'it.' to 'expects'
 | 
					        "L-nsubj",  # Attach 'it.' to 'expects'
 | 
				
			||||||
        "R-ccomp",  # Attach 'expects' to 'said'
 | 
					        "R-ccomp",  # Attach 'expects' to 'said'
 | 
				
			||||||
| 
						 | 
					@ -204,6 +217,8 @@ def test_oracle_dev_sentence(vocab, arc_eager):
 | 
				
			||||||
        "D",  # Reduce "steady"
 | 
					        "D",  # Reduce "steady"
 | 
				
			||||||
        "D",  # Reduce "expects"
 | 
					        "D",  # Reduce "expects"
 | 
				
			||||||
        "R-punct",  # Attach "." to "said"
 | 
					        "R-punct",  # Attach "." to "said"
 | 
				
			||||||
 | 
					        "D",  # Reduce "."
 | 
				
			||||||
 | 
					        "D",  # Reduce "said"
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gold_words = []
 | 
					    gold_words = []
 | 
				
			||||||
| 
						 | 
					@ -221,10 +236,40 @@ def test_oracle_dev_sentence(vocab, arc_eager):
 | 
				
			||||||
    for dep in gold_deps:
 | 
					    for dep in gold_deps:
 | 
				
			||||||
        arc_eager.add_action(2, dep)  # Left
 | 
					        arc_eager.add_action(2, dep)  # Left
 | 
				
			||||||
        arc_eager.add_action(3, dep)  # Right
 | 
					        arc_eager.add_action(3, dep)  # Right
 | 
				
			||||||
 | 
					 | 
				
			||||||
    doc = Doc(Vocab(), words=gold_words)
 | 
					    doc = Doc(Vocab(), words=gold_words)
 | 
				
			||||||
    example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
 | 
					    example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
 | 
				
			||||||
 | 
					    ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
 | 
				
			||||||
    ae_oracle_actions = arc_eager.get_oracle_sequence(example)
 | 
					 | 
				
			||||||
    ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
 | 
					    ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
 | 
				
			||||||
    assert ae_oracle_actions == expected_transitions
 | 
					    assert ae_oracle_actions == expected_transitions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_oracle_bad_tokenization(vocab, arc_eager):
 | 
				
			||||||
 | 
					    words_deps_heads = """
 | 
				
			||||||
 | 
					        [catalase] dep is
 | 
				
			||||||
 | 
					        : punct is
 | 
				
			||||||
 | 
					        that nsubj is
 | 
				
			||||||
 | 
					        is root is
 | 
				
			||||||
 | 
					        bad comp is
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					    gold_words = []
 | 
				
			||||||
 | 
					    gold_deps = []
 | 
				
			||||||
 | 
					    gold_heads = []
 | 
				
			||||||
 | 
					    for line in words_deps_heads.strip().split("\n"):
 | 
				
			||||||
 | 
					        line = line.strip()
 | 
				
			||||||
 | 
					        if not line:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        word, dep, head = line.split()
 | 
				
			||||||
 | 
					        gold_words.append(word)
 | 
				
			||||||
 | 
					        gold_deps.append(dep)
 | 
				
			||||||
 | 
					        gold_heads.append(head)
 | 
				
			||||||
 | 
					    gold_heads = [gold_words.index(head) for head in gold_heads]
 | 
				
			||||||
 | 
					    for dep in gold_deps:
 | 
				
			||||||
 | 
					        arc_eager.add_action(2, dep)  # Left
 | 
				
			||||||
 | 
					        arc_eager.add_action(3, dep)  # Right
 | 
				
			||||||
 | 
					    reference = Doc(Vocab(), words=gold_words, deps=gold_deps, heads=gold_heads)
 | 
				
			||||||
 | 
					    predicted = Doc(reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"])
 | 
				
			||||||
 | 
					    example = Example(predicted=predicted, reference=reference)
 | 
				
			||||||
 | 
					    ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
 | 
				
			||||||
 | 
					    ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
 | 
				
			||||||
 | 
					    assert ae_oracle_actions
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,7 +54,7 @@ def tsys(vocab, entity_types):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_get_oracle_moves(tsys, doc, entity_annots):
 | 
					def test_get_oracle_moves(tsys, doc, entity_annots):
 | 
				
			||||||
    example = Example.from_dict(doc, {"entities": entity_annots})
 | 
					    example = Example.from_dict(doc, {"entities": entity_annots})
 | 
				
			||||||
    act_classes = tsys.get_oracle_sequence(example)
 | 
					    act_classes = tsys.get_oracle_sequence(example, _debug=False)
 | 
				
			||||||
    names = [tsys.get_class_name(act) for act in act_classes]
 | 
					    names = [tsys.get_class_name(act) for act in act_classes]
 | 
				
			||||||
    assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
 | 
					    assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,144 @@
 | 
				
			||||||
 | 
					# coding: utf8
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import hypothesis
 | 
				
			||||||
 | 
					import hypothesis.strategies
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
 | 
					from spacy.vocab import Vocab
 | 
				
			||||||
 | 
					from spacy.language import Language
 | 
				
			||||||
 | 
					from spacy.pipeline import DependencyParser
 | 
				
			||||||
 | 
					from spacy.pipeline._parser_internals.arc_eager import ArcEager
 | 
				
			||||||
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
 | 
					from spacy.pipeline._parser_internals._beam_utils import BeamBatch
 | 
				
			||||||
 | 
					from spacy.pipeline._parser_internals.stateclass import StateClass
 | 
				
			||||||
 | 
					from spacy.training import Example
 | 
				
			||||||
 | 
					from thinc.tests.strategies import ndarrays_of_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def vocab():
 | 
				
			||||||
 | 
					    return Vocab()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def moves(vocab):
 | 
				
			||||||
 | 
					    aeager = ArcEager(vocab.strings, {})
 | 
				
			||||||
 | 
					    aeager.add_action(0, "")
 | 
				
			||||||
 | 
					    aeager.add_action(1, "")
 | 
				
			||||||
 | 
					    aeager.add_action(2, "nsubj")
 | 
				
			||||||
 | 
					    aeager.add_action(2, "punct")
 | 
				
			||||||
 | 
					    aeager.add_action(2, "aux")
 | 
				
			||||||
 | 
					    aeager.add_action(2, "nsubjpass")
 | 
				
			||||||
 | 
					    aeager.add_action(3, "dobj")
 | 
				
			||||||
 | 
					    aeager.add_action(2, "aux")
 | 
				
			||||||
 | 
					    aeager.add_action(4, "ROOT")
 | 
				
			||||||
 | 
					    return aeager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def docs(vocab):
 | 
				
			||||||
 | 
					    return [
 | 
				
			||||||
 | 
					        Doc(
 | 
				
			||||||
 | 
					            vocab,
 | 
				
			||||||
 | 
					            words=["Rats", "bite", "things"],
 | 
				
			||||||
 | 
					            heads=[1, 1, 1],
 | 
				
			||||||
 | 
					            deps=["nsubj", "ROOT", "dobj"],
 | 
				
			||||||
 | 
					            sent_starts=[True, False, False]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def examples(docs):
 | 
				
			||||||
 | 
					    return [Example(doc, doc.copy()) for doc in docs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def states(docs):
 | 
				
			||||||
 | 
					    return [StateClass(doc) for doc in docs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def tokvecs(docs, vector_size):
 | 
				
			||||||
 | 
					    output = []
 | 
				
			||||||
 | 
					    for doc in docs:
 | 
				
			||||||
 | 
					        vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size))
 | 
				
			||||||
 | 
					        output.append(numpy.asarray(vec))
 | 
				
			||||||
 | 
					    return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def batch_size(docs):
 | 
				
			||||||
 | 
					    return len(docs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def beam_width():
 | 
				
			||||||
 | 
					    return 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(params=[0.0, 0.5, 1.0])
 | 
				
			||||||
 | 
					def beam_density(request):
 | 
				
			||||||
 | 
					    return request.param
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def vector_size():
 | 
				
			||||||
 | 
					    return 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def beam(moves, examples, beam_width):
 | 
				
			||||||
 | 
					    states, golds, _ = moves.init_gold_batch(examples)
 | 
				
			||||||
 | 
					    return BeamBatch(moves, states, golds, width=beam_width, density=0.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def scores(moves, batch_size, beam_width):
 | 
				
			||||||
 | 
					    return numpy.asarray(
 | 
				
			||||||
 | 
					        numpy.concatenate(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                numpy.random.uniform(-0.1, 0.1, (beam_width, moves.n_moves))
 | 
				
			||||||
 | 
					                for _ in range(batch_size)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        ), dtype="float32")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_create_beam(beam):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_beam_advance(beam, scores):
 | 
				
			||||||
 | 
					    beam.advance(scores)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_beam_advance_too_few_scores(beam, scores):
 | 
				
			||||||
 | 
					    n_state = sum(len(beam) for beam in beam)
 | 
				
			||||||
 | 
					    scores = scores[:n_state]
 | 
				
			||||||
 | 
					    with pytest.raises(IndexError):
 | 
				
			||||||
 | 
					        beam.advance(scores[:-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_beam_parse(examples, beam_width):
 | 
				
			||||||
 | 
					    nlp = Language()
 | 
				
			||||||
 | 
					    parser = nlp.add_pipe("beam_parser")
 | 
				
			||||||
 | 
					    parser.cfg["beam_width"] = beam_width
 | 
				
			||||||
 | 
					    parser.add_label("nsubj")
 | 
				
			||||||
 | 
					    parser.initialize(lambda: examples)
 | 
				
			||||||
 | 
					    doc = nlp.make_doc("Australia is a country")
 | 
				
			||||||
 | 
					    parser(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@hypothesis.given(hyp=hypothesis.strategies.data())
 | 
				
			||||||
 | 
					def test_beam_density(moves, examples, beam_width, hyp):
 | 
				
			||||||
 | 
					    beam_density = float(hyp.draw(hypothesis.strategies.floats(0.0, 1.0, width=32)))
 | 
				
			||||||
 | 
					    states, golds, _ = moves.init_gold_batch(examples)
 | 
				
			||||||
 | 
					    beam = BeamBatch(moves, states, golds, width=beam_width, density=beam_density)
 | 
				
			||||||
 | 
					    n_state = sum(len(beam) for beam in beam)
 | 
				
			||||||
 | 
					    scores = hyp.draw(ndarrays_of_shape((n_state, moves.n_moves)))
 | 
				
			||||||
 | 
					    beam.advance(scores)
 | 
				
			||||||
 | 
					    for b in beam:
 | 
				
			||||||
 | 
					        beam_probs = b.probs
 | 
				
			||||||
 | 
					        assert b.min_density == beam_density
 | 
				
			||||||
 | 
					        assert beam_probs[-1] >= beam_probs[0] * beam_density
 | 
				
			||||||
| 
						 | 
					@ -22,6 +22,7 @@ def _parser_example(parser):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def parser(vocab):
 | 
					def parser(vocab):
 | 
				
			||||||
 | 
					    vocab.strings.add("ROOT")
 | 
				
			||||||
    config = {
 | 
					    config = {
 | 
				
			||||||
        "learn_tokens": False,
 | 
					        "learn_tokens": False,
 | 
				
			||||||
        "min_action_freq": 30,
 | 
					        "min_action_freq": 30,
 | 
				
			||||||
| 
						 | 
					@ -76,13 +77,16 @@ def test_sents_1_2(parser):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_sents_1_3(parser):
 | 
					def test_sents_1_3(parser):
 | 
				
			||||||
    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
					    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
				
			||||||
    doc[1].sent_start = True
 | 
					    doc[0].is_sent_start = True
 | 
				
			||||||
    doc[3].sent_start = True
 | 
					    doc[1].is_sent_start = True
 | 
				
			||||||
 | 
					    doc[2].is_sent_start = None
 | 
				
			||||||
 | 
					    doc[3].is_sent_start = True
 | 
				
			||||||
    doc = parser(doc)
 | 
					    doc = parser(doc)
 | 
				
			||||||
    assert len(list(doc.sents)) >= 3
 | 
					    assert len(list(doc.sents)) >= 3
 | 
				
			||||||
    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
					    doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
 | 
				
			||||||
    doc[1].sent_start = True
 | 
					    doc[0].is_sent_start = True
 | 
				
			||||||
    doc[2].sent_start = False
 | 
					    doc[1].is_sent_start = True
 | 
				
			||||||
    doc[3].sent_start = True
 | 
					    doc[2].is_sent_start = False
 | 
				
			||||||
 | 
					    doc[3].is_sent_start = True
 | 
				
			||||||
    doc = parser(doc)
 | 
					    doc = parser(doc)
 | 
				
			||||||
    assert len(list(doc.sents)) == 3
 | 
					    assert len(list(doc.sents)) == 3
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										74
									
								
								spacy/tests/parser/test_state.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								spacy/tests/parser/test_state.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,74 @@
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from spacy.tokens.doc import Doc
 | 
				
			||||||
 | 
					from spacy.vocab import Vocab
 | 
				
			||||||
 | 
					from spacy.pipeline._parser_internals.stateclass import StateClass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def vocab():
 | 
				
			||||||
 | 
					    return Vocab()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def doc(vocab):
 | 
				
			||||||
 | 
					    return Doc(vocab, words=["a", "b", "c", "d"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_init_state(doc):
 | 
				
			||||||
 | 
					    state = StateClass(doc)
 | 
				
			||||||
 | 
					    assert state.stack == []
 | 
				
			||||||
 | 
					    assert state.queue == list(range(len(doc)))
 | 
				
			||||||
 | 
					    assert not state.is_final()
 | 
				
			||||||
 | 
					    assert state.buffer_length() == 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_push_pop(doc):
 | 
				
			||||||
 | 
					    state = StateClass(doc)
 | 
				
			||||||
 | 
					    state.push()
 | 
				
			||||||
 | 
					    assert state.buffer_length() == 3
 | 
				
			||||||
 | 
					    assert state.stack == [0]
 | 
				
			||||||
 | 
					    assert 0 not in state.queue
 | 
				
			||||||
 | 
					    state.push()
 | 
				
			||||||
 | 
					    assert state.stack == [1, 0]
 | 
				
			||||||
 | 
					    assert 1 not in state.queue
 | 
				
			||||||
 | 
					    assert state.buffer_length() == 2
 | 
				
			||||||
 | 
					    state.pop()
 | 
				
			||||||
 | 
					    assert state.stack == [0]
 | 
				
			||||||
 | 
					    assert 1 not in state.queue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_stack_depth(doc):
 | 
				
			||||||
 | 
					    state = StateClass(doc)
 | 
				
			||||||
 | 
					    assert state.stack_depth() == 0
 | 
				
			||||||
 | 
					    assert state.buffer_length() == len(doc)
 | 
				
			||||||
 | 
					    state.push()
 | 
				
			||||||
 | 
					    assert state.buffer_length() == 3
 | 
				
			||||||
 | 
					    assert state.stack_depth() == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_H(doc):
 | 
				
			||||||
 | 
					    state = StateClass(doc)
 | 
				
			||||||
 | 
					    assert state.H(0) == -1
 | 
				
			||||||
 | 
					    state.add_arc(1, 0, 0)
 | 
				
			||||||
 | 
					    assert state.arcs == [{"head": 1, "child": 0, "label": 0}]
 | 
				
			||||||
 | 
					    assert state.H(0) == 1
 | 
				
			||||||
 | 
					    state.add_arc(3, 1, 0)
 | 
				
			||||||
 | 
					    assert state.H(1) == 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_L(doc):
 | 
				
			||||||
 | 
					    state = StateClass(doc)
 | 
				
			||||||
 | 
					    assert state.L(2, 1) == -1
 | 
				
			||||||
 | 
					    state.add_arc(2, 1, 0)
 | 
				
			||||||
 | 
					    assert state.arcs == [{"head": 2, "child": 1, "label": 0}]
 | 
				
			||||||
 | 
					    assert state.L(2, 1) == 1
 | 
				
			||||||
 | 
					    state.add_arc(2, 0, 0)
 | 
				
			||||||
 | 
					    assert state.L(2, 1) == 0
 | 
				
			||||||
 | 
					    assert state.n_L(2) == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_R(doc):
 | 
				
			||||||
 | 
					    state = StateClass(doc)
 | 
				
			||||||
 | 
					    assert state.R(0, 1) == -1
 | 
				
			||||||
 | 
					    state.add_arc(0, 1, 0)
 | 
				
			||||||
 | 
					    assert state.arcs == [{"head": 0, "child": 1, "label": 0}]
 | 
				
			||||||
 | 
					    assert state.R(0, 1) == 1
 | 
				
			||||||
 | 
					    state.add_arc(0, 2, 0)
 | 
				
			||||||
 | 
					    assert state.R(0, 1) == 2
 | 
				
			||||||
 | 
					    assert state.n_R(0) == 2
 | 
				
			||||||
| 
						 | 
					@ -122,7 +122,8 @@ def test_issue4042_bug2():
 | 
				
			||||||
    assert "SOME_LABEL" in ner1.labels
 | 
					    assert "SOME_LABEL" in ner1.labels
 | 
				
			||||||
    apple_ent = Span(doc1, 5, 6, label="MY_ORG")
 | 
					    apple_ent = Span(doc1, 5, 6, label="MY_ORG")
 | 
				
			||||||
    doc1.ents = list(doc1.ents) + [apple_ent]
 | 
					    doc1.ents = list(doc1.ents) + [apple_ent]
 | 
				
			||||||
    # reapply the NER - at this point it should resize itself
 | 
					    # Add the label explicitly. Previously we didn't require this.
 | 
				
			||||||
 | 
					    ner1.add_label("MY_ORG")
 | 
				
			||||||
    ner1(doc1)
 | 
					    ner1(doc1)
 | 
				
			||||||
    assert len(ner1.labels) == 2
 | 
					    assert len(ner1.labels) == 2
 | 
				
			||||||
    assert "SOME_LABEL" in ner1.labels
 | 
					    assert "SOME_LABEL" in ner1.labels
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,6 +22,9 @@ def parser(en_vocab):
 | 
				
			||||||
        "learn_tokens": False,
 | 
					        "learn_tokens": False,
 | 
				
			||||||
        "min_action_freq": 30,
 | 
					        "min_action_freq": 30,
 | 
				
			||||||
        "update_with_oracle_cut_size": 100,
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "beam_width": 1,
 | 
				
			||||||
 | 
					        "beam_update_prob": 1.0,
 | 
				
			||||||
 | 
					        "beam_density": 0.0
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
					    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
				
			||||||
    model = registry.resolve(cfg, validate=True)["model"]
 | 
					    model = registry.resolve(cfg, validate=True)["model"]
 | 
				
			||||||
| 
						 | 
					@ -36,6 +39,9 @@ def blank_parser(en_vocab):
 | 
				
			||||||
        "learn_tokens": False,
 | 
					        "learn_tokens": False,
 | 
				
			||||||
        "min_action_freq": 30,
 | 
					        "min_action_freq": 30,
 | 
				
			||||||
        "update_with_oracle_cut_size": 100,
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "beam_width": 1,
 | 
				
			||||||
 | 
					        "beam_update_prob": 1.0,
 | 
				
			||||||
 | 
					        "beam_density": 0.0
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
					    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
				
			||||||
    model = registry.resolve(cfg, validate=True)["model"]
 | 
					    model = registry.resolve(cfg, validate=True)["model"]
 | 
				
			||||||
| 
						 | 
					@ -58,6 +64,9 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
 | 
				
			||||||
        "learn_tokens": False,
 | 
					        "learn_tokens": False,
 | 
				
			||||||
        "min_action_freq": 0,
 | 
					        "min_action_freq": 0,
 | 
				
			||||||
        "update_with_oracle_cut_size": 100,
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "beam_width": 1,
 | 
				
			||||||
 | 
					        "beam_update_prob": 1.0,
 | 
				
			||||||
 | 
					        "beam_density": 0.0
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
					    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
				
			||||||
    model = registry.resolve(cfg, validate=True)["model"]
 | 
					    model = registry.resolve(cfg, validate=True)["model"]
 | 
				
			||||||
| 
						 | 
					@ -79,6 +88,9 @@ def test_serialize_parser_strings(Parser):
 | 
				
			||||||
        "learn_tokens": False,
 | 
					        "learn_tokens": False,
 | 
				
			||||||
        "min_action_freq": 0,
 | 
					        "min_action_freq": 0,
 | 
				
			||||||
        "update_with_oracle_cut_size": 100,
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "beam_width": 1,
 | 
				
			||||||
 | 
					        "beam_update_prob": 1.0,
 | 
				
			||||||
 | 
					        "beam_density": 0.0
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
					    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
				
			||||||
    model = registry.resolve(cfg, validate=True)["model"]
 | 
					    model = registry.resolve(cfg, validate=True)["model"]
 | 
				
			||||||
| 
						 | 
					@ -98,6 +110,9 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
 | 
				
			||||||
        "learn_tokens": False,
 | 
					        "learn_tokens": False,
 | 
				
			||||||
        "min_action_freq": 0,
 | 
					        "min_action_freq": 0,
 | 
				
			||||||
        "update_with_oracle_cut_size": 100,
 | 
					        "update_with_oracle_cut_size": 100,
 | 
				
			||||||
 | 
					        "beam_width": 1,
 | 
				
			||||||
 | 
					        "beam_update_prob": 1.0,
 | 
				
			||||||
 | 
					        "beam_density": 0.0
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
					    cfg = {"model": DEFAULT_PARSER_MODEL}
 | 
				
			||||||
    model = registry.resolve(cfg, validate=True)["model"]
 | 
					    model = registry.resolve(cfg, validate=True)["model"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -191,6 +191,24 @@ cdef class Example:
 | 
				
			||||||
                    aligned_deps[cand_i] = deps[gold_i]
 | 
					                    aligned_deps[cand_i] = deps[gold_i]
 | 
				
			||||||
        return aligned_heads, aligned_deps
 | 
					        return aligned_heads, aligned_deps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_aligned_sent_starts(self):
 | 
				
			||||||
 | 
					        """Get list of SENT_START attributes aligned to the predicted tokenization.
 | 
				
			||||||
 | 
					        If the reference has not sentence starts, return a list of None values.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The aligned sentence starts use the get_aligned_spans method, rather
 | 
				
			||||||
 | 
					        than aligning the list of tags, so that it handles cases where a mistaken
 | 
				
			||||||
 | 
					        tokenization starts the sentence.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.y.has_annotation("SENT_START"):
 | 
				
			||||||
 | 
					            align = self.alignment.y2x
 | 
				
			||||||
 | 
					            sent_starts = [False] * len(self.x)
 | 
				
			||||||
 | 
					            for y_sent in self.y.sents:
 | 
				
			||||||
 | 
					                x_start = int(align[y_sent.start].dataXd[0])
 | 
				
			||||||
 | 
					                sent_starts[x_start] = True
 | 
				
			||||||
 | 
					            return sent_starts
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return [None] * len(self.x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_aligned_spans_x2y(self, x_spans):
 | 
					    def get_aligned_spans_x2y(self, x_spans):
 | 
				
			||||||
        return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
 | 
					        return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user