mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 13:11:03 +03:00 
			
		
		
		
	Improve integration of NN parser, to support unified training API
This commit is contained in:
		
							parent
							
								
									48de4ed49f
								
							
						
					
					
						commit
						a9edb3aa1d
					
				
							
								
								
									
										28
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -118,6 +118,29 @@ class PrecomputableMaxouts(Model): | |||
|             return dXf | ||||
|         return Yfp, backward | ||||
| 
 | ||||
| def Tok2Vec(width, embed_size, preprocess=None): | ||||
|     cols = [LOWER, PREFIX, SUFFIX, SHAPE] | ||||
|     with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): | ||||
|         lower = get_col(cols.index(LOWER))   >> HashEmbed(width, embed_size) | ||||
|         prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2) | ||||
|         suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2) | ||||
|         shape = get_col(cols.index(SHAPE))   >> HashEmbed(width, embed_size//2) | ||||
| 
 | ||||
|         tok2vec = ( | ||||
|             flatten | ||||
|             >> (lower | prefix | suffix | shape ) | ||||
|             >> Maxout(width, width*4, pieces=3) | ||||
|             >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|             >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|             >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|             >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|         ) | ||||
|         if preprocess is not None: | ||||
|             tok2vec = preprocess >> tok2vec | ||||
|         # Work around thinc API limitations :(. TODO: Revise in Thinc 7 | ||||
|         tok2vec.nO = width | ||||
|     return tok2vec | ||||
| 
 | ||||
| 
 | ||||
| def get_col(idx): | ||||
|     def forward(X, drop=0.): | ||||
|  | @ -125,7 +148,6 @@ def get_col(idx): | |||
|             ops = NumpyOps() | ||||
|         else: | ||||
|             ops = CupyOps() | ||||
|         assert len(X.shape) <= 3 | ||||
|         output = ops.xp.ascontiguousarray(X[:, idx]) | ||||
|         def backward(y, sgd=None): | ||||
|             dX = ops.allocate(X.shape) | ||||
|  | @ -171,8 +193,10 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.): | |||
| def flatten(seqs, drop=0.): | ||||
|     if isinstance(seqs[0], numpy.ndarray): | ||||
|         ops = NumpyOps() | ||||
|     else: | ||||
|     elif hasattr(CupyOps.xp, 'ndarray') and isinstance(seqs[0], CupyOps.xp.ndarray): | ||||
|         ops = CupyOps() | ||||
|     else: | ||||
|         raise ValueError("Unable to flatten sequence of type %s" % type(seqs[0])) | ||||
|     lengths = [len(seq) for seq in seqs] | ||||
|     def finish_update(d_X, sgd=None): | ||||
|         return ops.unflatten(d_X, lengths) | ||||
|  |  | |||
|  | @ -64,10 +64,15 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_ | |||
| 
 | ||||
|     with Language.train(output_path, train_data, | ||||
|                         pos=tagger_cfg, deps=parser_cfg, ner=entity_cfg) as trainer: | ||||
| 
 | ||||
|         for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): | ||||
|             for doc, gold in epoch: | ||||
|                 trainer.update(doc, gold) | ||||
|             dev_scores = trainer.evaluate(dev_data).scores if dev_data else defaultdict(float) | ||||
|             for docs, golds in partition_all(12, epoch): | ||||
|                 trainer.update(docs, golds) | ||||
| 
 | ||||
|             if dev_data: | ||||
|                 dev_scores = trainer.evaluate(dev_data).scores | ||||
|             else: | ||||
|                 defaultdict(float) | ||||
|             print_progress(itn, trainer.nlp.parser.model.nr_weight, | ||||
|                            trainer.nlp.parser.model.nr_active_feat, | ||||
|                            **dev_scores) | ||||
|  |  | |||
|  | @ -247,6 +247,7 @@ class Language(object): | |||
|         self.tokenizer = self.Defaults.create_tokenizer(self) \ | ||||
|                          if 'tokenizer' not in overrides \ | ||||
|                          else overrides['tokenizer'] | ||||
|   | ||||
|         self.tagger    = self.Defaults.create_tagger(self) \ | ||||
|                          if 'tagger' not in overrides \ | ||||
|                          else overrides['tagger'] | ||||
|  |  | |||
|  | @ -27,40 +27,26 @@ from thinc.neural._classes.resnet import Residual | |||
| from thinc.neural._classes.batchnorm import BatchNorm as BN | ||||
| 
 | ||||
| from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP | ||||
| from ._ml import flatten, get_col, doc2feats | ||||
| from ._ml import Tok2Vec, flatten, get_col, doc2feats | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class TokenVectorEncoder(object): | ||||
|     '''Assign position-sensitive vectors to tokens, using a CNN or RNN.''' | ||||
|     def __init__(self, vocab, token_vector_width, **cfg): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def Model(cls, width=128, embed_size=5000, **cfg): | ||||
|         return Tok2Vec(width, embed_size, preprocess=False) | ||||
| 
 | ||||
|     def __init__(self, vocab, model=True, **cfg): | ||||
|         self.vocab = vocab | ||||
|         self.doc2feats = doc2feats() | ||||
|         self.model = self.build_model(vocab.lang, token_vector_width, **cfg) | ||||
|         self.tagger = chain( | ||||
|                         self.model, | ||||
|                         Softmax(self.vocab.morphology.n_tags, | ||||
|                                 token_vector_width)) | ||||
| 
 | ||||
|     def build_model(self, lang, width, embed_size=5000, **cfg): | ||||
|         cols = self.doc2feats.cols | ||||
|         with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): | ||||
|             lower = get_col(cols.index(LOWER))   >> (HashEmbed(width, embed_size) | ||||
|                                                      +HashEmbed(width, embed_size)) | ||||
|             prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2) | ||||
|             suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2) | ||||
|             shape = get_col(cols.index(SHAPE))   >> HashEmbed(width, embed_size//2) | ||||
| 
 | ||||
|             tok2vec = ( | ||||
|                 flatten | ||||
|                 >> (lower | prefix | suffix | shape ) | ||||
|                 >> Maxout(width, pieces=3) | ||||
|                 >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|                 >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|                 >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|                 >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) | ||||
|             ) | ||||
|         return tok2vec | ||||
|         self.model = self.Model() if model is True else model | ||||
|         if self.model not in (None, False): | ||||
|             self.tagger = chain( | ||||
|                             self.model, | ||||
|                             Softmax(self.vocab.morphology.n_tags, | ||||
|                                     self.model.nO)) | ||||
| 
 | ||||
|     def pipe(self, docs): | ||||
|         docs = list(docs) | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ cdef cppclass StateC: | |||
|     Entity* _ents | ||||
|     TokenC _empty_token | ||||
|     int length | ||||
|     int offset | ||||
|     int _s_i | ||||
|     int _b_i | ||||
|     int _e_i | ||||
|  |  | |||
|  | @ -10,9 +10,8 @@ from ._state cimport StateC | |||
| 
 | ||||
