mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Move new parser to nn_parser.pyx, and restore old parser, to make tests pass.
This commit is contained in:
		
							parent
							
								
									f8c02b4341
								
							
						
					
					
						commit
						5cac951a16
					
				| 
						 | 
					@ -10,6 +10,7 @@ cimport numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .tokens.doc cimport Doc
 | 
					from .tokens.doc cimport Doc
 | 
				
			||||||
from .syntax.parser cimport Parser
 | 
					from .syntax.parser cimport Parser
 | 
				
			||||||
 | 
					from .syntax.parser import get_templates as get_feature_templates
 | 
				
			||||||
#from .syntax.beam_parser cimport BeamParser
 | 
					#from .syntax.beam_parser cimport BeamParser
 | 
				
			||||||
from .syntax.ner cimport BiluoPushDown
 | 
					from .syntax.ner cimport BiluoPushDown
 | 
				
			||||||
from .syntax.arc_eager cimport ArcEager
 | 
					from .syntax.arc_eager cimport ArcEager
 | 
				
			||||||
| 
						 | 
					@ -113,6 +114,7 @@ cdef class EntityRecognizer(Parser):
 | 
				
			||||||
    Annotate named entities on Doc objects.
 | 
					    Annotate named entities on Doc objects.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    TransitionSystem = BiluoPushDown
 | 
					    TransitionSystem = BiluoPushDown
 | 
				
			||||||
 | 
					    feature_templates = get_feature_templates('ner')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_label(self, label):
 | 
					    def add_label(self, label):
 | 
				
			||||||
        Parser.add_label(self, label)
 | 
					        Parser.add_label(self, label)
 | 
				
			||||||