| cdef class Parser: | ||||
|     cdef readonly Vocab vocab | ||||
|     cdef readonly object model | ||||
|     cdef public object model | ||||
|     cdef readonly TransitionSystem moves | ||||
|     cdef readonly object cfg | ||||
|     cdef public object feature_maps | ||||
| 
 | ||||
|     #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil | ||||
|  |  | |||
|  | @ -1,5 +1,7 @@ | |||
| # cython: infer_types=True | ||||
| # cython: profile=True | ||||
| # cython: cdivision=True | ||||
| # cython: boundscheck=False | ||||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals, print_function | ||||
| 
 | ||||
|  | @ -30,11 +32,12 @@ from preshed.maps cimport MapStruct | |||
| from preshed.maps cimport map_get | ||||
| 
 | ||||
| from thinc.api import layerize, chain | ||||
| from thinc.neural import BatchNorm, Model, Affine, ELU, ReLu, Maxout | ||||
| from thinc.neural import Model, Affine, ELU, ReLu, Maxout | ||||
| from thinc.neural.ops import NumpyOps | ||||
| 
 | ||||
| from ..util import get_cuda_stream | ||||
| from ..util import get_async, get_cuda_stream | ||||
| from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts | ||||
| from .._ml import Tok2Vec, doc2feats | ||||
| 
 | ||||
| from . import _parse_features | ||||
| from ._parse_features cimport CONTEXT_SIZE | ||||
|  | @ -61,8 +64,7 @@ def set_debug(val): | |||
|     DEBUG = val | ||||
| 
 | ||||
| 
 | ||||
| def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None, | ||||
|                                drop=0.): | ||||
| cdef class precompute_hiddens: | ||||
|     '''Allow a model to be "primed" by pre-computing input features in bulk. | ||||
| 
 | ||||
|     This is used for the parser, where we want to take a batch of documents, | ||||
|  | @ -79,95 +81,88 @@ def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=Non | |||
|     we can do all our hard maths up front, packed into large multiplications, | ||||
|     and do the hard-to-program parsing on the CPU. | ||||
|     ''' | ||||
|     gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop) | ||||
|     cdef np.ndarray cached | ||||
|     if not isinstance(gpu_cached, numpy.ndarray): | ||||
|         cached = gpu_cached.get(stream=cuda_stream) | ||||
|     else: | ||||
|         cached = gpu_cached | ||||
|     nF = gpu_cached.shape[1] | ||||
|     nO = gpu_cached.shape[2] | ||||
|     nP = gpu_cached.shape[3] | ||||
|     ops = lower_model.ops | ||||
|     features = numpy.zeros((batch_size, nO, nP), dtype='f') | ||||
|     synchronized = False | ||||
|     cdef int nF, nO, nP | ||||
|     cdef bint _is_synchronized | ||||
|     cdef public object ops | ||||
|     cdef np.ndarray _features | ||||
|     cdef np.ndarray _cached | ||||
|     cdef object _cuda_stream | ||||
|     cdef object _bp_hiddens | ||||
| 
 | ||||
|     def forward(token_ids, drop=0.): | ||||
|         nonlocal synchronized | ||||
|         if not synchronized and cuda_stream is not None: | ||||
|             cuda_stream.synchronize() | ||||
|             synchronized = True | ||||
|         # This is tricky, but: | ||||
|     def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None, drop=0.): | ||||
|         gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop) | ||||
|         cdef np.ndarray cached | ||||
|         if not isinstance(gpu_cached, numpy.ndarray): | ||||
|             # Note the passing of cuda_stream here: it lets | ||||
|             # cupy make the copy asynchronously. | ||||
|             # We then have to block before first use. | ||||
|             cached = gpu_cached.get(stream=cuda_stream) | ||||
|         else: | ||||
|             cached = gpu_cached | ||||
|         self.nF = cached.shape[1] | ||||
|         self.nO = cached.shape[2] | ||||
|         self.nP = cached.shape[3] | ||||
|         self.ops = lower_model.ops | ||||
|         self._features = numpy.zeros((batch_size, self.nO, self.nP), dtype='f') | ||||
|         self._is_synchronized = False | ||||
|         self._cuda_stream = cuda_stream | ||||
|         self._cached = cached | ||||
|         self._bp_hiddens = bp_features | ||||
| 
 | ||||
|     def __call__(self, X): | ||||
|         return self.begin_update(X)[0] | ||||
| 
 | ||||
|     def begin_update(self, token_ids, drop=0.): | ||||
|         self._features.fill(0) | ||||
|         if not self._is_synchronized \ | ||||
|         and self._cuda_stream is not None: | ||||
|             self._cuda_stream.synchronize() | ||||
|             self._synchronized = True | ||||
|         # This is tricky, but (assuming GPU available); | ||||
|         # - Input to forward on CPU | ||||
|         # - Output from forward on CPU | ||||
|         # - Input to backward on GPU! | ||||
|         # - Output from backward on GPU | ||||
|         nonlocal features | ||||
|         features = features[:len(token_ids)] | ||||
|         features.fill(0) | ||||
|         cdef float[:, :, ::1] feats = features | ||||
|         cdef np.ndarray state_vector = self._features[:len(token_ids)] | ||||
|         cdef np.ndarray hiddens = self._cached | ||||
|         bp_hiddens = self._bp_hiddens | ||||
| 
 | ||||
|         cdef int[:, ::1] ids = token_ids | ||||
|         _sum_features(<float*>&feats[0,0,0], | ||||
|             <float*>cached.data, &ids[0,0], | ||||
|             token_ids.shape[0], nF, nO*nP) | ||||
|         self._sum_features(<float*>state_vector.data, | ||||
|             <float*>hiddens.data, &ids[0,0], | ||||
|             token_ids.shape[0], self.nF, self.nO*self.nP) | ||||
| 
 | ||||
|         if nP >= 2: | ||||
|             best, which = ops.maxout(features) | ||||
|         else: | ||||
|             best = features.reshape((features.shape[0], features.shape[1])) | ||||
|             which = None | ||||
|         output, bp_output = self._apply_nonlinearity(state_vector)  | ||||
| 
 | ||||
|         def backward(d_best, sgd=None): | ||||
|         def backward(d_output, sgd=None): | ||||
|             # This will usually be on GPU | ||||
|             if isinstance(d_best, numpy.ndarray): | ||||
|                 d_best = ops.xp.array(d_best) | ||||
|             if nP >= 2: | ||||
|                 d_features = ops.backprop_maxout(d_best, which, nP) | ||||
|             else: | ||||
|                 d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1)) | ||||
|             d_tokens = bp_features((d_features, token_ids), sgd) | ||||
|             if isinstance(d_output, numpy.ndarray): | ||||
|                 d_output = self.ops.xp.array(d_output) | ||||
|             d_state_vector = bp_output(d_output, sgd) | ||||
|             d_tokens = bp_hiddens((d_state_vector, token_ids), sgd) | ||||
|             return d_tokens | ||||
|         return output, backward | ||||
| 
 | ||||
|         return best, backward | ||||
|     def _apply_nonlinearity(self, X): | ||||
|         if self.nP < 2: | ||||
|             return X.reshape(X.shape[:2]), lambda dX, sgd=None: dX.reshape(X.shape) | ||||
|         best, which = self.ops.maxout(X) | ||||
|         return best, lambda dX, sgd=None: self.ops.backprop_maxout(dX, which, self.nP) | ||||
| 
 | ||||
|     return forward | ||||
| 
 | ||||
| 
 | ||||
| cdef void _sum_features(float* output, | ||||
|         const float* cached, const int* token_ids, int B, int F, int O) nogil: | ||||
|     cdef int idx, b, f, i | ||||
|     cdef const float* feature | ||||
|     for b in range(B): | ||||
|         for f in range(F): | ||||
|             if token_ids[f] < 0: | ||||
|                 continue | ||||
|             idx = token_ids[f] * F * O + f*O | ||||
|             feature = &cached[idx] | ||||
|             for i in range(O): | ||||
|                 output[i] += feature[i] | ||||
|         output += O | ||||
|         token_ids += F | ||||
| 
 | ||||
| 
 | ||||
| def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores): | ||||
|     cdef StateClass state | ||||
|     cdef GoldParse gold | ||||
|     cdef Pool mem = Pool() | ||||
|     cdef int i | ||||
|     is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int)) | ||||
|     costs = <float*>mem.alloc(moves.n_moves, sizeof(float)) | ||||
|     cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f', | ||||
|                                            order='c') | ||||
|     c_d_scores = <float*>d_scores.data | ||||
|     for i, (state, gold) in enumerate(zip(states, golds)): | ||||
|         memset(is_valid, 0, moves.n_moves * sizeof(int)) | ||||
|         memset(costs, 0, moves.n_moves * sizeof(float)) | ||||
|         moves.set_costs(is_valid, costs, state, gold) | ||||
|         cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1]) | ||||
|         #cpu_regression_loss(c_d_scores, | ||||
|         #    costs, is_valid, &scores[i, 0], d_scores.shape[1]) | ||||
|         c_d_scores += d_scores.shape[1] | ||||
|     return d_scores | ||||
|     cdef void _sum_features(self, float* output, | ||||
|             const float* cached, const int* token_ids, int B, int F, int O) nogil: | ||||
|         cdef int idx, b, f, i | ||||
|         cdef const float* feature | ||||
|         for b in range(B): | ||||
|             for f in range(F): | ||||
|                 if token_ids[f] < 0: | ||||
|                     continue | ||||
|                 idx = token_ids[f] * F * O + f*O | ||||
|                 feature = &cached[idx] | ||||
|                 for i in range(O): | ||||
|                     output[i] += feature[i] | ||||
|             output += O | ||||
|             token_ids += F | ||||
| 
 | ||||
| 
 | ||||
| cdef void cpu_log_loss(float* d_scores, | ||||
|  | @ -217,121 +212,62 @@ cdef void cpu_regression_loss(float* d_scores, | |||
|                 d_scores[i] = diff | ||||
| 
 | ||||
| 
 | ||||
| def init_states(TransitionSystem moves, docs): | ||||
|     cdef Doc doc | ||||
|     cdef StateClass state | ||||
|     offsets = [] | ||||
|     states = [] | ||||
|     offset = 0 | ||||
|     for i, doc in enumerate(docs): | ||||
|         state = StateClass.init(doc.c, doc.length) | ||||
|         moves.initialize_state(state.c) | ||||
|         states.append(state) | ||||
|         offsets.append(offset) | ||||
|         offset += len(doc) | ||||
|     return states, offsets | ||||
| 
 | ||||
| 
 | ||||
| def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0): | ||||
|     cdef StateClass state | ||||
|     cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR) | ||||
|     ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c') | ||||
|     if offsets is None: | ||||
|         offsets = [0] * len(states) | ||||
|     for i, (state, offset) in enumerate(zip(states, offsets)): | ||||
|         state.set_context_tokens(ids[i], nF, nB, nS, nL, nR) | ||||
|         ids[i] += (ids[i] >= 0) * offset | ||||
|     return ids | ||||
| 
 | ||||
| 
 | ||||
| _n_iter = 0 | ||||
| @layerize | ||||
| def print_mean_variance(X, drop=0.): | ||||
|     global _n_iter | ||||
|     _n_iter += 1 | ||||
|     fwd_iter = _n_iter | ||||
|     means = X.mean(axis=0) | ||||
|     variance = X.var(axis=0) | ||||
|     print(fwd_iter, "M", ', '.join(('%.2f' % m) for m in means)) | ||||
|     print(fwd_iter, "V", ', '.join(('%.2f' % m) for m in variance)) | ||||
|     def backward(dX, sgd=None): | ||||
|         means = dX.mean(axis=0) | ||||
|         variance = dX.var(axis=0) | ||||
|         print(fwd_iter, "dM", ', '.join(('%.2f' % m) for m in means)) | ||||
|         print(fwd_iter, "dV", ', '.join(('%.2f' % m) for m in variance)) | ||||
|     return X, backward | ||||
| 
 | ||||
| 
 | ||||
| cdef class Parser: | ||||
|     """ | ||||
|     Base class of the DependencyParser and EntityRecognizer. | ||||
|     """ | ||||
|     @classmethod | ||||
|     def load(cls, path, Vocab vocab, TransitionSystem=None, require=False, **cfg): | ||||
|         """ | ||||
|         Load the statistical model from the supplied path. | ||||
|     def Model(cls, nr_class, tok2vec=None, hidden_width=128, **cfg): | ||||
|         if tok2vec is None: | ||||
|             tok2vec = Tok2Vec(hidden_width, 5000, preprocess=doc2feats()) | ||||
|         token_vector_width = tok2vec.nO | ||||
|         nr_context_tokens = StateClass.nr_context_tokens() | ||||
|         lower = PrecomputableMaxouts(hidden_width, | ||||
|                     nF=nr_context_tokens, | ||||
|                     nI=token_vector_width, | ||||
|                     pieces=cfg.get('maxout_pieces', 1)) | ||||
| 
 | ||||
|         Arguments: | ||||
|             path (Path): | ||||
|                 The path to load from. | ||||
|             vocab (Vocab): | ||||
|                 The vocabulary. Must be shared by the documents to be processed. | ||||
|             require (bool): | ||||
|                 Whether to raise an error if the files are not found. | ||||
|         Returns (Parser): | ||||
|             The newly constructed object. | ||||
|         """ | ||||
|         with (path / 'config.json').open() as file_: | ||||
|             cfg = ujson.load(file_) | ||||
|         self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg) | ||||
|         if (path / 'model').exists(): | ||||
|             self.model.load(str(path / 'model')) | ||||
|         elif require: | ||||
|             raise IOError( | ||||
|                 "Required file %s/model not found when loading" % str(path)) | ||||
|         return self | ||||
|         with Model.use_device('cpu'): | ||||
|             upper = chain( | ||||
|                         Maxout(hidden_width), | ||||
|                         zero_init(Affine(nr_class)) | ||||
|                     ) | ||||
|         # TODO: This is an unfortunate hack atm! | ||||
|         # Used to set input dimensions in network. | ||||
|         lower.begin_training(lower.ops.allocate((500, token_vector_width))) | ||||
|         upper.begin_training(upper.ops.allocate((500, hidden_width))) | ||||
|         return tok2vec, lower, upper | ||||
| 
 | ||||