| 
						 | 
					@ -141,6 +143,7 @@ cdef class EntityRecognizer(Parser):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class DependencyParser(Parser):
 | 
					cdef class DependencyParser(Parser):
 | 
				
			||||||
    TransitionSystem = ArcEager
 | 
					    TransitionSystem = ArcEager
 | 
				
			||||||
 | 
					    feature_templates = get_feature_templates('basic')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_label(self, label):
 | 
					    def add_label(self, label):
 | 
				
			||||||
        Parser.add_label(self, label)
 | 
					        Parser.add_label(self, label)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										18
									
								
								spacy/syntax/nn_parser.pxd
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								spacy/syntax/nn_parser.pxd
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					from thinc.typedefs cimport atom_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .stateclass cimport StateClass
 | 
				
			||||||
 | 
					from .arc_eager cimport TransitionSystem
 | 
				
			||||||
 | 
					from ..vocab cimport Vocab
 | 
				
			||||||
 | 
					from ..tokens.doc cimport Doc
 | 
				
			||||||
 | 
					from ..structs cimport TokenC
 | 
				
			||||||
 | 
					from ._state cimport StateC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class Parser:
 | 
				
			||||||
 | 
					    cdef readonly Vocab vocab
 | 
				
			||||||
 | 
					    cdef readonly object model
 | 
				
			||||||
 | 
					    cdef readonly TransitionSystem moves
 | 
				
			||||||
 | 
					    cdef readonly object cfg
 | 
				
			||||||
 | 
					    cdef public object feature_maps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil
 | 
				
			||||||
							
								
								
									
										677
									
								
								spacy/syntax/nn_parser.pyx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										677
									
								
								spacy/syntax/nn_parser.pyx
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,677 @@
 | 
				
			||||||
 | 
					# cython: infer_types=True
 | 
				
			||||||
 | 
					# cython: profile=True
 | 
				
			||||||
 | 
					# coding: utf-8
 | 
				
			||||||
 | 
					from __future__ import unicode_literals, print_function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from collections import Counter
 | 
				
			||||||
 | 
					import ujson
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from libc.math cimport exp
 | 
				
			||||||
 | 
					cimport cython
 | 
				
			||||||
 | 
					cimport cython.parallel
 | 
				
			||||||
 | 
					import cytoolz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy.random
 | 
				
			||||||
 | 
					cimport numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
 | 
				
			||||||
 | 
					from cpython.exc cimport PyErr_CheckSignals
 | 
				
			||||||
 | 
					from libc.stdint cimport uint32_t, uint64_t
 | 
				
			||||||
 | 
					from libc.string cimport memset, memcpy
 | 
				
			||||||
 | 
					from libc.stdlib cimport malloc, calloc, free
 | 
				
			||||||
 | 
					from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
 | 
				
			||||||
 | 
					from thinc.linear.avgtron cimport AveragedPerceptron
 | 
				
			||||||
 | 
					from thinc.linalg cimport VecVec
 | 
				
			||||||
 | 
					from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
 | 
				
			||||||
 | 
					from thinc.extra.eg cimport Example
 | 
				
			||||||
 | 
					from cymem.cymem cimport Pool, Address
 | 
				
			||||||
 | 
					from murmurhash.mrmr cimport hash64
 | 
				
			||||||
 | 
					from preshed.maps cimport MapStruct
 | 
				
			||||||
 | 
					from preshed.maps cimport map_get
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from thinc.api import layerize, chain
 | 
				
			||||||
 | 
					from thinc.neural import BatchNorm, Model, Affine, ELU, ReLu, Maxout
 | 
				
			||||||
 | 
					from thinc.neural.ops import NumpyOps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..util import get_cuda_stream
 | 
				
			||||||
 | 
					from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from . import _parse_features
 | 
				
			||||||
 | 
					from ._parse_features cimport CONTEXT_SIZE
 | 
				
			||||||
 | 
					from ._parse_features cimport fill_context
 | 
				
			||||||
 | 
					from .stateclass cimport StateClass
 | 
				
			||||||
 | 
					from ._state cimport StateC
 | 
				
			||||||
 | 
					from .nonproj import PseudoProjectivity
 | 
				
			||||||
 | 
					from .transition_system import OracleError
 | 
				
			||||||
 | 
					from .transition_system cimport TransitionSystem, Transition
 | 
				
			||||||
 | 
					from ..structs cimport TokenC
 | 
				
			||||||
 | 
					from ..tokens.doc cimport Doc
 | 
				
			||||||
 | 
					from ..strings cimport StringStore
 | 
				
			||||||
 | 
					from ..gold cimport GoldParse
 | 
				
			||||||
 | 
					from ..attrs cimport TAG, DEP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_templates(*args, **kwargs):
 | 
				
			||||||
 | 
					    return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					USE_FTRL = True
 | 
				
			||||||
 | 
					DEBUG = False
 | 
				
			||||||
 | 
					def set_debug(val):
 | 
				
			||||||
 | 
					    global DEBUG
 | 
				
			||||||
 | 
					    DEBUG = val
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None,
 | 
				
			||||||
 | 
					                               drop=0.):
 | 
				
			||||||
 | 
					    '''Allow a model to be "primed" by pre-computing input features in bulk.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This is used for the parser, where we want to take a batch of documents,
 | 
				
			||||||
 | 
					    and compute vectors for each (token, position) pair. These vectors can then
 | 
				
			||||||
 | 
					    be reused, especially for beam-search.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Let's say we're using 12 features for each state, e.g. word at start of
 | 
				
			||||||
 | 
					    buffer, three words on stack, their children, etc. In the normal arc-eager
 | 
				
			||||||
 | 
					    system, a document of length N is processed in 2*N states. This means we'll
 | 
				
			||||||
 | 
					    create 2*N*12 feature vectors --- but if we pre-compute, we only need
 | 
				
			||||||
 | 
					    N*12 vector computations. The saving for beam-search is much better:
 | 
				
			||||||
 | 
					    if we have a beam of k, we'll normally make 2*N*12*K computations --
 | 
				
			||||||
 | 
					    so we can save the factor k. This also gives a nice CPU/GPU division:
 | 
				
			||||||
 | 
					    we can do all our hard maths up front, packed into large multiplications,
 | 
				
			||||||
 | 
					    and do the hard-to-program parsing on the CPU.
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop)
 | 
				
			||||||
 | 
					    cdef np.ndarray cached
 | 
				
			||||||
 | 
					    if not isinstance(gpu_cached, numpy.ndarray):
 | 
				
			||||||
 | 
					        cached = gpu_cached.get(stream=cuda_stream)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        cached = gpu_cached
 | 
				
			||||||
 | 
					    nF = gpu_cached.shape[1]
 | 
				
			||||||
 | 
					    nO = gpu_cached.shape[2]
 | 
				
			||||||
 | 
					    nP = gpu_cached.shape[3]
 | 
				
			||||||
 | 
					    ops = lower_model.ops
 | 
				
			||||||
 | 
					    features = numpy.zeros((batch_size, nO, nP), dtype='f')
 | 
				
			||||||
 | 
					    synchronized = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(token_ids, drop=0.):
 | 
				
			||||||
 | 
					        nonlocal synchronized
 | 
				
			||||||
 | 
					        if not synchronized and cuda_stream is not None:
 | 
				
			||||||
 | 
					            cuda_stream.synchronize()
 | 
				
			||||||
 | 
					            synchronized = True
 | 
				
			||||||
 | 
					        # This is tricky, but:
 | 
				
			||||||
 | 
					        # - Input to forward on CPU
 | 
				
			||||||
 | 
					        # - Output from forward on CPU
 | 
				
			||||||
 | 
					        # - Input to backward on GPU!
 | 
				
			||||||
 | 
					        # - Output from backward on GPU
 | 
				
			||||||
 | 
					        nonlocal features
 | 
				
			||||||
 | 
					        features = features[:len(token_ids)]
 | 
				
			||||||
 | 
					        features.fill(0)
 | 
				
			||||||
 | 
					        cdef float[:, :, ::1] feats = features
 | 
				
			||||||
 | 
					        cdef int[:, ::1] ids = token_ids
 | 
				
			||||||
 | 
					        _sum_features(<float*>&feats[0,0,0],
 | 
				
			||||||
 | 
					            <float*>cached.data, &ids[0,0],
 | 
				
			||||||
 | 
					            token_ids.shape[0], nF, nO*nP)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if nP >= 2:
 | 
				
			||||||
 | 
					            best, which = ops.maxout(features)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            best = features.reshape((features.shape[0], features.shape[1]))
 | 
				
			||||||
 | 
					            which = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def backward(d_best, sgd=None):
 | 
				
			||||||
 | 
					            # This will usually be on GPU
 | 
				
			||||||
 | 
					            if isinstance(d_best, numpy.ndarray):
 | 
				
			||||||
 | 
					                d_best = ops.xp.array(d_best)
 | 
				
			||||||
 | 
					            if nP >= 2:
 | 
				
			||||||
 | 
					                d_features = ops.backprop_maxout(d_best, which, nP)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1))
 | 
				
			||||||
 | 
					            d_tokens = bp_features((d_features, token_ids), sgd)
 | 
				
			||||||
 | 
					            return d_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return best, backward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef void _sum_features(float* output,
 | 
				
			||||||
 | 
					        const float* cached, const int* token_ids, int B, int F, int O) nogil:
 | 
				
			||||||
 | 
					    cdef int idx, b, f, i
 | 
				
			||||||
 | 
					    cdef const float* feature
 | 
				
			||||||
 | 
					    for b in range(B):
 | 
				
			||||||
 | 
					        for f in range(F):
 | 
				
			||||||
 | 
					            if token_ids[f] < 0:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            idx = token_ids[f] * F * O + f*O
 | 
				
			||||||
 | 
					            feature = &cached[idx]
 | 
				
			||||||
 | 
					            for i in range(O):
 | 
				
			||||||
 | 
					                output[i] += feature[i]
 | 
				
			||||||
 | 
					        output += O
 | 
				
			||||||
 | 
					        token_ids += F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores):
 | 
				
			||||||
 | 
					    cdef StateClass state
 | 
				
			||||||
 | 
					    cdef GoldParse gold
 | 
				
			||||||
 | 
					    cdef Pool mem = Pool()
 | 
				
			||||||
 | 
					    cdef int i
 | 
				
			||||||
 | 
					    is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
 | 
				
			||||||
 | 
					    costs = <float*>mem.alloc(moves.n_moves, sizeof(float))
 | 
				
			||||||
 | 
					    cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f',
 | 
				
			||||||
 | 
					                                           order='c')
 | 
				
			||||||
 | 
					    c_d_scores = <float*>d_scores.data
 | 
				
			||||||
 | 
					    for i, (state, gold) in enumerate(zip(states, golds)):
 | 
				
			||||||
 | 
					        memset(is_valid, 0, moves.n_moves * sizeof(int))
 | 
				
			||||||
 | 
					        memset(costs, 0, moves.n_moves * sizeof(float))
 | 
				
			||||||
 | 
					        moves.set_costs(is_valid, costs, state, gold)
 | 
				
			||||||
 | 
					        cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1])
 | 
				
			||||||
 | 
					        #cpu_regression_loss(c_d_scores,
 | 
				
			||||||
 | 
					        #    costs, is_valid, &scores[i, 0], d_scores.shape[1])
 | 
				
			||||||
 | 
					        c_d_scores += d_scores.shape[1]
 | 
				
			||||||
 | 
					    return d_scores
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef void cpu_log_loss(float* d_scores,
 | 
				
			||||||
 | 
					        const float* costs, const int* is_valid, const float* scores,
 | 
				
			||||||
 | 
					        int O) nogil:
 | 
				
			||||||
 | 
					    """Do multi-label log loss"""
 | 
				
			||||||
 | 
					    cdef double max_, gmax, Z, gZ
 | 
				
			||||||
 | 
					    best = arg_max_if_gold(scores, costs, is_valid, O)
 | 
				
			||||||
 | 
					    guess = arg_max_if_valid(scores, is_valid, O)
 | 
				
			||||||
 | 
					    Z = 1e-10
 | 
				
			||||||
 | 
					    gZ = 1e-10
 | 
				
			||||||
 | 
					    max_ = scores[guess]
 | 
				
			||||||
 | 
					    gmax = scores[best]
 | 
				
			||||||
 | 
					    for i in range(O):
 | 
				
			||||||
 | 
					        if is_valid[i]:
 | 
				
			||||||
 | 
					            Z += exp(scores[i] - max_)
 | 
				
			||||||
 | 
					            if costs[i] <= costs[best]:
 | 
				
			||||||
 | 
					                gZ += exp(scores[i] - gmax)
 | 
				
			||||||
 | 
					    for i in range(O):
 | 
				
			||||||
 | 
					        if not is_valid[i]:
 | 
				
			||||||
 | 
					            d_scores[i] = 0.
 | 
				
			||||||
 | 
					        elif costs[i] <= costs[best]:
 | 
				
			||||||
 | 
					            d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            d_scores[i] = exp(scores[i]-max_) / Z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef void cpu_regression_loss(float* d_scores,
 | 
				
			||||||
 | 
					        const float* costs, const int* is_valid, const float* scores,
 | 
				
			||||||
 | 
					        int O) nogil:
 | 
				
			||||||
 | 
					    cdef float eps = 2.
 | 
				
			||||||
 | 
					    best = arg_max_if_gold(scores, costs, is_valid, O)
 | 
				
			||||||
 | 
					    for i in range(O):
 | 
				
			||||||
 | 
					        if not is_valid[i]:
 | 
				
			||||||
 | 
					            d_scores[i] = 0.
 | 
				
			||||||
 | 
					        elif scores[i] < scores[best]:
 | 
				
			||||||
 | 
					            d_scores[i] = 0.
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # I doubt this is correct?
 | 
				
			||||||
 | 
					            # Looking for something like Huber loss
 | 
				
			||||||
 | 
					            diff = scores[i] - -costs[i]
 | 
				
			||||||
 | 
					            if diff > eps:
 | 
				
			||||||
 | 
					                d_scores[i] = eps
 | 
				
			||||||
 | 
					            elif diff < -eps:
 | 
				
			||||||
 | 
					                d_scores[i] = -eps
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                d_scores[i] = diff
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_states(TransitionSystem moves, docs):
 | 
				
			||||||
 | 
					    cdef Doc doc
 | 
				
			||||||
 | 
					    cdef StateClass state
 | 
				
			||||||
 | 
					    offsets = []
 | 
				
			||||||
 | 
					    states = []
 | 
				
			||||||
 | 
					    offset = 0
 | 
				
			||||||
 | 
					    for i, doc in enumerate(docs):
 | 
				
			||||||
 | 
					        state = StateClass.init(doc.c, doc.length)
 | 
				
			||||||
 | 
					        moves.initialize_state(state.c)
 | 
				
			||||||
 | 
					        states.append(state)
 | 
				
			||||||
 | 
					        offsets.append(offset)
 | 
				
			||||||
 | 
					        offset += len(doc)
 | 
				
			||||||
 | 
					    return states, offsets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0):
 | 
				
			||||||
 | 
					    cdef StateClass state
 | 
				
			||||||
 | 
					    cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
 | 
				
			||||||
 | 
					    ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
 | 
				
			||||||
 | 
					    if offsets is None:
 | 
				
			||||||
 | 
					        offsets = [0] * len(states)
 | 
				
			||||||
 | 
					    for i, (state, offset) in enumerate(zip(states, offsets)):
 | 
				
			||||||
 | 
					        state.set_context_tokens(ids[i], nF, nB, nS, nL, nR)
 | 
				
			||||||
 | 
					        ids[i] += (ids[i] >= 0) * offset
 | 
				
			||||||
 | 
					    return ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_n_iter = 0
 | 
				
			||||||
 | 
					@layerize
 | 
				
			||||||
 | 
					def print_mean_variance(X, drop=0.):
 | 
				
			||||||
 | 
					    global _n_iter
 | 
				
			||||||
 | 
					    _n_iter += 1
 | 
				
			||||||
 | 
					    fwd_iter = _n_iter
 | 
				
			||||||
 | 
					    means = X.mean(axis=0)
 | 
				
			||||||
 | 
					    variance = X.var(axis=0)
 | 
				
			||||||
 | 
					    print(fwd_iter, "M", ', '.join(('%.2f' % m) for m in means))
 | 
				
			||||||
 | 
					    print(fwd_iter, "V", ', '.join(('%.2f' % m) for m in variance))
 | 
				
			||||||
 | 
					    def backward(dX, sgd=None):
 | 
				
			||||||
 | 
					        means = dX.mean(axis=0)
 | 
				
			||||||
 | 
					        variance = dX.var(axis=0)
 | 
				
			||||||
 | 
					        print(fwd_iter, "dM", ', '.join(('%.2f' % m) for m in means))
 | 
				
			||||||
 | 
					        print(fwd_iter, "dV", ', '.join(('%.2f' % m) for m in variance))
 | 
				
			||||||
 | 
					    return X, backward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class Parser:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Base class of the DependencyParser and EntityRecognizer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def load(cls, path, Vocab vocab, TransitionSystem=None, require=False, **cfg):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Load the statistical model from the supplied path.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            path (Path):
 | 
				
			||||||
 | 
					                The path to load from.
 | 
				
			||||||
 | 
					            vocab (Vocab):
 | 
				
			||||||
 | 
					                The vocabulary. Must be shared by the documents to be processed.
 | 
				
			||||||
 | 
					            require (bool):
 | 
				
			||||||
 | 
					                Whether to raise an error if the files are not found.
 | 
				
			||||||
 | 
					        Returns (Parser):
 | 
				
			||||||
 | 
					            The newly constructed object.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with (path / 'config.json').open() as file_:
 | 
				
			||||||
 | 
					            cfg = ujson.load(file_)
 | 
				
			||||||
 | 
					        self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
 | 
				
			||||||
 | 
					        if (path / 'model').exists():
 | 
				
			||||||
 | 
					            self.model.load(str(path / 'model'))
 | 
				
			||||||
 | 
					        elif require:
 | 
				
			||||||
 | 
					            raise IOError(
 | 
				
			||||||
 | 
					                "Required file %s/model not found when loading" % str(path))
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, Vocab vocab, TransitionSystem=None, model=None, **cfg):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Create a Parser.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            vocab (Vocab):
 | 
				
			||||||
 | 
					                The vocabulary object. Must be shared with documents to be processed.
 | 
				
			||||||
 | 
					            model (thinc Model):
 | 
				
			||||||
 | 
					                The statistical model.
 | 
				
			||||||
 | 
					        Returns (Parser):
 | 
				
			||||||
 | 
					            The newly constructed object.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if TransitionSystem is None:
 | 
				
			||||||
 | 
					            TransitionSystem = self.TransitionSystem
 | 
				
			||||||
 | 
					        self.vocab = vocab
 | 
				
			||||||
 | 
					        cfg['actions'] = TransitionSystem.get_actions(**cfg)
 | 
				
			||||||
 | 
					        self.moves = TransitionSystem(vocab.strings, cfg['actions'])
 | 
				
			||||||
 | 
					        if model is None:
 | 
				
			||||||
 | 
					            self.model, self.feature_maps = self.build_model(**cfg)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.model, self.feature_maps = model
 | 
				
			||||||
 | 
					        self.cfg = cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __reduce__(self):
 | 
				
			||||||
 | 
					        return (Parser, (self.vocab, self.moves, self.model), None, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def build_model(self,
 | 
				
			||||||
 | 
					            hidden_width=128, token_vector_width=96, nr_vector=1000,
 | 
				
			||||||
 | 
					            nF=1, nB=1, nS=1, nL=1, nR=1, **cfg):
 | 
				
			||||||
 | 
					        nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
 | 
				
			||||||
 | 
					        with Model.use_device('cpu'):
 | 
				
			||||||
 | 
					            upper = chain(
 | 
				
			||||||
 | 
					                        Maxout(hidden_width, hidden_width),
 | 
				
			||||||
 | 
					                        #print_mean_variance,
 | 
				
			||||||
 | 
					                        zero_init(Affine(self.moves.n_moves, hidden_width)))
 | 
				
			||||||
 | 
					        assert isinstance(upper.ops, NumpyOps)
 | 
				
			||||||
 | 
					        lower = PrecomputableMaxouts(hidden_width, nF=nr_context_tokens, nI=token_vector_width,
 | 
				
			||||||
 | 
					                                     pieces=cfg.get('maxout_pieces', 1))
 | 
				
			||||||
 | 
					        lower.begin_training(lower.ops.allocate((500, token_vector_width)))
 | 
				
			||||||
 | 
					        upper.begin_training(upper.ops.allocate((500, hidden_width)))
 | 
				
			||||||
 | 
					        return upper, lower
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, Doc tokens):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Apply the parser or entity recognizer, setting the annotations onto the Doc object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            doc (Doc): The document to be processed.
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            None
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.parse_batch([tokens])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def pipe(self, stream, int batch_size=1000, int n_threads=2):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Process a stream of documents.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            stream: The sequence of documents to process.
 | 
				
			||||||
 | 
					            batch_size (int):
 | 
				
			||||||
 | 
					                The number of documents to accumulate into a working set.
 | 
				
			||||||
 | 
					            n_threads (int):
 | 
				
			||||||
 | 
					                The number of threads with which to work on the buffer in parallel.
 | 
				
			||||||
 | 
					        Yields (Doc): Documents, in order.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        queue = []
 | 
				
			||||||
 | 
					        for doc in stream:
 | 
				
			||||||
 | 
					            queue.append(doc)
 | 
				
			||||||
 | 
					            if len(queue) == batch_size:
 | 
				
			||||||
 | 
					                self.parse_batch(queue)
 | 
				
			||||||
 | 
					                for doc in queue:
 | 
				
			||||||
 | 
					                    self.moves.finalize_doc(doc)
 | 
				
			||||||
 | 
					                    yield doc
 | 
				
			||||||
 | 
					                queue = []
 | 
				
			||||||
 | 
					        if queue:
 | 
				
			||||||
 | 
					            self.parse_batch(queue)
 | 
				
			||||||
 | 
					            for doc in queue:
 | 
				
			||||||
 | 
					                self.moves.finalize_doc(doc)
 | 
				
			||||||
 | 
					                yield doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def parse_batch(self, docs_tokvecs):
 | 
				
			||||||
 | 
					        cdef:
 | 
				
			||||||
 | 
					            int nC
 | 
				
			||||||
 | 
					            Doc doc
 | 
				
			||||||
 | 
					            StateClass state
 | 
				
			||||||
 | 
					            np.ndarray py_scores
 | 
				
			||||||
 | 
					            int[500] is_valid # Hacks for now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cuda_stream = get_cuda_stream()
 | 
				
			||||||
 | 
					        docs, tokvecs = docs_tokvecs
 | 
				
			||||||
 | 
					        lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps,
 | 
				
			||||||
 | 
					                                                 cuda_stream)
 | 
				
			||||||
 | 
					        upper_model = self.model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        states, offsets = init_states(self.moves, docs)
 | 
				
			||||||
 | 
					        all_states = list(states)
 | 
				
			||||||
 | 
					        todo = [st for st in zip(states, offsets) if not st[0].py_is_final()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        while todo:
 | 
				
			||||||
 | 
					            states, offsets = zip(*todo)
 | 
				
			||||||
 | 
					            token_ids = extract_token_ids(states, offsets=offsets)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            py_scores = upper_model(lower_model(token_ids)[0])
 | 
				
			||||||
 | 
					            scores = <float*>py_scores.data
 | 
				
			||||||
 | 
					            nC = py_scores.shape[1]
 | 
				
			||||||
 | 
					            for state, offset in zip(states, offsets):
 | 
				
			||||||
 | 
					                self.moves.set_valid(is_valid, state.c)
 | 
				
			||||||
 | 
					                guess = arg_max_if_valid(scores, is_valid, nC)
 | 
				
			||||||
 | 
					                action = self.moves.c[guess]
 | 
				
			||||||
 | 
					                action.do(state.c, action.label)
 | 
				
			||||||
 | 
					                scores += nC
 | 
				
			||||||
 | 
					            todo = [st for st in todo if not st[0].py_is_final()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for state, doc in zip(all_states, docs):
 | 
				
			||||||
 | 
					            self.moves.finalize_state(state.c)
 | 
				
			||||||
 | 
					            for i in range(doc.length):
 | 
				
			||||||
 | 
					                doc.c[i] = state.c._sent[i]
 | 
				
			||||||
 | 
					            self.moves.finalize_doc(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(self, docs_tokvecs, golds, drop=0., sgd=None):
 | 
				
			||||||
 | 
					        cdef:
 | 
				
			||||||
 | 
					            int nC
 | 
				
			||||||
 | 
					            Doc doc
 | 
				
			||||||
 | 
					            StateClass state
 | 
				
			||||||
 | 
					            np.ndarray scores
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        docs, tokvecs = docs_tokvecs
 | 
				
			||||||
 | 
					        cuda_stream = get_cuda_stream()
 | 
				
			||||||
 | 
					        lower_model = get_greedy_model_for_batch(len(docs),
 | 
				
			||||||
 | 
					                        tokvecs, self.feature_maps, cuda_stream=cuda_stream,
 | 
				
			||||||
 | 
					                        drop=drop)
 | 
				
			||||||
 | 
					        if isinstance(docs, Doc) and isinstance(golds, GoldParse):
 | 
				
			||||||
 | 
					            return self.update(([docs], tokvecs), [golds], drop=drop)
 | 
				
			||||||
 | 
					        for gold in golds:
 | 
				
			||||||
 | 
					            self.moves.preprocess_gold(gold)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        states, offsets = init_states(self.moves, docs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        todo = zip(states, offsets, golds)
 | 
				
			||||||
 | 
					        todo = filter(lambda sp: not sp[0].py_is_final(), todo)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
 | 
					        is_valid = <int*>mem.alloc(len(states) * self.moves.n_moves, sizeof(int))
 | 
				
			||||||
 | 
					        costs = <float*>mem.alloc(len(states) * self.moves.n_moves, sizeof(float))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        upper_model = self.model
 | 
				
			||||||
 | 
					        d_tokens = self.feature_maps.ops.allocate(tokvecs.shape)
 | 
				
			||||||
 | 
					        backprops = []
 | 
				
			||||||
 | 
					        n_tokens = tokvecs.shape[0]
 | 
				
			||||||
 | 
					        nF = self.feature_maps.nF
 | 
				
			||||||
 | 
					        loss = 0.
 | 
				
			||||||
 | 
					        total = 1e-4
 | 
				
			||||||
 | 
					        follow_gold = False
 | 
				
			||||||
 | 
					        cupy = self.feature_maps.ops.xp
 | 
				
			||||||
 | 
					        while len(todo) >= 4:
 | 
				
			||||||
 | 
					            states, offsets, golds = zip(*todo)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            token_ids = extract_token_ids(states, offsets=offsets)
 | 
				
			||||||
 | 
					            lower, bp_lower = lower_model(token_ids, drop=drop)
 | 
				
			||||||
 | 
					            scores, bp_scores = upper_model.begin_update(lower, drop=drop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            d_scores = get_batch_loss(self.moves, states, golds, scores)
 | 
				
			||||||
 | 
					            loss += numpy.abs(d_scores).sum()
 | 
				
			||||||
 | 
					            total += d_scores.shape[0]
 | 
				
			||||||
 | 
					            d_lower = bp_scores(d_scores, sgd=sgd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if isinstance(tokvecs, cupy.ndarray):
 | 
				
			||||||
 | 
					                gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i', order='C')
 | 
				
			||||||
 | 
					                gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f', order='C')
 | 
				
			||||||
 | 
					                gpu_tok_ids.set(token_ids, stream=cuda_stream)
 | 
				
			||||||
 | 
					                gpu_d_lower.set(d_lower, stream=cuda_stream)
 | 
				
			||||||
 | 
					                backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                backprops.append((token_ids, d_lower, bp_lower))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            c_scores = <float*>scores.data
 | 
				
			||||||
 | 
					            for state, gold in zip(states, golds):
 | 
				
			||||||
 | 
					                if follow_gold:
 | 
				
			||||||
 | 
					                    self.moves.set_costs(is_valid, costs, state, gold)
 | 
				
			||||||
 | 
					                    guess = arg_max_if_gold(c_scores, costs, is_valid, scores.shape[1])
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    self.moves.set_valid(is_valid, state.c)
 | 
				
			||||||
 | 
					                    guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
 | 
				
			||||||
 | 
					                action = self.moves.c[guess]
 | 
				
			||||||
 | 
					                action.do(state.c, action.label)
 | 
				
			||||||
 | 
					                c_scores += scores.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            todo = filter(lambda sp: not sp[0].py_is_final(), todo)
 | 
				
			||||||
 | 
					        # This tells CUDA to block --- so we know our copies are complete.
 | 
				
			||||||
 | 
					        cuda_stream.synchronize()
 | 
				
			||||||
 | 
					        for token_ids, d_lower, bp_lower in backprops:
 | 
				
			||||||
 | 
					            d_state_features = bp_lower(d_lower, sgd=sgd)
 | 
				
			||||||
 | 
					            active_feats = token_ids * (token_ids >= 0)
 | 
				
			||||||
 | 
					            active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1))
 | 
				
			||||||
 | 
					            if hasattr(self.feature_maps.ops.xp, 'scatter_add'):
 | 
				
			||||||
 | 
					                self.feature_maps.ops.xp.scatter_add(d_tokens,
 | 
				
			||||||
 | 
					                    token_ids, d_state_features * active_feats)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.model.ops.xp.add.at(d_tokens,
 | 
				
			||||||
 | 
					                    token_ids, d_state_features * active_feats)
 | 
				
			||||||
 | 
					        return d_tokens, loss / total
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def step_through(self, Doc doc, GoldParse gold=None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Set up a stepwise state, to introspect and control the transition sequence.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            doc (Doc): The document to step through.
 | 
				
			||||||
 | 
					            gold (GoldParse): Optional gold parse
 | 
				
			||||||
 | 
					        Returns (StepwiseState):
 | 
				
			||||||
 | 
					            A state object, to step through the annotation process.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return StepwiseState(self, doc, gold=gold)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def from_transition_sequence(self, Doc doc, sequence):
 | 
				
			||||||
 | 
					        """Control the annotations on a document by specifying a transition sequence
 | 
				
			||||||
 | 
					        to follow.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            doc (Doc): The document to annotate.
 | 
				
			||||||
 | 
					            sequence: A sequence of action names, as unicode strings.
 | 
				
			||||||
 | 
					        Returns: None
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with self.step_through(doc) as stepwise:
 | 
				
			||||||
 | 
					            for transition in sequence:
 | 
				
			||||||
 | 
					                stepwise.transition(transition)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_label(self, label):
 | 
				
			||||||
 | 
					        # Doesn't set label into serializer -- subclasses override it to do that.
 | 
				
			||||||
 | 
					        for action in self.moves.action_types:
 | 
				
			||||||
 | 
					            added = self.moves.add_action(action, label)
 | 
				
			||||||
 | 
					            if added:
 | 
				
			||||||
 | 
					                # Important that the labels be stored as a list! We need the
 | 
				
			||||||
 | 
					                # order, or the model goes out of synch
 | 
				
			||||||
 | 
					                self.cfg.setdefault('extra_labels', []).append(label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class StepwiseState:
 | 
				
			||||||
 | 
					    cdef readonly StateClass stcls
 | 
				
			||||||
 | 
					    cdef readonly Example eg
 | 
				
			||||||
 | 
					    cdef readonly Doc doc
 | 
				
			||||||
 | 
					    cdef readonly GoldParse gold
 | 
				
			||||||
 | 
					    cdef readonly Parser parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, Parser parser, Doc doc, GoldParse gold=None):
 | 
				
			||||||
 | 
					        self.parser = parser
 | 
				
			||||||
 | 
					        self.doc = doc
 | 
				
			||||||
 | 
					        if gold is not None:
 | 
				
			||||||
 | 
					            self.gold = gold
 | 
				
			||||||
 | 
					            self.parser.moves.preprocess_gold(self.gold)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.gold = GoldParse(doc)
 | 
				
			||||||
 | 
					        self.stcls = StateClass.init(doc.c, doc.length)
 | 
				
			||||||
 | 
					        self.parser.moves.initialize_state(self.stcls.c)
 | 
				
			||||||
 | 
					        self.eg = Example(
 | 
				
			||||||
 | 
					            nr_class=self.parser.moves.n_moves,
 | 
				
			||||||
 | 
					            nr_atom=CONTEXT_SIZE,
 | 
				
			||||||
 | 
					            nr_feat=self.parser.model.nr_feat)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __enter__(self):
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __exit__(self, type, value, traceback):
 | 
				
			||||||
 | 
					        self.finish()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def is_final(self):
 | 
				
			||||||
 | 
					        return self.stcls.is_final()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def stack(self):
 | 
				
			||||||
 | 
					        return self.stcls.stack
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def queue(self):
 | 
				
			||||||
 | 
					        return self.stcls.queue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def heads(self):
 | 
				
			||||||
 | 
					        return [self.stcls.H(i) for i in range(self.stcls.c.length)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def deps(self):
 | 
				
			||||||
 | 
					        return [self.doc.vocab.strings[self.stcls.c._sent[i].dep]
 | 
				
			||||||
 | 
					                for i in range(self.stcls.c.length)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def costs(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Find the action-costs for the current state.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not self.gold:
 | 
				
			||||||
 | 
					            raise ValueError("Can't set costs: No GoldParse provided")
 | 
				
			||||||
 | 
					        self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs,
 | 
				
			||||||
 | 
					                self.stcls, self.gold)
 | 
				
			||||||
 | 
					        costs = {}
 | 
				
			||||||
 | 
					        for i in range(self.parser.moves.n_moves):
 | 
				
			||||||
 | 
					            if not self.eg.c.is_valid[i]:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            transition = self.parser.moves.c[i]
 | 
				
			||||||
 | 
					            name = self.parser.moves.move_name(transition.move, transition.label)
 | 
				
			||||||
 | 
					            costs[name] = self.eg.c.costs[i]
 | 
				
			||||||
 | 
					        return costs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict(self):
 | 
				
			||||||
 | 
					        self.eg.reset()
 | 
				
			||||||
 | 
					        #self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
 | 
				
			||||||
 | 
					        #                                                    self.stcls.c)
 | 
				
			||||||
 | 
					        self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
 | 
				
			||||||
 | 
					        #self.parser.model.set_scoresC(self.eg.c.scores,
 | 
				
			||||||
 | 
					        #    self.eg.c.features, self.eg.c.nr_feat)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cdef Transition action = self.parser.moves.c[self.eg.guess]
 | 
				
			||||||
 | 
					        return self.parser.moves.move_name(action.move, action.label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def transition(self, action_name=None):
 | 
				
			||||||
 | 
					        if action_name is None:
 | 
				
			||||||
 | 
					            action_name = self.predict()
 | 
				
			||||||
 | 
					        moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3}
 | 
				
			||||||
 | 
					        if action_name == '_':
 | 
				
			||||||
 | 
					            action_name = self.predict()
 | 
				
			||||||
 | 
					            action = self.parser.moves.lookup_transition(action_name)
 | 
				
			||||||
 | 
					        elif action_name == 'L' or action_name == 'R':
 | 
				
			||||||
 | 
					            self.predict()
 | 
				
			||||||
 | 
					            move = moves[action_name]
 | 
				
			||||||
 | 
					            clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c,
 | 
				
			||||||
 | 
					                                 self.eg.c.nr_class)
 | 
				
			||||||
 | 
					            action = self.parser.moves.c[clas]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            action = self.parser.moves.lookup_transition(action_name)
 | 
				
			||||||
 | 
					        action.do(self.stcls.c, action.label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def finish(self):
 | 
				
			||||||
 | 
					        if self.stcls.is_final():
 | 
				
			||||||
 | 
					            self.parser.moves.finalize_state(self.stcls.c)
 | 
				
			||||||
 | 
					        self.doc.set_parse(self.stcls.c._sent)
 | 
				
			||||||
 | 
					        self.parser.moves.finalize_doc(self.doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ParserStateError(ValueError):
 | 
				
			||||||
 | 
					    def __init__(self, doc):
 | 
				
			||||||
 | 
					        ValueError.__init__(self,
 | 
				
			||||||
 | 
					            "Error analysing doc -- no valid actions available. This should "
 | 
				
			||||||
 | 
					            "never happen, so please report the error on the issue tracker. "
 | 
				
			||||||
 | 
					            "Here's the thread to do so --- reopen it if it's closed:\n"
 | 
				
			||||||
 | 
					            "https://github.com/spacy-io/spaCy/issues/429\n"
 | 
				
			||||||
 | 
					            "Please include the text that the parser failed on, which is:\n"
 | 
				
			||||||
 | 
					            "%s" % repr(doc.text))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, const int* is_valid, int n) nogil:
 | 
				
			||||||
 | 
					    # Find minimum cost
 | 
				
			||||||
 | 
					    cdef float cost = 1
 | 
				
			||||||
 | 
					    for i in range(n):
 | 
				
			||||||
 | 
					        if is_valid[i] and costs[i] < cost:
 | 
				
			||||||
 | 
					            cost = costs[i]
 | 
				
			||||||
 | 
					    # Now find best-scoring with that cost
 | 
				
			||||||
 | 
					    cdef int best = -1
 | 
				
			||||||
 | 
					    for i in range(n):
 | 
				
			||||||
 | 
					        if costs[i] <= cost and is_valid[i]:
 | 
				
			||||||
 | 
					            if best == -1 or scores[i] > scores[best]:
 | 
				
			||||||
 | 
					                best = i
 | 
				
			||||||
 | 
					    return best
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
 | 
				
			||||||
 | 
					    cdef int best = -1
 | 
				
			||||||
 | 
					    for i in range(n):
 | 
				
			||||||
 | 
					        if is_valid[i] >= 1:
 | 
				
			||||||
 | 
					            if best == -1 or scores[i] > scores[best]:
 | 
				
			||||||
 | 
					                best = i
 | 
				
			||||||
 | 
					    return best
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
 | 
				
			||||||
 | 
					                       int nr_class) except -1:
 | 
				
			||||||
 | 
					    cdef weight_t score = 0
 | 
				
			||||||
 | 
					    cdef int mode = -1
 | 
				
			||||||
 | 
					    cdef int i
 | 
				
			||||||
 | 
					    for i in range(nr_class):
 | 
				
			||||||
 | 
					        if actions[i].move == move and (mode == -1 or scores[i] >= score):
 | 
				
			||||||
 | 
					            mode = i
 | 
				
			||||||
 | 
					            score = scores[i]
 | 
				
			||||||
 | 
					    return mode
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,6 @@
 | 
				
			||||||
 | 
					from thinc.linear.avgtron cimport AveragedPerceptron
 | 
				
			||||||
from thinc.typedefs cimport atom_t
 | 
					from thinc.typedefs cimport atom_t
 | 
				
			||||||
 | 
					from thinc.structs cimport FeatureC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .stateclass cimport StateClass
 | 
					from .stateclass cimport StateClass
 | 
				
			||||||
from .arc_eager cimport TransitionSystem
 | 
					from .arc_eager cimport TransitionSystem
 | 
				
			||||||
| 
						 | 
					@ -8,11 +10,15 @@ from ..structs cimport TokenC
 | 
				
			||||||
from ._state cimport StateC
 | 
					from ._state cimport StateC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class ParserModel(AveragedPerceptron):
 | 
				
			||||||
 | 
					    cdef int set_featuresC(self, atom_t* context, FeatureC* features,
 | 
				
			||||||
 | 
					                            const StateC* state) nogil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Parser:
 | 
					cdef class Parser:
 | 
				
			||||||
    cdef readonly Vocab vocab
 | 
					    cdef readonly Vocab vocab
 | 
				
			||||||
    cdef readonly object model
 | 
					    cdef readonly ParserModel model
 | 
				
			||||||
    cdef readonly TransitionSystem moves
 | 
					    cdef readonly TransitionSystem moves
 | 
				
			||||||
    cdef readonly object cfg
 | 
					    cdef readonly object cfg
 | 
				
			||||||
    cdef public object feature_maps
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil
 | 
					    cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,18 +1,17 @@
 | 
				
			||||||
# cython: infer_types=True
 | 
					"""
 | 
				
			||||||
# cython: profile=True
 | 
					MALT-style dependency parser
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
# coding: utf-8
 | 
					# coding: utf-8
 | 
				
			||||||
from __future__ import unicode_literals, print_function
 | 
					# cython: infer_types=True
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections import Counter
 | 
					from collections import Counter
 | 
				
			||||||
import ujson
 | 
					import ujson
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from libc.math cimport exp
 | 
					 | 
				
			||||||
cimport cython
 | 
					cimport cython
 | 
				
			||||||
cimport cython.parallel
 | 
					cimport cython.parallel
 | 
				
			||||||
import cytoolz
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy.random
 | 
					import numpy.random
 | 
				
			||||||
cimport numpy as np
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
 | 
					from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
 | 
				
			||||||
from cpython.exc cimport PyErr_CheckSignals
 | 
					from cpython.exc cimport PyErr_CheckSignals
 | 
				
			||||||
| 
						 | 
					@ -29,13 +28,6 @@ from murmurhash.mrmr cimport hash64
 | 
				
			||||||
from preshed.maps cimport MapStruct
 | 
					from preshed.maps cimport MapStruct
 | 
				
			||||||
from preshed.maps cimport map_get
 | 
					from preshed.maps cimport map_get
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from thinc.api import layerize, chain
 | 
					 | 
				
			||||||
from thinc.neural import BatchNorm, Model, Affine, ELU, ReLu, Maxout
 | 
					 | 
				
			||||||
from thinc.neural.ops import NumpyOps
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..util import get_cuda_stream
 | 
					 | 
				
			||||||
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from . import _parse_features
 | 
					from . import _parse_features
 | 
				
			||||||
from ._parse_features cimport CONTEXT_SIZE
 | 
					from ._parse_features cimport CONTEXT_SIZE
 | 
				
			||||||
from ._parse_features cimport fill_context
 | 
					from ._parse_features cimport fill_context
 | 
				
			||||||
| 
						 | 
					@ -48,12 +40,8 @@ from ..structs cimport TokenC
 | 
				
			||||||
from ..tokens.doc cimport Doc
 | 
					from ..tokens.doc cimport Doc
 | 
				
			||||||
from ..strings cimport StringStore
 | 
					from ..strings cimport StringStore
 | 
				
			||||||
from ..gold cimport GoldParse
 | 
					from ..gold cimport GoldParse
 | 
				
			||||||
from ..attrs cimport TAG, DEP
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_templates(*args, **kwargs):
 | 
					 | 
				
			||||||
    return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
USE_FTRL = True
 | 
					USE_FTRL = True
 | 
				
			||||||
DEBUG = False
 | 
					DEBUG = False
 | 
				
			||||||
def set_debug(val):
 | 
					def set_debug(val):
 | 
				
			||||||
| 
						 | 
					@ -61,205 +49,78 @@ def set_debug(val):
 | 
				
			||||||
    DEBUG = val
 | 
					    DEBUG = val
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None,
 | 
					def get_templates(name):
 | 
				
			||||||
                               drop=0.):
 | 
					    pf = _parse_features
 | 
				
			||||||
    '''Allow a model to be "primed" by pre-computing input features in bulk.
 | 
					    if name == 'ner':
 | 
				
			||||||
 | 
					        return pf.ner
 | 
				
			||||||
    This is used for the parser, where we want to take a batch of documents,
 | 
					    elif name == 'debug':
 | 
				
			||||||
    and compute vectors for each (token, position) pair. These vectors can then
 | 
					        return pf.unigrams
 | 
				
			||||||
    be reused, especially for beam-search.
 | 
					    elif name.startswith('embed'):
 | 
				
			||||||
 | 
					        return (pf.words, pf.tags, pf.labels)
 | 
				
			||||||
    Let's say we're using 12 features for each state, e.g. word at start of
 | 
					 | 
				
			||||||
    buffer, three words on stack, their children, etc. In the normal arc-eager
 | 
					 | 
				
			||||||
    system, a document of length N is processed in 2*N states. This means we'll
 | 
					 | 
				
			||||||
    create 2*N*12 feature vectors --- but if we pre-compute, we only need
 | 
					 | 
				
			||||||
    N*12 vector computations. The saving for beam-search is much better:
 | 
					 | 
				
			||||||
    if we have a beam of k, we'll normally make 2*N*12*K computations --
 | 
					 | 
				
			||||||
    so we can save the factor k. This also gives a nice CPU/GPU division:
 | 
					 | 
				
			||||||
    we can do all our hard maths up front, packed into large multiplications,
 | 
					 | 
				
			||||||
    and do the hard-to-program parsing on the CPU.
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop)
 | 
					 | 
				
			||||||
    cdef np.ndarray cached
 | 
					 | 
				
			||||||
    if not isinstance(gpu_cached, numpy.ndarray):
 | 
					 | 
				
			||||||
        cached = gpu_cached.get(stream=cuda_stream)
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cached = gpu_cached
 | 
					        return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s1_s0 + pf.s0_n1 + pf.n0_n1 + \
 | 
				
			||||||
    nF = gpu_cached.shape[1]
 | 
					                pf.tree_shape + pf.trigrams)
 | 
				
			||||||
    nO = gpu_cached.shape[2]
 | 
					 | 
				
			||||||
    nP = gpu_cached.shape[3]
 | 
					 | 
				
			||||||
    ops = lower_model.ops
 | 
					 | 
				
			||||||
    features = numpy.zeros((batch_size, nO, nP), dtype='f')
 | 
					 | 
				
			||||||
    synchronized = False
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(token_ids, drop=0.):
 | 
					 | 
				
			||||||
        nonlocal synchronized
 | 
					 | 
				
			||||||
        if not synchronized and cuda_stream is not None:
 | 
					 | 
				
			||||||
            cuda_stream.synchronize()
 | 
					 | 
				
			||||||
            synchronized = True
 | 
					 | 
				
			||||||
        # This is tricky, but:
 | 
					 | 
				
			||||||
        # - Input to forward on CPU
 | 
					 | 
				
			||||||
        # - Output from forward on CPU
 | 
					 | 
				
			||||||
        # - Input to backward on GPU!
 | 
					 | 
				
			||||||
        # - Output from backward on GPU
 | 
					 | 
				
			||||||
        nonlocal features
 | 
					 | 
				
			||||||
        features = features[:len(token_ids)]
 | 
					 | 
				
			||||||
        features.fill(0)
 | 
					 | 
				
			||||||
        cdef float[:, :, ::1] feats = features
 | 
					 | 
				
			||||||
        cdef int[:, ::1] ids = token_ids
 | 
					 | 
				
			||||||
        _sum_features(<float*>&feats[0,0,0],
 | 
					 | 
				
			||||||
            <float*>cached.data, &ids[0,0],
 | 
					 | 
				
			||||||
            token_ids.shape[0], nF, nO*nP)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if nP >= 2:
 | 
					cdef class ParserModel(AveragedPerceptron):
 | 
				
			||||||
            best, which = ops.maxout(features)
 | 
					    cdef int set_featuresC(self, atom_t* context, FeatureC* features,
 | 
				
			||||||
 | 
					            const StateC* state) nogil:
 | 
				
			||||||
 | 
					        fill_context(context, state)
 | 
				
			||||||
 | 
					        nr_feat = self.extracter.set_features(features, context)
 | 
				
			||||||
 | 
					        return nr_feat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(self, Example eg, itn=0):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Does regression on negative cost. Sort of cute?
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.time += 1
 | 
				
			||||||
 | 
					        cdef int best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class)
 | 
				
			||||||
 | 
					        cdef int guess = eg.guess
 | 
				
			||||||
 | 
					        if guess == best or best == -1:
 | 
				
			||||||
 | 
					            return 0.0
 | 
				
			||||||
 | 
					        cdef FeatureC feat
 | 
				
			||||||
 | 
					        cdef int clas
 | 
				
			||||||
 | 
					        cdef weight_t gradient
 | 
				
			||||||
 | 
					        if USE_FTRL:
 | 
				
			||||||
 | 
					            for feat in eg.c.features[:eg.c.nr_feat]:
 | 
				
			||||||
 | 
					                for clas in range(eg.c.nr_class):
 | 
				
			||||||
 | 
					                    if eg.c.is_valid[clas] and eg.c.scores[clas] >= eg.c.scores[best]:
 | 
				
			||||||
 | 
					                        gradient = eg.c.scores[clas] + eg.c.costs[clas]
 | 
				
			||||||
 | 
					                        self.update_weight_ftrl(feat.key, clas, feat.value * gradient)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            best = features.reshape((features.shape[0], features.shape[1]))
 | 
					            for feat in eg.c.features[:eg.c.nr_feat]:
 | 
				
			||||||
            which = None
 | 
					                self.update_weight(feat.key, guess, feat.value * eg.c.costs[guess])
 | 
				
			||||||
 | 
					                self.update_weight(feat.key, best, -feat.value * eg.c.costs[guess])
 | 
				
			||||||
 | 
					        return eg.c.costs[guess]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def backward(d_best, sgd=None):
 | 
					    def update_from_histories(self, TransitionSystem moves, Doc doc, histories, weight_t min_grad=0.0):
 | 
				
			||||||
            # This will usually be on GPU
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
            if isinstance(d_best, numpy.ndarray):
 | 
					        features = <FeatureC*>mem.alloc(self.nr_feat, sizeof(FeatureC))
 | 
				
			||||||
                d_best = ops.xp.array(d_best)
 | 
					 | 
				
			||||||
            if nP >= 2:
 | 
					 | 
				
			||||||
                d_features = ops.backprop_maxout(d_best, which, nP)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1))
 | 
					 | 
				
			||||||
            d_tokens = bp_features((d_features, token_ids), sgd)
 | 
					 | 
				
			||||||
            return d_tokens
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return best, backward
 | 
					        cdef StateClass stcls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return forward
 | 
					        cdef class_t clas
 | 
				
			||||||
 | 
					        self.time += 1
 | 
				
			||||||
 | 
					        cdef atom_t[CONTEXT_SIZE] atoms
 | 
				
			||||||