|     def __init__(self, Vocab vocab, TransitionSystem=None, model=None, **cfg): | ||||
|     @classmethod | ||||
|     def Moves(cls): | ||||
|         return TransitionSystem() | ||||
| 
 | ||||
|     def __init__(self, Vocab vocab, moves=True, model=True, **cfg): | ||||
|         """ | ||||
|         Create a Parser. | ||||
| 
 | ||||
|         Arguments: | ||||
|             vocab (Vocab): | ||||
|                 The vocabulary object. Must be shared with documents to be processed. | ||||
|             model (thinc Model): | ||||
|                 The statistical model. | ||||
|         Returns (Parser): | ||||
|             The newly constructed object. | ||||
|                 The value is set to the .vocab attribute. | ||||
|             moves (TransitionSystem): | ||||
|                 Defines how the parse-state is created, updated and evaluated. | ||||
|                 The value is set to the .moves attribute unless True (default), | ||||
|                 in which case a new instance is created with Parser.Moves(). | ||||
|             model (object): | ||||
|                 Defines how the parse-state is created, updated and evaluated. | ||||
|                 The value is set to the .model attribute unless True (default), | ||||
|                 in which case a new instance is created with Parser.Model(). | ||||
|             **cfg: | ||||
|                 Arbitrary configuration parameters. Set to the .cfg attribute | ||||
|         """ | ||||
|         if TransitionSystem is None: | ||||
|             TransitionSystem = self.TransitionSystem | ||||
|         self.vocab = vocab | ||||
|         cfg['actions'] = TransitionSystem.get_actions(**cfg) | ||||
|         self.moves = TransitionSystem(vocab.strings, cfg['actions']) | ||||
|         if model is None: | ||||
|             self.model, self.feature_maps = self.build_model(**cfg) | ||||
|         else: | ||||
|             self.model, self.feature_maps = model | ||||
|         self.moves = self.Moves(self.vocab) if moves is True else moves | ||||
|         self.model = self.Model(self.moves.n_moves) if model is True else model | ||||
|         self.cfg = cfg | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (Parser, (self.vocab, self.moves, self.model), None, None) | ||||
| 
 | ||||
|     def build_model(self, | ||||
|             hidden_width=128, token_vector_width=96, nr_vector=1000, | ||||
|             nF=1, nB=1, nS=1, nL=1, nR=1, **cfg): | ||||
|         nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) | ||||
|         with Model.use_device('cpu'): | ||||
|             upper = chain( | ||||
|                         Maxout(hidden_width, hidden_width), | ||||
|                         #print_mean_variance, | ||||
|                         zero_init(Affine(self.moves.n_moves, hidden_width))) | ||||
|         assert isinstance(upper.ops, NumpyOps) | ||||
|         lower = PrecomputableMaxouts(hidden_width, nF=nr_context_tokens, nI=token_vector_width, | ||||
|                                      pieces=cfg.get('maxout_pieces', 1)) | ||||
|         lower.begin_training(lower.ops.allocate((500, token_vector_width))) | ||||
|         upper.begin_training(upper.ops.allocate((500, hidden_width))) | ||||
|         return upper, lower | ||||
|         return (Parser, (self.vocab, self.moves, self.model, self.cfg), None, None) | ||||
| 
 | ||||
|     def __call__(self, Doc tokens): | ||||
|         """ | ||||
|  | @ -356,168 +292,145 @@ cdef class Parser: | |||
|                 The number of threads with which to work on the buffer in parallel. | ||||
|         Yields (Doc): Documents, in order. | ||||
|         """ | ||||
|         cdef StateClass state | ||||
|         cdef Doc doc | ||||
|         queue = [] | ||||
|         for doc in stream: | ||||
|             queue.append(doc) | ||||
|             if len(queue) == batch_size: | ||||
|                 self.parse_batch(queue) | ||||
|                 for doc in queue: | ||||
|                     self.moves.finalize_doc(doc) | ||||
|                     yield doc | ||||
|                 queue = [] | ||||
|         if queue: | ||||
|             self.parse_batch(queue) | ||||
|             for doc in queue: | ||||
|         for docs in cytoolz.partition_all(batch_size, stream): | ||||
|             docs = list(docs) | ||||
|             states = self.parse_batch(docs) | ||||
|             for state, doc in zip(states, docs): | ||||
|                 self.moves.finalize_state(state.c) | ||||
|                 for i in range(doc.length): | ||||
|                     doc.c[i] = state.c._sent[i] | ||||
|                 self.moves.finalize_doc(doc) | ||||
|                 yield doc | ||||
| 
 | ||||
|     def parse_batch(self, docs_tokvecs): | ||||
|         cdef: | ||||
|             int nC | ||||
|             Doc doc | ||||
|             StateClass state | ||||
|             np.ndarray py_scores | ||||
|             int[500] is_valid # Hacks for now | ||||
|     def parse_batch(self, docs): | ||||
|         cuda_stream = get_cuda_stream() | ||||
| 
 | ||||
|         tokvecs = self.model[0](docs) | ||||
|         states = self.moves.init_batch(docs) | ||||
|         state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, | ||||
|                                                      cuda_stream, 0.0) | ||||
| 
 | ||||
|         todo = [st for st in states if not st.is_final()] | ||||
|         while todo: | ||||
|             token_ids = self.get_token_ids(states) | ||||
|             vectors = state2vec(token_ids) | ||||
|             scores = vec2scores(vectors) | ||||
|             self.transition_batch(states, scores) | ||||
|             todo = [st for st in states if not st.is_final()] | ||||
|         self.finish_batch(states, docs) | ||||
| 
 | ||||
|     def update(self, docs, golds, drop=0., sgd=None): | ||||
|         if isinstance(docs, Doc) and isinstance(golds, GoldParse): | ||||
|             return self.update([docs], [golds], drop=drop, sgd=sgd) | ||||
| 
 | ||||
|         cuda_stream = get_cuda_stream() | ||||
|         docs, tokvecs = docs_tokvecs | ||||
|         lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps, | ||||
|                                                  cuda_stream) | ||||
|         upper_model = self.model | ||||
|         for gold in golds: | ||||
|             self.moves.preprocess_gold(gold) | ||||
| 
 | ||||
|         states, offsets = init_states(self.moves, docs) | ||||
|         all_states = list(states) | ||||
|         todo = [st for st in zip(states, offsets) if not st[0].py_is_final()] | ||||
|         tokvecs, bp_tokvecs = self.model[0].begin_update(docs, drop=drop) | ||||
|         states = self.moves.init_batch(docs) | ||||
|         state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, | ||||
|                                                       drop) | ||||
| 
 | ||||
|         todo = [(s, g) for s, g in zip(states, golds) if not s.is_final()] | ||||
| 
 | ||||
|         backprops = [] | ||||
|         cdef float loss = 0. | ||||
|         while todo: | ||||
|             states, offsets = zip(*todo) | ||||
|             token_ids = extract_token_ids(states, offsets=offsets) | ||||
|             states, golds = zip(*todo) | ||||
| 
 | ||||
|             py_scores = upper_model(lower_model(token_ids)[0]) | ||||
|             scores = <float*>py_scores.data | ||||
|             nC = py_scores.shape[1] | ||||
|             for state, offset in zip(states, offsets): | ||||
|                 self.moves.set_valid(is_valid, state.c) | ||||
|                 guess = arg_max_if_valid(scores, is_valid, nC) | ||||
|                 action = self.moves.c[guess] | ||||
|                 action.do(state.c, action.label) | ||||
|                 scores += nC | ||||
|             todo = [st for st in todo if not st[0].py_is_final()] | ||||
|             token_ids = self.get_token_ids(states) | ||||
|             vector, bp_vector = state2vec.begin_update(token_ids, drop=drop) | ||||
|             scores, bp_scores = vec2scores.begin_update(vector, drop=drop) | ||||
| 
 | ||||
|         for state, doc in zip(all_states, docs): | ||||
|             d_scores = self.get_batch_loss(states, golds, scores) | ||||
|             d_vector = bp_scores(d_scores, sgd=sgd) | ||||
|             loss += (d_scores**2).sum() | ||||
| 
 | ||||
|             if not isinstance(tokvecs, state2vec.ops.xp.ndarray): | ||||
|                 backprops.append((token_ids, d_vector, bp_vector)) | ||||
|             else: | ||||
|                 # Move token_ids and d_vector to CPU, asynchronously | ||||
|                 backprops.append(( | ||||
|                     get_async(cuda_stream, token_ids), | ||||
|                     get_async(cuda_stream, d_vector), | ||||
|                     bp_vector | ||||
|                 )) | ||||
|             self.transition_batch(states, scores) | ||||
|             todo = [st for st in todo if not st[0].is_final()] | ||||
|         # Tells CUDA to block, so our async copies complete. | ||||
|         if cuda_stream is not None: | ||||
|             cuda_stream.synchronize() | ||||
|         d_tokvecs = state2vec.ops.allocate(tokvecs.shape) | ||||
|         xp = state2vec.ops.xp # Handle for numpy/cupy | ||||
|         for token_ids, d_vector, bp_vector in backprops: | ||||
|             d_state_features = bp_vector(d_vector, sgd=sgd) | ||||
|             active_feats = token_ids * (token_ids >= 0) | ||||
|             active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1)) | ||||
|             if hasattr(xp, 'scatter_add'): | ||||
|                 xp.scatter_add(d_tokvecs, | ||||
|                     token_ids, d_state_features * active_feats) | ||||
|             else: | ||||
|                 xp.add.at(d_tokvecs, | ||||
|                     token_ids, d_state_features * active_feats) | ||||
|         bp_tokvecs(d_tokvecs, sgd) | ||||
|         return loss | ||||
| 
 | ||||
|     def get_batch_model(self, batch_size, tokvecs, stream, dropout): | ||||
|         state2vec = precompute_hiddens(batch_size, tokvecs, | ||||
|                         self.model[1], stream, drop=dropout) | ||||
|         return state2vec, self.model[-1] | ||||
| 
 | ||||
|     def get_token_ids(self, states): | ||||
|         cdef StateClass state | ||||
|         cdef int n_tokens = states[0].nr_context_tokens() | ||||
|         ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c') | ||||
|         for i, state in enumerate(states): | ||||
|             state.set_context_tokens(ids[i]) | ||||
|         return ids | ||||
| 
 | ||||
|     def transition_batch(self, states, float[:, ::1] scores): | ||||
|         cdef StateClass state | ||||
|         cdef int[500] is_valid # TODO: Unhack | ||||
|         cdef float* c_scores = &scores[0, 0] | ||||
|         for state in states: | ||||
|             self.moves.set_valid(is_valid, state.c) | ||||
|             guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) | ||||
|             action = self.moves.c[guess] | ||||
|             action.do(state.c, action.label) | ||||
|             c_scores += scores.shape[1] | ||||
| 
 | ||||
|     def get_batch_loss(self, states, golds, float[:, ::1] scores): | ||||
|         cdef StateClass state | ||||
|         cdef GoldParse gold | ||||
|         cdef Pool mem = Pool() | ||||
|         cdef int i | ||||
|         is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int)) | ||||
|         costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float)) | ||||
|         cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves), | ||||
|                                         dtype='f', order='C') | ||||
|         c_d_scores = <float*>d_scores.data | ||||
|         for i, (state, gold) in enumerate(zip(states, golds)): | ||||
|             memset(is_valid, 0, self.moves.n_moves * sizeof(int)) | ||||
|             memset(costs, 0, self.moves.n_moves * sizeof(float)) | ||||
|             self.moves.set_costs(is_valid, costs, state, gold) | ||||
|             cpu_log_loss(c_d_scores, | ||||
|                 costs, is_valid, &scores[i, 0], d_scores.shape[1]) | ||||
|             c_d_scores += d_scores.shape[1] | ||||
|         return d_scores | ||||
| 
 | ||||
|     def finish_batch(self, states, docs): | ||||
|         cdef StateClass state | ||||
|         cdef Doc doc | ||||
|         for state, doc in zip(states, docs): | ||||
|             self.moves.finalize_state(state.c) | ||||
|             for i in range(doc.length): | ||||
|                 doc.c[i] = state.c._sent[i] | ||||
|             self.moves.finalize_doc(doc) | ||||
| 
 | ||||
|     def update(self, docs_tokvecs, golds, drop=0., sgd=None): | ||||
|         cdef: | ||||
|             int nC | ||||
|             Doc doc | ||||
|             StateClass state | ||||
|             np.ndarray scores | ||||
| 
 | ||||
|         docs, tokvecs = docs_tokvecs | ||||
|         cuda_stream = get_cuda_stream() | ||||
|         lower_model = get_greedy_model_for_batch(len(docs), | ||||
|                         tokvecs, self.feature_maps, cuda_stream=cuda_stream, | ||||
|                         drop=drop) | ||||
|         if isinstance(docs, Doc) and isinstance(golds, GoldParse): | ||||
|             return self.update(([docs], tokvecs), [golds], drop=drop) | ||||
|         for gold in golds: | ||||
|             self.moves.preprocess_gold(gold) | ||||
| 
 | ||||
|         states, offsets = init_states(self.moves, docs) | ||||
| 
 | ||||