cdef void _sum_features(float* output,
 | 
					        histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad and hist]
 | 
				
			||||||
        const float* cached, const int* token_ids, int B, int F, int O) nogil:
 | 
					        if not histories:
 | 
				
			||||||
    cdef int idx, b, f, i
 | 
					            return None
 | 
				
			||||||
    cdef const float* feature
 | 
					        gradient = [Counter() for _ in range(max([max(h)+1 for _, h in histories]))]
 | 
				
			||||||
    for b in range(B):
 | 
					        for d_loss, history in histories:
 | 
				
			||||||
        for f in range(F):
 | 
					            stcls = StateClass.init(doc.c, doc.length)
 | 
				
			||||||
            if token_ids[f] < 0:
 | 
					            moves.initialize_state(stcls.c)
 | 
				
			||||||
                continue
 | 
					            for clas in history:
 | 
				
			||||||
            idx = token_ids[f] * F * O + f*O
 | 
					                nr_feat = self.set_featuresC(atoms, features, stcls.c)
 | 
				
			||||||
            feature = &cached[idx]
 | 
					                clas_grad = gradient[clas]
 | 
				
			||||||
            for i in range(O):
 | 
					                for feat in features[:nr_feat]:
 | 
				
			||||||
                output[i] += feature[i]
 | 
					                    clas_grad[feat.key] += d_loss * feat.value
 | 
				
			||||||
        output += O
 | 
					                moves.c[clas].do(stcls.c, moves.c[clas].label)
 | 
				
			||||||
        token_ids += F
 | 
					        cdef feat_t key
 | 
				
			||||||
 | 
					        cdef weight_t d_feat
 | 
				
			||||||
 | 
					        for clas, clas_grad in enumerate(gradient):
 | 
				
			||||||
def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores):
 | 
					            for key, d_feat in clas_grad.items():
 | 
				
			||||||
    cdef StateClass state
 | 
					                if d_feat != 0:
 | 
				
			||||||
    cdef GoldParse gold
 | 
					                    self.update_weight_ftrl(key, clas, d_feat)
 | 
				
			||||||
    cdef Pool mem = Pool()
 | 
					 | 
				
			||||||
    cdef int i
 | 
					 | 
				
			||||||
    is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
 | 
					 | 
				
			||||||
    costs = <float*>mem.alloc(moves.n_moves, sizeof(float))
 | 
					 | 
				
			||||||
    cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f',
 | 
					 | 
				
			||||||
                                           order='c')
 | 
					 | 
				
			||||||
    c_d_scores = <float*>d_scores.data
 | 
					 | 
				
			||||||
    for i, (state, gold) in enumerate(zip(states, golds)):
 | 
					 | 
				
			||||||
        memset(is_valid, 0, moves.n_moves * sizeof(int))
 | 
					 | 
				
			||||||
        memset(costs, 0, moves.n_moves * sizeof(float))
 | 
					 | 
				
			||||||
        moves.set_costs(is_valid, costs, state, gold)
 | 
					 | 
				
			||||||
        cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1])
 | 
					 | 
				
			||||||
        #cpu_regression_loss(c_d_scores,
 | 
					 | 
				
			||||||
        #    costs, is_valid, &scores[i, 0], d_scores.shape[1])
 | 
					 | 
				
			||||||
        c_d_scores += d_scores.shape[1]
 | 
					 | 
				
			||||||
    return d_scores
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef void cpu_log_loss(float* d_scores,
 | 
					 | 
				
			||||||
        const float* costs, const int* is_valid, const float* scores,
 | 
					 | 
				
			||||||
        int O) nogil:
 | 
					 | 
				
			||||||
    """Do multi-label log loss"""
 | 
					 | 
				
			||||||
    cdef double max_, gmax, Z, gZ
 | 
					 | 
				
			||||||
    best = arg_max_if_gold(scores, costs, is_valid, O)
 | 
					 | 
				
			||||||
    guess = arg_max_if_valid(scores, is_valid, O)
 | 
					 | 
				
			||||||
    Z = 1e-10
 | 
					 | 
				
			||||||
    gZ = 1e-10
 | 
					 | 
				
			||||||
    max_ = scores[guess]
 | 
					 | 
				
			||||||
    gmax = scores[best]
 | 
					 | 
				
			||||||
    for i in range(O):
 | 
					 | 
				
			||||||
        if is_valid[i]:
 | 
					 | 
				
			||||||
            Z += exp(scores[i] - max_)
 | 
					 | 
				
			||||||
            if costs[i] <= costs[best]:
 | 
					 | 
				
			||||||
                gZ += exp(scores[i] - gmax)
 | 
					 | 
				
			||||||
    for i in range(O):
 | 
					 | 
				
			||||||
        if not is_valid[i]:
 | 
					 | 
				
			||||||
            d_scores[i] = 0.
 | 
					 | 
				
			||||||
        elif costs[i] <= costs[best]:
 | 
					 | 
				
			||||||
            d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            d_scores[i] = exp(scores[i]-max_) / Z
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef void cpu_regression_loss(float* d_scores,
 | 
					 | 
				
			||||||
        const float* costs, const int* is_valid, const float* scores,
 | 
					 | 
				
			||||||
        int O) nogil:
 | 
					 | 
				
			||||||
    cdef float eps = 2.
 | 
					 | 
				
			||||||
    best = arg_max_if_gold(scores, costs, is_valid, O)
 | 
					 | 
				
			||||||
    for i in range(O):
 | 
					 | 
				
			||||||
        if not is_valid[i]:
 | 
					 | 
				
			||||||
            d_scores[i] = 0.
 | 
					 | 
				
			||||||
        elif scores[i] < scores[best]:
 | 
					 | 
				
			||||||
            d_scores[i] = 0.
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # I doubt this is correct?
 | 
					 | 
				
			||||||
            # Looking for something like Huber loss
 | 
					 | 
				
			||||||
            diff = scores[i] - -costs[i]
 | 
					 | 
				
			||||||
            if diff > eps:
 | 
					 | 
				
			||||||
                d_scores[i] = eps
 | 
					 | 
				
			||||||
            elif diff < -eps:
 | 
					 | 
				
			||||||
                d_scores[i] = -eps
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                d_scores[i] = diff
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def init_states(TransitionSystem moves, docs):
 | 
					 | 
				
			||||||
    cdef Doc doc
 | 
					 | 
				
			||||||
    cdef StateClass state
 | 
					 | 
				
			||||||
    offsets = []
 | 
					 | 
				
			||||||
    states = []
 | 
					 | 
				
			||||||
    offset = 0
 | 
					 | 
				
			||||||
    for i, doc in enumerate(docs):
 | 
					 | 
				
			||||||
        state = StateClass.init(doc.c, doc.length)
 | 
					 | 
				
			||||||
        moves.initialize_state(state.c)
 | 
					 | 
				
			||||||
        states.append(state)
 | 
					 | 
				
			||||||
        offsets.append(offset)
 | 
					 | 
				
			||||||
        offset += len(doc)
 | 
					 | 
				
			||||||
    return states, offsets
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0):
 | 
					 | 
				
			||||||
    cdef StateClass state
 | 
					 | 
				
			||||||
    cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
 | 
					 | 
				
			||||||
    ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
 | 
					 | 
				
			||||||
    if offsets is None:
 | 
					 | 
				
			||||||
        offsets = [0] * len(states)
 | 
					 | 
				
			||||||
    for i, (state, offset) in enumerate(zip(states, offsets)):
 | 
					 | 
				
			||||||
        state.set_context_tokens(ids[i], nF, nB, nS, nL, nR)
 | 
					 | 
				
			||||||
        ids[i] += (ids[i] >= 0) * offset
 | 
					 | 
				
			||||||
    return ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_n_iter = 0
 | 
					 | 
				
			||||||