|         todo = zip(states, offsets, golds) | ||||
|         todo = filter(lambda sp: not sp[0].py_is_final(), todo) | ||||
| 
 | ||||
|         cdef Pool mem = Pool() | ||||
|         is_valid = <int*>mem.alloc(len(states) * self.moves.n_moves, sizeof(int)) | ||||
|         costs = <float*>mem.alloc(len(states) * self.moves.n_moves, sizeof(float)) | ||||
| 
 | ||||
|         upper_model = self.model | ||||
|         d_tokens = self.feature_maps.ops.allocate(tokvecs.shape) | ||||
|         backprops = [] | ||||
|         n_tokens = tokvecs.shape[0] | ||||
|         nF = self.feature_maps.nF | ||||
|         loss = 0. | ||||
|         total = 1e-4 | ||||
|         follow_gold = False | ||||
|         cupy = self.feature_maps.ops.xp | ||||
|         while len(todo) >= 4: | ||||
|             states, offsets, golds = zip(*todo) | ||||
| 
 | ||||
|             token_ids = extract_token_ids(states, offsets=offsets) | ||||
|             lower, bp_lower = lower_model(token_ids, drop=drop) | ||||
|             scores, bp_scores = upper_model.begin_update(lower, drop=drop) | ||||
| 
 | ||||
|             d_scores = get_batch_loss(self.moves, states, golds, scores) | ||||
|             loss += numpy.abs(d_scores).sum() | ||||
|             total += d_scores.shape[0] | ||||
|             d_lower = bp_scores(d_scores, sgd=sgd) | ||||
| 
 | ||||
|             if isinstance(tokvecs, cupy.ndarray): | ||||
|                 gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i', order='C') | ||||
|                 gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f', order='C') | ||||
|                 gpu_tok_ids.set(token_ids, stream=cuda_stream) | ||||
|                 gpu_d_lower.set(d_lower, stream=cuda_stream) | ||||
|                 backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower)) | ||||
|             else: | ||||
|                 backprops.append((token_ids, d_lower, bp_lower)) | ||||
| 
 | ||||
|             c_scores = <float*>scores.data | ||||
|             for state, gold in zip(states, golds): | ||||
|                 if follow_gold: | ||||
|                     self.moves.set_costs(is_valid, costs, state, gold) | ||||
|                     guess = arg_max_if_gold(c_scores, costs, is_valid, scores.shape[1]) | ||||
|                 else: | ||||
|                     self.moves.set_valid(is_valid, state.c) | ||||
|                     guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) | ||||
|                 action = self.moves.c[guess] | ||||
|                 action.do(state.c, action.label) | ||||
|                 c_scores += scores.shape[1] | ||||
| 
 | ||||
|             todo = filter(lambda sp: not sp[0].py_is_final(), todo) | ||||
|         # This tells CUDA to block --- so we know our copies are complete. | ||||
|         cuda_stream.synchronize() | ||||
|         for token_ids, d_lower, bp_lower in backprops: | ||||
|             d_state_features = bp_lower(d_lower, sgd=sgd) | ||||
|             active_feats = token_ids * (token_ids >= 0) | ||||
|             active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1)) | ||||
|             if hasattr(self.feature_maps.ops.xp, 'scatter_add'): | ||||
|                 self.feature_maps.ops.xp.scatter_add(d_tokens, | ||||
|                     token_ids, d_state_features * active_feats) | ||||
|             else: | ||||
|                 self.model.ops.xp.add.at(d_tokens, | ||||
|                     token_ids, d_state_features * active_feats) | ||||
|         return d_tokens, loss / total | ||||
| 
 | ||||
|     def step_through(self, Doc doc, GoldParse gold=None): | ||||
|         """ | ||||
|         Set up a stepwise state, to introspect and control the transition sequence. | ||||
| 
 | ||||
|         Arguments: | ||||
|             doc (Doc): The document to step through. | ||||
|             gold (GoldParse): Optional gold parse | ||||
|         Returns (StepwiseState): | ||||
|             A state object, to step through the annotation process. | ||||
|         """ | ||||
|         return StepwiseState(self, doc, gold=gold) | ||||
| 
 | ||||
|     def from_transition_sequence(self, Doc doc, sequence): | ||||
|         """Control the annotations on a document by specifying a transition sequence | ||||
|         to follow. | ||||
| 
 | ||||
|         Arguments: | ||||
|             doc (Doc): The document to annotate. | ||||
|             sequence: A sequence of action names, as unicode strings. | ||||
|         Returns: None | ||||
|         """ | ||||
|         with self.step_through(doc) as stepwise: | ||||
|             for transition in sequence: | ||||
|                 stepwise.transition(transition) | ||||
| 
 | ||||
|     def add_label(self, label): | ||||
|         # Doesn't set label into serializer -- subclasses override it to do that. | ||||
|         for action in self.moves.action_types: | ||||
|  | @ -528,108 +441,6 @@ cdef class Parser: | |||
|                 self.cfg.setdefault('extra_labels', []).append(label) | ||||
| 
 | ||||
| 
 | ||||
| cdef class StepwiseState: | ||||
|     cdef readonly StateClass stcls | ||||
|     cdef readonly Example eg | ||||
|     cdef readonly Doc doc | ||||
|     cdef readonly GoldParse gold | ||||
|     cdef readonly Parser parser | ||||
| 
 | ||||
|     def __init__(self, Parser parser, Doc doc, GoldParse gold=None): | ||||
|         self.parser = parser | ||||
|         self.doc = doc | ||||
|         if gold is not None: | ||||
|             self.gold = gold | ||||
|             self.parser.moves.preprocess_gold(self.gold) | ||||
|         else: | ||||
|             self.gold = GoldParse(doc) | ||||
|         self.stcls = StateClass.init(doc.c, doc.length) | ||||
|         self.parser.moves.initialize_state(self.stcls.c) | ||||
|         self.eg = Example( | ||||
|             nr_class=self.parser.moves.n_moves, | ||||
|             nr_atom=CONTEXT_SIZE, | ||||
|             nr_feat=self.parser.model.nr_feat) | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|         return self | ||||
| 
 | ||||
|     def __exit__(self, type, value, traceback): | ||||
|         self.finish() | ||||
| 
 | ||||
|     @property | ||||
|     def is_final(self): | ||||
|         return self.stcls.is_final() | ||||
| 
 | ||||
|     @property | ||||
|     def stack(self): | ||||
|         return self.stcls.stack | ||||
| 
 | ||||
|     @property | ||||
|     def queue(self): | ||||
|         return self.stcls.queue | ||||
| 
 | ||||
|     @property | ||||
|     def heads(self): | ||||
|         return [self.stcls.H(i) for i in range(self.stcls.c.length)] | ||||
| 
 | ||||
|     @property | ||||
|     def deps(self): | ||||
|         return [self.doc.vocab.strings[self.stcls.c._sent[i].dep] | ||||
|                 for i in range(self.stcls.c.length)] | ||||
| 
 | ||||
|     @property | ||||
|     def costs(self): | ||||
|         """ | ||||
|         Find the action-costs for the current state. | ||||
|         """ | ||||
|         if not self.gold: | ||||
|             raise ValueError("Can't set costs: No GoldParse provided") | ||||
|         self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs, | ||||
|                 self.stcls, self.gold) | ||||
|         costs = {} | ||||
|         for i in range(self.parser.moves.n_moves): | ||||
|             if not self.eg.c.is_valid[i]: | ||||
|                 continue | ||||
|             transition = self.parser.moves.c[i] | ||||
|             name = self.parser.moves.move_name(transition.move, transition.label) | ||||
|             costs[name] = self.eg.c.costs[i] | ||||
|         return costs | ||||
| 
 | ||||
|     def predict(self): | ||||
|         self.eg.reset() | ||||
|         #self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features, | ||||
|         #                                                    self.stcls.c) | ||||
|         self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c) | ||||
|         #self.parser.model.set_scoresC(self.eg.c.scores, | ||||
|         #    self.eg.c.features, self.eg.c.nr_feat) | ||||
| 
 | ||||
|         cdef Transition action = self.parser.moves.c[self.eg.guess] | ||||
|         return self.parser.moves.move_name(action.move, action.label) | ||||
| 
 | ||||
|     def transition(self, action_name=None): | ||||
|         if action_name is None: | ||||
|             action_name = self.predict() | ||||
|         moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3} | ||||
|         if action_name == '_': | ||||
|             action_name = self.predict() | ||||
|             action = self.parser.moves.lookup_transition(action_name) | ||||
|         elif action_name == 'L' or action_name == 'R': | ||||
|             self.predict() | ||||
|             move = moves[action_name] | ||||
|             clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c, | ||||
|                                  self.eg.c.nr_class) | ||||
|             action = self.parser.moves.c[clas] | ||||
|         else: | ||||
|             action = self.parser.moves.lookup_transition(action_name) | ||||
|         action.do(self.stcls.c, action.label) | ||||
| 
 | ||||
|     def finish(self): | ||||
|         if self.stcls.is_final(): | ||||
|             self.parser.moves.finalize_state(self.stcls.c) | ||||
|         self.doc.set_parse(self.stcls.c._sent) | ||||
|         self.parser.moves.finalize_doc(self.doc) | ||||
| 
 | ||||
| 
 | ||||
| class ParserStateError(ValueError): | ||||
|     def __init__(self, doc): | ||||
|         ValueError.__init__(self, | ||||
|  |  | |||
|  | @ -9,17 +9,24 @@ from ..vocab cimport EMPTY_LEXEME | |||
| from ._state cimport StateC | ||||
| 
 | ||||
| 
 | ||||
| @cython.final | ||||
| cdef class StateClass: | ||||
|     cdef Pool mem | ||||
|     cdef StateC* c | ||||
| 
 | ||||
|     @staticmethod | ||||
|     cdef inline StateClass init(const TokenC* sent, int length): | ||||
|         cdef StateClass self = StateClass(length) | ||||
|         cdef StateClass self = StateClass() | ||||
|         self.c = new StateC(sent, length) | ||||
|         return self | ||||
| 
 | ||||
|     @staticmethod | ||||
|     cdef inline StateClass init_offset(const TokenC* sent, int length, int | ||||
|                                        offset): | ||||
|         cdef StateClass self = StateClass() | ||||
|         self.c = new StateC(sent, length) | ||||
|         self.c.offset = offset | ||||
|         return self | ||||
| 
 | ||||
|     cdef inline int S(self, int i) nogil: | ||||
|         return self.c.S(i) | ||||
| 
 | ||||
|  | @ -68,9 +75,6 @@ cdef class StateClass: | |||
|     cdef inline bint at_break(self) nogil: | ||||
|         return self.c.at_break() | ||||
| 
 | ||||
|     cdef inline bint is_final(self) nogil: | ||||
|         return self.c.is_final() | ||||
| 
 | ||||
|     cdef inline bint has_head(self, int i) nogil: | ||||
|         return self.c.has_head(i) | ||||
| 
 | ||||
|  | @ -97,22 +101,22 @@ cdef class StateClass: | |||
| 
 | ||||
|     cdef inline void pop(self) nogil: | ||||
|         self.c.pop() | ||||
|      | ||||
| 
 | ||||
|     cdef inline void unshift(self) nogil: | ||||
|         self.c.unshift() | ||||
| 
 | ||||
|     cdef inline void add_arc(self, int head, int child, int label) nogil: | ||||
|         self.c.add_arc(head, child, label) | ||||
|      | ||||
| 
 | ||||
|     cdef inline void del_arc(self, int head, int child) nogil: | ||||
|         self.c.del_arc(head, child) | ||||
| 
 | ||||
|     cdef inline void open_ent(self, int label) nogil: | ||||
|         self.c.open_ent(label) | ||||
|      | ||||
| 
 | ||||
|     cdef inline void close_ent(self) nogil: | ||||
|         self.c.close_ent() | ||||
|      | ||||
| 
 | ||||
|     cdef inline void set_ent_tag(self, int i, int ent_iob, int ent_type) nogil: | ||||
|         self.c.set_ent_tag(i, ent_iob, ent_type) | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,12 +12,16 @@ from ..symbols cimport punct | |||
| from ..attrs cimport IS_SPACE | ||||
| from ..attrs cimport attr_id_t | ||||
| from ..tokens.token cimport Token | ||||
| from ..tokens.doc cimport Doc | ||||
| 
 | ||||
| 
 | ||||
| cdef class StateClass: | ||||
|     def __init__(self, int length): | ||||
|     def __init__(self, Doc doc=None, int offset=0): | ||||
|         cdef Pool mem = Pool() | ||||
|         self.mem = mem | ||||
|         if doc is not None: | ||||
|             self.c = new StateC(doc.c, doc.length) | ||||
|             self.c.offset = offset | ||||
| 
 | ||||
|     def __dealloc__(self): | ||||
|         del self.c | ||||
|  | @ -34,7 +38,7 @@ cdef class StateClass: | |||
|     def token_vector_lenth(self): | ||||
|         return self.doc.tensor.shape[1] | ||||
| 
 | ||||
|     def py_is_final(self): | ||||
|     def is_final(self): | ||||
|         return self.c.is_final() | ||||
| 
 | ||||
|     def print_state(self, words): | ||||
|  | @ -47,11 +51,10 @@ cdef class StateClass: | |||
|         return ' '.join((third, second, top, '|', n0, n1)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR): | ||||
|     def nr_context_tokens(cls): | ||||
|         return 13 | ||||
| 
 | ||||
|     def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, | ||||
|             nL=2, nR=2): | ||||
|     def set_context_tokens(self, int[::1] output): | ||||
|         output[0] = self.B(0) | ||||
|         output[1] = self.B(1) | ||||
|         output[2] = self.S(0) | ||||
|  | @ -67,21 +70,6 @@ cdef class StateClass: | |||
|         output[11] = self.R(self.S(1), 1) | ||||
|         output[12] = self.R(self.S(1), 2) | ||||
| 
 | ||||