@layerize
 | 
					 | 
				
			||||||
def print_mean_variance(X, drop=0.):
 | 
					 | 
				
			||||||
    global _n_iter
 | 
					 | 
				
			||||||
    _n_iter += 1
 | 
					 | 
				
			||||||
    fwd_iter = _n_iter
 | 
					 | 
				
			||||||
    means = X.mean(axis=0)
 | 
					 | 
				
			||||||
    variance = X.var(axis=0)
 | 
					 | 
				
			||||||
    print(fwd_iter, "M", ', '.join(('%.2f' % m) for m in means))
 | 
					 | 
				
			||||||
    print(fwd_iter, "V", ', '.join(('%.2f' % m) for m in variance))
 | 
					 | 
				
			||||||
    def backward(dX, sgd=None):
 | 
					 | 
				
			||||||
        means = dX.mean(axis=0)
 | 
					 | 
				
			||||||
        variance = dX.var(axis=0)
 | 
					 | 
				
			||||||
        print(fwd_iter, "dM", ', '.join(('%.2f' % m) for m in means))
 | 
					 | 
				
			||||||
        print(fwd_iter, "dV", ', '.join(('%.2f' % m) for m in variance))
 | 
					 | 
				
			||||||
    return X, backward
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Parser:
 | 
					cdef class Parser:
 | 
				
			||||||