|     def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): | ||||
|         cdef int i, j, tok_i | ||||
|         for i in range(tokens.shape[0]): | ||||
|             tok_i = tokens[i] | ||||
|             if tok_i >= 0: | ||||
|                 token = &self.c._sent[tok_i] | ||||
|                 for j in range(names.shape[0]): | ||||
|                     vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j]) | ||||
|             else: | ||||
|                 vals[i] = 0 | ||||
| 
 | ||||
|     def set_token_vectors(self, tokvecs, | ||||
|             all_tokvecs, int[:] indices): | ||||
|         for i in range(indices.shape[0]): | ||||
|             if indices[i] >= 0: | ||||
|                 tokvecs[i] = all_tokvecs[indices[i]] | ||||
|             else: | ||||
|                 tokvecs[i] = 0 | ||||
|         for i in range(13): | ||||
|             if output[i] != -1: | ||||
|                 output[i] += self.c.offset | ||||
|  |  | |||
|  | @ -58,6 +58,17 @@ cdef class TransitionSystem: | |||
|                 (self.strings, labels_by_action, self.freqs), | ||||
|                 None, None) | ||||
| 
 | ||||
|     def init_batch(self, docs): | ||||
|         cdef StateClass state | ||||
|         states = [] | ||||
|         offset = 0 | ||||
|         for doc in docs: | ||||
|             state = StateClass(doc, offset=offset) | ||||
|             self.initialize_state(state.c) | ||||
|             states.append(state) | ||||
|             offset += len(doc) | ||||
|         return states | ||||
| 
 | ||||
|     cdef int initialize_state(self, StateC* state) nogil: | ||||
|         pass | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										70
									
								
								spacy/tests/parser/test_neural_parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								spacy/tests/parser/test_neural_parser.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,70 @@ | |||
| from thinc.neural import Model | ||||
| from mock import Mock | ||||
| import pytest | ||||
| import numpy | ||||
| 
 | ||||
| from ..._ml import chain, Tok2Vec, doc2feats | ||||
| from ...vocab import Vocab | ||||
| from ...pipeline import TokenVectorEncoder | ||||
| from ...syntax.arc_eager import ArcEager | ||||
| from ...syntax.nn_parser import Parser | ||||
| from ...tokens.doc import Doc | ||||
| from ...gold import GoldParse | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def vocab(): | ||||
|     return Vocab() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def arc_eager(vocab): | ||||
|     actions = ArcEager.get_actions(left_labels=['L'], right_labels=['R']) | ||||
|     return ArcEager(vocab.strings, actions) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def tok2vec(): | ||||
|     return Tok2Vec(8, 100, preprocess=doc2feats()) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def parser(vocab, arc_eager): | ||||
|     return Parser(vocab, moves=arc_eager, model=None) | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def model(arc_eager, tok2vec): | ||||
|     return Parser.Model(arc_eager.n_moves, tok2vec) | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def doc(vocab): | ||||
|     return Doc(vocab, words=['a', 'b', 'c']) | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def gold(doc): | ||||
|     return GoldParse(doc, heads=[1, 1, 1], deps=['L', 'ROOT', 'R']) | ||||
| def test_can_init_nn_parser(parser): | ||||
|     assert parser.model is None | ||||
| 
 | ||||
| 
 | ||||
| def test_build_model(parser, tok2vec): | ||||
|     parser.model = Parser.Model(parser.moves.n_moves, tok2vec) | ||||
|     assert parser.model is not None | ||||
| 
 | ||||
| 
 | ||||
| def test_predict_doc(parser, model, doc): | ||||
|     parser.model = model | ||||
|     parser(doc) | ||||
| 
 | ||||
| 
 | ||||
| def test_update_doc(parser, model, doc, gold): | ||||
|     parser.model = model | ||||
|     loss1 = parser.update(doc, gold) | ||||
|     assert loss1 > 0 | ||||
|     loss2 = parser.update(doc, gold) | ||||
|     assert loss2 == loss1 | ||||
|     def optimize(weights, gradient, key=None): | ||||
|         weights -= 0.001 * gradient | ||||
|     loss3 = parser.update(doc, gold, sgd=optimize) | ||||
|     loss4 = parser.update(doc, gold, sgd=optimize) | ||||
|     assert loss3 < loss2 | ||||
|  | @ -3,6 +3,10 @@ from __future__ import absolute_import, unicode_literals | |||
| 
 | ||||
| import random | ||||
| import tqdm | ||||
| 
 | ||||
| from thinc.neural.optimizers import Adam | ||||
| from thinc.neural.ops import NumpyOps, CupyOps | ||||
| 
 | ||||
| from .gold import GoldParse, merge_sents | ||||
| from .scorer import Scorer | ||||
| 
 | ||||
|  | @ -44,10 +48,12 @@ class Trainer(object): | |||
|             yield _epoch(indices) | ||||
|             self.nr_epoch += 1 | ||||
| 
 | ||||
|     def update(self, doc, gold): | ||||
|     def update(self, docs, golds, drop=0.): | ||||
|         for process in self.nlp.pipeline: | ||||
|             if hasattr(process, 'update'): | ||||
|                 loss = process.update(doc, gold, itn=self.nr_epoch) | ||||
|                 loss = process.update(doc, gold, sgd=self.sgd, drop=drop, | ||||
|                                       itn=self.nr_epoch) | ||||
|                 self.sgd.finish_update() | ||||
|             else: | ||||
|                 process(doc) | ||||
|         return doc | ||||
|  |  | |||
|  | @ -15,7 +15,15 @@ from .compat import path2str, basestring_, input_, unicode_ | |||
| 
 | ||||
| LANGUAGES = {} | ||||
| _data_path = Path(__file__).parent / 'data' | ||||
| try: | ||||
|     from cupy.cuda.stream import Stream as CudaStream | ||||
| except ImportError: | ||||
|     CudaStream = None | ||||
| 
 | ||||
| try: | ||||
|     import cupy | ||||
| except ImportError: | ||||
|     cupy = None | ||||
| 
 | ||||
| def set_lang_class(name, cls): | ||||
|     global LANGUAGES | ||||
|  | @ -152,11 +160,14 @@ def parse_package_meta(package_path, require=True): | |||
| def get_cuda_stream(require=False): | ||||
|     # TODO: Error and tell to install chainer if not found | ||||
|     # Requires GPU | ||||
|     try: | ||||
|         from cupy.cuda.stream import Stream | ||||
|     except ImportError: | ||||
|         return None | ||||
|     return Stream() | ||||
|     return CudaStream() if CudaStream is not None else None | ||||
| 
 | ||||
| 
 | ||||
| def get_async(stream, numpy_array): | ||||
|     if cupy is None: | ||||
|         return numpy_array | ||||
|     else: | ||||
|         return cupy.array(numpy_array, stream=stream) | ||||
| 
 | ||||
| 
 | ||||
| def read_regex(path): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user