| 
						 | 
					@ -283,6 +144,15 @@ cdef class Parser:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        with (path / 'config.json').open() as file_:
 | 
					        with (path / 'config.json').open() as file_:
 | 
				
			||||||
            cfg = ujson.load(file_)
 | 
					            cfg = ujson.load(file_)
 | 
				
			||||||
 | 
					        # TODO: remove this shim when we don't have to support older data
 | 
				
			||||||
 | 
					        if 'labels' in cfg and 'actions' not in cfg:
 | 
				
			||||||
 | 
					            cfg['actions'] = cfg.pop('labels')
 | 
				
			||||||
 | 
					        # TODO: remove this shim when we don't have to support older data
 | 
				
			||||||
 | 
					        for action_name, labels in dict(cfg.get('actions', {})).items():
 | 
				
			||||||
 | 
					            # We need this to be sorted
 | 
				
			||||||
 | 
					            if isinstance(labels, dict):
 | 
				
			||||||
 | 
					                labels = list(sorted(labels.keys()))
 | 
				
			||||||
 | 
					            cfg['actions'][action_name] = labels
 | 
				
			||||||
        self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
 | 
					        self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
 | 
				
			||||||
        if (path / 'model').exists():
 | 
					        if (path / 'model').exists():
 | 
				
			||||||
            self.model.load(str(path / 'model'))
 | 
					            self.model.load(str(path / 'model'))
 | 
				
			||||||
| 
						 | 
					@ -291,14 +161,14 @@ cdef class Parser:
 | 
				
			||||||
                "Required file %s/model not found when loading" % str(path))
 | 
					                "Required file %s/model not found when loading" % str(path))
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, Vocab vocab, TransitionSystem=None, model=None, **cfg):
 | 
					    def __init__(self, Vocab vocab, TransitionSystem=None, ParserModel model=None, **cfg):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Create a Parser.
 | 
					        Create a Parser.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Arguments:
 | 
					        Arguments:
 | 
				
			||||||
            vocab (Vocab):
 | 
					            vocab (Vocab):
 | 
				
			||||||
                The vocabulary object. Must be shared with documents to be processed.
 | 
					                The vocabulary object. Must be shared with documents to be processed.
 | 
				
			||||||
            model (thinc Model):
 | 
					            model (thinc.linear.AveragedPerceptron):
 | 
				
			||||||
                The statistical model.
 | 
					                The statistical model.
 | 
				
			||||||
        Returns (Parser):
 | 
					        Returns (Parser):
 | 
				
			||||||
            The newly constructed object.
 | 
					            The newly constructed object.
 | 
				
			||||||
| 
						 | 
					@ -308,41 +178,43 @@ cdef class Parser:
 | 
				
			||||||
        self.vocab = vocab
 | 
					        self.vocab = vocab
 | 
				
			||||||
        cfg['actions'] = TransitionSystem.get_actions(**cfg)
 | 
					        cfg['actions'] = TransitionSystem.get_actions(**cfg)
 | 
				
			||||||
        self.moves = TransitionSystem(vocab.strings, cfg['actions'])
 | 
					        self.moves = TransitionSystem(vocab.strings, cfg['actions'])
 | 
				
			||||||
        if model is None:
 | 
					        # TODO: Remove this when we no longer need to support old-style models
 | 
				
			||||||
            self.model, self.feature_maps = self.build_model(**cfg)
 | 
					        if isinstance(cfg.get('features'), basestring):
 | 
				
			||||||
        else:
 | 
					            cfg['features'] = get_templates(cfg['features'])
 | 
				
			||||||
            self.model, self.feature_maps = model
 | 
					        elif 'features' not in cfg:
 | 
				
			||||||
 | 
					            cfg['features'] = self.feature_templates
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.model = ParserModel(cfg['features'])
 | 
				
			||||||
 | 
					        self.model.l1_penalty = cfg.get('L1', 0.0)
 | 
				
			||||||
 | 
					        self.model.learn_rate = cfg.get('learn_rate', 0.001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.cfg = cfg
 | 
					        self.cfg = cfg
 | 
				
			||||||
 | 
					        # TODO: This is a pretty hacky fix to the problem of adding more
 | 
				
			||||||
 | 
					        # labels. The issue is they come in out of order, if labels are
 | 
				
			||||||
 | 
					        # added during training
 | 
				
			||||||
 | 
					        for label in cfg.get('extra_labels', []):
 | 
				
			||||||
 | 
					            self.add_label(label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __reduce__(self):
 | 
					    def __reduce__(self):
 | 
				
			||||||
        return (Parser, (self.vocab, self.moves, self.model), None, None)
 | 
					        return (Parser, (self.vocab, self.moves, self.model), None, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def build_model(self,
 | 
					 | 
				
			||||||
            hidden_width=128, token_vector_width=96, nr_vector=1000,
 | 
					 | 
				
			||||||
            nF=1, nB=1, nS=1, nL=1, nR=1, **cfg):
 | 
					 | 
				
			||||||
        nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
 | 
					 | 
				
			||||||
        with Model.use_device('cpu'):
 | 
					 | 
				
			||||||
            upper = chain(
 | 
					 | 
				
			||||||
                        Maxout(hidden_width, hidden_width),
 | 
					 | 
				
			||||||
                        #print_mean_variance,
 | 
					 | 
				
			||||||
                        zero_init(Affine(self.moves.n_moves, hidden_width)))
 | 
					 | 
				
			||||||
        assert isinstance(upper.ops, NumpyOps)
 | 
					 | 
				
			||||||
        lower = PrecomputableMaxouts(hidden_width, nF=nr_context_tokens, nI=token_vector_width,
 | 
					 | 
				
			||||||
                                     pieces=cfg.get('maxout_pieces', 1))
 | 
					 | 
				
			||||||
        lower.begin_training(lower.ops.allocate((500, token_vector_width)))
 | 
					 | 
				
			||||||
        upper.begin_training(upper.ops.allocate((500, hidden_width)))
 | 
					 | 
				
			||||||
        return upper, lower
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, Doc tokens):
 | 
					    def __call__(self, Doc tokens):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Apply the parser or entity recognizer, setting the annotations onto the Doc object.
 | 
					        Apply the entity recognizer, setting the annotations onto the Doc object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Arguments:
 | 
					        Arguments:
 | 
				
			||||||
            doc (Doc): The document to be processed.
 | 
					            doc (Doc): The document to be processed.
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            None
 | 
					            None
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.parse_batch([tokens])
 | 
					        cdef int nr_feat = self.model.nr_feat
 | 
				
			||||||
 | 
					        with nogil:
 | 
				
			||||||
 | 
					            status = self.parseC(tokens.c, tokens.length, nr_feat)
 | 
				
			||||||
 | 
					        # Check for KeyboardInterrupt etc. Untested
 | 
				
			||||||
 | 
					        PyErr_CheckSignals()
 | 
				
			||||||
 | 
					        if status != 0:
 | 
				
			||||||
 | 
					            raise ParserStateError(tokens)
 | 
				
			||||||
 | 
					        self.moves.finalize_doc(tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def pipe(self, stream, int batch_size=1000, int n_threads=2):
 | 
					    def pipe(self, stream, int batch_size=1000, int n_threads=2):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -356,142 +228,124 @@ cdef class Parser:
 | 
				
			||||||
                The number of threads with which to work on the buffer in parallel.
 | 
					                The number of threads with which to work on the buffer in parallel.
 | 
				
			||||||
        Yields (Doc): Documents, in order.
 | 
					        Yields (Doc): Documents, in order.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
 | 
					        cdef TokenC** doc_ptr = <TokenC**>mem.alloc(batch_size, sizeof(TokenC*))
 | 
				
			||||||
 | 
					        cdef int* lengths = <int*>mem.alloc(batch_size, sizeof(int))
 | 
				
			||||||
 | 
					        cdef Doc doc
 | 
				
			||||||
 | 
					        cdef int i
 | 
				
			||||||
 | 
					        cdef int nr_feat = self.model.nr_feat
 | 
				
			||||||
 | 
					        cdef int status
 | 
				
			||||||
        queue = []
 | 
					        queue = []
 | 
				
			||||||
        for doc in stream:
 | 
					        for doc in stream:
 | 
				
			||||||
 | 
					            doc_ptr[len(queue)] = doc.c
 | 
				
			||||||
 | 
					            lengths[len(queue)] = doc.length
 | 
				
			||||||
            queue.append(doc)
 | 
					            queue.append(doc)
 | 
				
			||||||
            if len(queue) == batch_size:
 | 
					            if len(queue) == batch_size:
 | 
				
			||||||
                self.parse_batch(queue)
 | 
					                with nogil:
 | 
				
			||||||
 | 
					                    for i in cython.parallel.prange(batch_size, num_threads=n_threads):
 | 
				
			||||||
 | 
					                        status = self.parseC(doc_ptr[i], lengths[i], nr_feat)
 | 
				
			||||||
 | 
					                        if status != 0:
 | 
				
			||||||
 | 
					                            with gil:
 | 
				
			||||||
 | 
					                                raise ParserStateError(queue[i])
 | 
				
			||||||
 | 
					                PyErr_CheckSignals()
 | 
				
			||||||
                for doc in queue:
 | 
					                for doc in queue:
 | 
				
			||||||
                    self.moves.finalize_doc(doc)
 | 
					                    self.moves.finalize_doc(doc)
 | 
				
			||||||
                    yield doc
 | 
					                    yield doc
 | 
				
			||||||
                queue = []
 | 
					                queue = []
 | 
				
			||||||
        if queue:
 | 
					        batch_size = len(queue)
 | 
				
			||||||
            self.parse_batch(queue)
 | 
					        with nogil:
 | 
				
			||||||
            for doc in queue:
 | 
					            for i in cython.parallel.prange(batch_size, num_threads=n_threads):
 | 
				
			||||||
                self.moves.finalize_doc(doc)
 | 
					                status = self.parseC(doc_ptr[i], lengths[i], nr_feat)
 | 
				
			||||||
                yield doc
 | 
					                if status != 0:
 | 
				
			||||||
 | 
					                    with gil:
 | 
				
			||||||
    def parse_batch(self, docs_tokvecs):
 | 
					                        raise ParserStateError(queue[i])
 | 
				
			||||||
        cdef:
 | 
					        PyErr_CheckSignals()
 | 
				
			||||||
            int nC
 | 
					        for doc in queue:
 | 
				
			||||||
            Doc doc
 | 
					 | 
				
			||||||
            StateClass state
 | 
					 | 
				
			||||||
            np.ndarray py_scores
 | 
					 | 
				
			||||||
            int[500] is_valid # Hacks for now
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cuda_stream = get_cuda_stream()
 | 
					 | 
				
			||||||
        docs, tokvecs = docs_tokvecs
 | 
					 | 
				
			||||||
        lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps,
 | 
					 | 
				
			||||||
                                                 cuda_stream)
 | 
					 | 
				
			||||||
        upper_model = self.model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        states, offsets = init_states(self.moves, docs)
 | 
					 | 
				
			||||||
        all_states = list(states)
 | 
					 | 
				
			||||||
        todo = [st for st in zip(states, offsets) if not st[0].py_is_final()]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while todo:
 | 
					 | 
				
			||||||
            states, offsets = zip(*todo)
 | 
					 | 
				
			||||||
            token_ids = extract_token_ids(states, offsets=offsets)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            py_scores = upper_model(lower_model(token_ids)[0])
 | 
					 | 
				
			||||||
            scores = <float*>py_scores.data
 | 
					 | 
				
			||||||
            nC = py_scores.shape[1]
 | 
					 | 
				
			||||||
            for state, offset in zip(states, offsets):
 | 
					 | 
				
			||||||
                self.moves.set_valid(is_valid, state.c)
 | 
					 | 
				
			||||||
                guess = arg_max_if_valid(scores, is_valid, nC)
 | 
					 | 
				
			||||||
                action = self.moves.c[guess]
 | 
					 | 
				
			||||||
                action.do(state.c, action.label)
 | 
					 | 
				
			||||||
                scores += nC
 | 
					 | 
				
			||||||
            todo = [st for st in todo if not st[0].py_is_final()]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for state, doc in zip(all_states, docs):
 | 
					 | 
				
			||||||
            self.moves.finalize_state(state.c)
 | 
					 | 
				
			||||||
            for i in range(doc.length):
 | 
					 | 
				
			||||||
                doc.c[i] = state.c._sent[i]
 | 
					 | 
				
			||||||
            self.moves.finalize_doc(doc)
 | 
					            self.moves.finalize_doc(doc)
 | 
				
			||||||
 | 
					            yield doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, docs_tokvecs, golds, drop=0., sgd=None):
 | 
					    cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil:
 | 
				
			||||||
        cdef:
 | 
					        state = new StateC(tokens, length)
 | 
				
			||||||
            int nC
 | 
					        # NB: This can change self.moves.n_moves!
 | 
				
			||||||
            Doc doc
 | 
					        # I think this causes memory errors if called by .pipe()
 | 
				
			||||||
            StateClass state
 | 
					        self.moves.initialize_state(state)
 | 
				
			||||||
            np.ndarray scores
 | 
					        nr_class = self.moves.n_moves
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        docs, tokvecs = docs_tokvecs
 | 
					        cdef ExampleC eg
 | 
				
			||||||
        cuda_stream = get_cuda_stream()
 | 
					        eg.nr_feat = nr_feat
 | 
				
			||||||
        lower_model = get_greedy_model_for_batch(len(docs),
 | 
					        eg.nr_atom = CONTEXT_SIZE
 | 
				
			||||||
                        tokvecs, self.feature_maps, cuda_stream=cuda_stream,
 | 
					        eg.nr_class = nr_class
 | 
				
			||||||
                        drop=drop)
 | 
					        eg.features = <FeatureC*>calloc(sizeof(FeatureC), nr_feat)
 | 
				
			||||||
        if isinstance(docs, Doc) and isinstance(golds, GoldParse):
 | 
					        eg.atoms = <atom_t*>calloc(sizeof(atom_t), CONTEXT_SIZE)
 | 
				
			||||||
            return self.update(([docs], tokvecs), [golds], drop=drop)
 | 
					        eg.scores = <weight_t*>calloc(sizeof(weight_t), nr_class)
 | 
				
			||||||
        for gold in golds:
 | 
					        eg.is_valid = <int*>calloc(sizeof(int), nr_class)
 | 
				
			||||||
            self.moves.preprocess_gold(gold)
 | 
					        cdef int i
 | 
				
			||||||
 | 
					        while not state.is_final():
 | 
				
			||||||
 | 
					            eg.nr_feat = self.model.set_featuresC(eg.atoms, eg.features, state)
 | 
				
			||||||
 | 
					            self.moves.set_valid(eg.is_valid, state)
 | 
				
			||||||
 | 
					            self.model.set_scoresC(eg.scores, eg.features, eg.nr_feat)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        states, offsets = init_states(self.moves, docs)
 | 
					            guess = VecVec.arg_max_if_true(eg.scores, eg.is_valid, eg.nr_class)
 | 
				
			||||||
 | 
					            if guess < 0:
 | 
				
			||||||
 | 
					                return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        todo = zip(states, offsets, golds)
 | 
					            action = self.moves.c[guess]
 | 
				
			||||||
        todo = filter(lambda sp: not sp[0].py_is_final(), todo)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            action.do(state, action.label)
 | 
				
			||||||
 | 
					            memset(eg.scores, 0, sizeof(eg.scores[0]) * eg.nr_class)
 | 
				
			||||||
 | 
					            for i in range(eg.nr_class):
 | 
				
			||||||
 | 
					                eg.is_valid[i] = 1
 | 
				
			||||||
 | 
					        self.moves.finalize_state(state)
 | 
				
			||||||
 | 
					        for i in range(length):
 | 
				
			||||||
 | 
					            tokens[i] = state._sent[i]
 | 
				
			||||||
 | 
					        del state
 | 
				
			||||||
 | 
					        free(eg.features)
 | 
				
			||||||
 | 
					        free(eg.atoms)
 | 
				
			||||||
 | 
					        free(eg.scores)
 | 
				
			||||||
 | 
					        free(eg.is_valid)
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(self, Doc tokens, GoldParse gold, itn=0, double drop=0.0):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Update the statistical model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					            doc (Doc):
 | 
				
			||||||
 | 
					                The example document for the update.
 | 
				
			||||||
 | 
					            gold (GoldParse):
 | 
				
			||||||
 | 
					                The gold-standard annotations, to calculate the loss.
 | 
				
			||||||
 | 
					        Returns (float):
 | 
				
			||||||
 | 
					            The loss on this example.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.moves.preprocess_gold(gold)
 | 
				
			||||||
 | 
					        cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
 | 
				
			||||||
 | 
					        self.moves.initialize_state(stcls.c)
 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
        is_valid = <int*>mem.alloc(len(states) * self.moves.n_moves, sizeof(int))
 | 
					        cdef Example eg = Example(
 | 
				
			||||||
        costs = <float*>mem.alloc(len(states) * self.moves.n_moves, sizeof(float))
 | 
					                nr_class=self.moves.n_moves,
 | 
				
			||||||
 | 
					                nr_atom=CONTEXT_SIZE,
 | 
				
			||||||
 | 
					                nr_feat=self.model.nr_feat)
 | 
				
			||||||
 | 
					        cdef weight_t loss = 0
 | 
				
			||||||
 | 
					        cdef Transition action
 | 
				
			||||||
 | 
					        cdef double dropout_rate = self.cfg.get('dropout', drop)
 | 
				
			||||||
 | 
					        while not stcls.is_final():
 | 
				
			||||||
 | 
					            eg.c.nr_feat = self.model.set_featuresC(eg.c.atoms, eg.c.features,
 | 
				
			||||||
 | 
					                                                    stcls.c)
 | 
				
			||||||
 | 
					            dropout(eg.c.features, eg.c.nr_feat, dropout_rate)
 | 
				
			||||||
 | 
					            self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
 | 
				
			||||||
 | 
					            self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
 | 
				
			||||||
 | 
					            guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
 | 
				
			||||||
 | 
					            self.model.update(eg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        upper_model = self.model
 | 
					            action = self.moves.c[guess]
 | 
				
			||||||
        d_tokens = self.feature_maps.ops.allocate(tokvecs.shape)
 | 
					            action.do(stcls.c, action.label)
 | 
				
			||||||
        backprops = []
 | 
					            loss += eg.costs[guess]
 | 
				
			||||||
        n_tokens = tokvecs.shape[0]
 | 
					            eg.fill_scores(0, eg.c.nr_class)
 | 
				
			||||||
        nF = self.feature_maps.nF
 | 
					            eg.fill_costs(0, eg.c.nr_class)
 | 
				
			||||||
        loss = 0.
 | 
					            eg.fill_is_valid(1, eg.c.nr_class)
 | 
				
			||||||
        total = 1e-4
 | 
					 | 
				
			||||||
        follow_gold = False
 | 
					 | 
				
			||||||
        cupy = self.feature_maps.ops.xp
 | 
					 | 
				
			||||||
        while len(todo) >= 4:
 | 
					 | 
				
			||||||
            states, offsets, golds = zip(*todo)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            token_ids = extract_token_ids(states, offsets=offsets)
 | 
					        self.moves.finalize_state(stcls.c)
 | 
				
			||||||
            lower, bp_lower = lower_model(token_ids, drop=drop)
 | 
					        return loss
 | 
				
			||||||
            scores, bp_scores = upper_model.begin_update(lower, drop=drop)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            d_scores = get_batch_loss(self.moves, states, golds, scores)
 | 
					 | 
				
			||||||
            loss += numpy.abs(d_scores).sum()
 | 
					 | 
				
			||||||
            total += d_scores.shape[0]
 | 
					 | 
				
			||||||
            d_lower = bp_scores(d_scores, sgd=sgd)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if isinstance(tokvecs, cupy.ndarray):
 | 
					 | 
				
			||||||
                gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i', order='C')
 | 
					 | 
				
			||||||
                gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f', order='C')
 | 
					 | 
				
			||||||
                gpu_tok_ids.set(token_ids, stream=cuda_stream)
 | 
					 | 
				
			||||||
                gpu_d_lower.set(d_lower, stream=cuda_stream)
 | 
					 | 
				
			||||||
                backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower))
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                backprops.append((token_ids, d_lower, bp_lower))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            c_scores = <float*>scores.data
 | 
					 | 
				
			||||||
            for state, gold in zip(states, golds):
 | 
					 | 
				
			||||||
                if follow_gold:
 | 
					 | 
				
			||||||
                    self.moves.set_costs(is_valid, costs, state, gold)
 | 
					 | 
				
			||||||
                    guess = arg_max_if_gold(c_scores, costs, is_valid, scores.shape[1])
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    self.moves.set_valid(is_valid, state.c)
 | 
					 | 
				
			||||||
                    guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
 | 
					 | 
				
			||||||
                action = self.moves.c[guess]
 | 
					 | 
				
			||||||
                action.do(state.c, action.label)
 | 
					 | 
				
			||||||
                c_scores += scores.shape[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            todo = filter(lambda sp: not sp[0].py_is_final(), todo)
 | 
					 | 
				
			||||||
        # This tells CUDA to block --- so we know our copies are complete.
 | 
					 | 
				
			||||||
        cuda_stream.synchronize()
 | 
					 | 
				
			||||||
        for token_ids, d_lower, bp_lower in backprops:
 | 
					 | 
				
			||||||
            d_state_features = bp_lower(d_lower, sgd=sgd)
 | 
					 | 
				
			||||||
            active_feats = token_ids * (token_ids >= 0)
 | 
					 | 
				
			||||||
            active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1))
 | 
					 | 
				
			||||||
            if hasattr(self.feature_maps.ops.xp, 'scatter_add'):
 | 
					 | 
				
			||||||
                self.feature_maps.ops.xp.scatter_add(d_tokens,
 | 
					 | 
				
			||||||
                    token_ids, d_state_features * active_feats)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self.model.ops.xp.add.at(d_tokens,
 | 
					 | 
				
			||||||
                    token_ids, d_state_features * active_feats)
 | 
					 | 
				
			||||||
        return d_tokens, loss / total
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def step_through(self, Doc doc, GoldParse gold=None):
 | 
					    def step_through(self, Doc doc, GoldParse gold=None):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -528,6 +382,18 @@ cdef class Parser:
 | 
				
			||||||
                self.cfg.setdefault('extra_labels', []).append(label)
 | 
					                self.cfg.setdefault('extra_labels', []).append(label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int dropout(FeatureC* feats, int nr_feat, float prob) except -1:
 | 
				
			||||||
 | 
					    if prob <= 0 or prob >= 1.:
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					    cdef double[::1] py_probs = numpy.random.uniform(0., 1., nr_feat)
 | 
				
			||||||
 | 
					    cdef double* probs = &py_probs[0]
 | 
				
			||||||
 | 
					    for i in range(nr_feat):
 | 
				
			||||||
 | 
					        if probs[i] >= prob:
 | 
				
			||||||
 | 
					            feats[i].value /= prob
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            feats[i].value = 0.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class StepwiseState:
 | 
					cdef class StepwiseState:
 | 
				
			||||||
    cdef readonly StateClass stcls
 | 
					    cdef readonly StateClass stcls
 | 
				
			||||||
    cdef readonly Example eg
 | 
					    cdef readonly Example eg
 | 
				
			||||||
| 
						 | 
					@ -597,11 +463,11 @@ cdef class StepwiseState:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(self):
 | 
					    def predict(self):
 | 
				
			||||||
        self.eg.reset()
 | 
					        self.eg.reset()
 | 
				
			||||||
        #self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
 | 
					        self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
 | 
				
			||||||
        #                                                    self.stcls.c)
 | 
					                                                            self.stcls.c)
 | 
				
			||||||
        self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
 | 
					        self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
 | 
				
			||||||
        #self.parser.model.set_scoresC(self.eg.c.scores,
 | 
					        self.parser.model.set_scoresC(self.eg.c.scores,
 | 
				
			||||||
        #    self.eg.c.features, self.eg.c.nr_feat)
 | 
					            self.eg.c.features, self.eg.c.nr_feat)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cdef Transition action = self.parser.moves.c[self.eg.guess]
 | 
					        cdef Transition action = self.parser.moves.c[self.eg.guess]
 | 
				
			||||||
        return self.parser.moves.move_name(action.move, action.label)
 | 
					        return self.parser.moves.move_name(action.move, action.label)
 | 
				
			||||||
| 
						 | 
					@ -640,26 +506,10 @@ class ParserStateError(ValueError):
 | 
				
			||||||
            "Please include the text that the parser failed on, which is:\n"
 | 
					            "Please include the text that the parser failed on, which is:\n"
 | 
				
			||||||
            "%s" % repr(doc.text))
 | 
					            "%s" % repr(doc.text))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, int n) nogil:
 | 
				
			||||||
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, const int* is_valid, int n) nogil:
 | 
					 | 
				
			||||||
    # Find minimum cost
 | 
					 | 
				
			||||||
    cdef float cost = 1
 | 
					 | 
				
			||||||
    for i in range(n):
 | 
					 | 
				
			||||||
        if is_valid[i] and costs[i] < cost:
 | 
					 | 
				
			||||||
            cost = costs[i]
 | 
					 | 
				
			||||||
    # Now find best-scoring with that cost
 | 
					 | 
				
			||||||
    cdef int best = -1
 | 
					    cdef int best = -1
 | 
				
			||||||
    for i in range(n):
 | 
					    for i in range(n):
 | 
				
			||||||
        if costs[i] <= cost and is_valid[i]:
 | 
					        if costs[i] <= 0:
 | 
				
			||||||
            if best == -1 or scores[i] > scores[best]:
 | 
					 | 
				
			||||||
                best = i
 | 
					 | 
				
			||||||
    return best
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
 | 
					 | 
				
			||||||
    cdef int best = -1
 | 
					 | 
				
			||||||
    for i in range(n):
 | 
					 | 
				
			||||||
        if is_valid[i] >= 1:
 | 
					 | 
				
			||||||
            if best == -1 or scores[i] > scores[best]:
 | 
					            if best == -1 or scores[i] > scores[best]:
 | 
				
			||||||
                best = i
 | 
					                best = i
 | 
				
			||||||
    return best
 | 
					    return best
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user