mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Get pre-computed version working
This commit is contained in:
		
							parent
							
								
									35458987e8
								
							
						
					
					
						commit
						10682d35ab
					
				|  | @ -144,7 +144,6 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | |||
|         docs = list(Xs) | ||||
|         for doc in docs: | ||||
|             encoder(doc) | ||||
|         parser.begin_training(docs, ys) | ||||
|         nn_loss = [0.] | ||||
|         def track_progress(): | ||||
|             scorer = score_model(vocab, encoder, tagger, parser, dev_Xs, dev_ys) | ||||
|  | @ -153,7 +152,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | |||
|             nn_loss.append(0.) | ||||
|         trainer.each_epoch.append(track_progress) | ||||
|         trainer.batch_size = 12 | ||||
|         trainer.nb_epoch = 2 | ||||
|         trainer.nb_epoch = 20 | ||||
|         for docs, golds in trainer.iterate(Xs, ys, progress_bar=False): | ||||
|             docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs] | ||||
|             tokvecs, upd_tokvecs = encoder.begin_update(docs) | ||||
|  | @ -161,9 +160,9 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | |||
|                 doc.tensor = tokvec | ||||
|             for doc, gold in zip(docs, golds): | ||||
|                 tagger.update(doc, gold) | ||||
|             d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer) | ||||
|             d_tokvecs = parser.update(docs, golds, sgd=optimizer) | ||||
|             upd_tokvecs(d_tokvecs, sgd=optimizer) | ||||
|             nn_loss[-1] += loss | ||||
|             #nn_loss[-1] += loss | ||||
|     nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) | ||||
|     #nlp.end_training(model_dir) | ||||
|     #scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) | ||||
|  | @ -173,7 +172,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | |||
| if __name__ == '__main__': | ||||
|     import cProfile | ||||
|     import pstats | ||||
|     if 0: | ||||
|     if 1: | ||||
|         plac.call(main) | ||||
|     else: | ||||
|         cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof") | ||||
|  |  | |||
							
								
								
									
										85
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										85
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -51,47 +51,6 @@ def doc2feats(cols): | |||
|     model = layerize(forward) | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| def build_feature_precomputer(model, feat_maps): | ||||
|     '''Allow a model to be "primed" by pre-computing input features in bulk. | ||||
| 
 | ||||
|     This is used for the parser, where we want to take a batch of documents, | ||||
|     and compute vectors for each (token, position) pair. These vectors can then | ||||
|     be reused, especially for beam-search. | ||||
| 
 | ||||
|     Let's say we're using 12 features for each state, e.g. word at start of | ||||
|     buffer, three words on stack, their children, etc. In the normal arc-eager | ||||
|     system, a document of length N is processed in 2*N states. This means we'll | ||||
|     create 2*N*12 feature vectors --- but if we pre-compute, we only need | ||||
|     N*12 vector computations. The saving for beam-search is much better: | ||||
|     if we have a beam of k, we'll normally make 2*N*12*K computations --  | ||||
|     so we can save the factor k. This also gives a nice CPU/GPU division: | ||||
|     we can do all our hard maths up front, packed into large multiplications, | ||||
|     and do the hard-to-program parsing on the CPU. | ||||
|     ''' | ||||
|     def precompute(input_vectors): | ||||
|         cached, backprops = zip(*[lyr.begin_update(input_vectors) | ||||
|                                 for lyr in feat_maps) | ||||
|         def forward(batch_token_ids, drop=0.): | ||||
|             output = ops.allocate((batch_size, output_width)) | ||||
|             # i: batch index | ||||
|             # j: position index (i.e. N0, S0, etc | ||||
|             # tok_i: Index of the token within its document | ||||
|             for i, token_ids in enumerate(batch_token_ids): | ||||
|                 for j, tok_i in enumerate(token_ids): | ||||
|                     output[i] += cached[j][tok_i] | ||||
|             def backward(d_vector, sgd=None): | ||||
|                 d_inputs = ops.allocate((batch_size, n_feat, vec_width)) | ||||
|                 for i, token_ids in enumerate(batch_token_ids): | ||||
|                     for j in range(len(token_ids)): | ||||
|                         d_inputs[i][j] = backprops[j](d_vector, sgd) | ||||
|                 # Return the IDs, so caller can associate to correct token | ||||
|                 return (batch_token_ids, d_inputs) | ||||
|             return vector, backward | ||||
|         return chain(layerize(forward), model) | ||||
|     return precompute | ||||
| 
 | ||||
| 
 | ||||
| def print_shape(prefix): | ||||
|     def forward(X, drop=0.): | ||||
|         return X, lambda dX, **kwargs: dX | ||||
|  | @ -114,3 +73,47 @@ def flatten(seqs, drop=0.): | |||
|         return d_X | ||||
|     X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs]) | ||||
|     return X, finish_update | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| #def build_feature_precomputer(model, feat_maps): | ||||
| #    '''Allow a model to be "primed" by pre-computing input features in bulk. | ||||
| # | ||||
| #    This is used for the parser, where we want to take a batch of documents, | ||||
| #    and compute vectors for each (token, position) pair. These vectors can then | ||||
| #    be reused, especially for beam-search. | ||||
| # | ||||
| #    Let's say we're using 12 features for each state, e.g. word at start of | ||||
| #    buffer, three words on stack, their children, etc. In the normal arc-eager | ||||
| #    system, a document of length N is processed in 2*N states. This means we'll | ||||
| #    create 2*N*12 feature vectors --- but if we pre-compute, we only need | ||||
| #    N*12 vector computations. The saving for beam-search is much better: | ||||
| #    if we have a beam of k, we'll normally make 2*N*12*K computations --  | ||||
| #    so we can save the factor k. This also gives a nice CPU/GPU division: | ||||
| #    we can do all our hard maths up front, packed into large multiplications, | ||||
| #    and do the hard-to-program parsing on the CPU. | ||||
| #    ''' | ||||
| #    def precompute(input_vectors): | ||||
| #        cached, backprops = zip(*[lyr.begin_update(input_vectors) | ||||
| #                                for lyr in feat_maps) | ||||
| #        def forward(batch_token_ids, drop=0.): | ||||
| #            output = ops.allocate((batch_size, output_width)) | ||||
| #            # i: batch index | ||||
| #            # j: position index (i.e. N0, S0, etc | ||||
| #            # tok_i: Index of the token within its document | ||||
| #            for i, token_ids in enumerate(batch_token_ids): | ||||
| #                for j, tok_i in enumerate(token_ids): | ||||
| #                    output[i] += cached[j][tok_i] | ||||
| #            def backward(d_vector, sgd=None): | ||||
| #                d_inputs = ops.allocate((batch_size, n_feat, vec_width)) | ||||
| #                for i, token_ids in enumerate(batch_token_ids): | ||||
| #                    for j in range(len(token_ids)): | ||||
| #                        d_inputs[i][j] = backprops[j](d_vector, sgd) | ||||
| #                # Return the IDs, so caller can associate to correct token | ||||
| #                return (batch_token_ids, d_inputs) | ||||
| #            return vector, backward | ||||
| #        return chain(layerize(forward), model) | ||||
| #    return precompute | ||||
| # | ||||
| # | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,5 +13,6 @@ cdef class Parser: | |||
|     cdef readonly object model | ||||
|     cdef readonly TransitionSystem moves | ||||
|     cdef readonly object cfg | ||||
|     cdef public object feature_maps | ||||
| 
 | ||||
|     #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil | ||||
|  |  | |||
|  | @ -28,10 +28,11 @@ from murmurhash.mrmr cimport hash64 | |||
| from preshed.maps cimport MapStruct | ||||
| from preshed.maps cimport map_get | ||||
| 
 | ||||
| from thinc.api import layerize | ||||
| 
 | ||||
| from numpy import exp | ||||
| from thinc.api import layerize, chain | ||||
| from thinc.neural import Model, Maxout | ||||
| 
 | ||||
| from .._ml import get_col | ||||
| from . import _parse_features | ||||
| from ._parse_features cimport CONTEXT_SIZE | ||||
| from ._parse_features cimport fill_context | ||||
|  | @ -46,8 +47,9 @@ from ..strings cimport StringStore | |||
| from ..gold cimport GoldParse | ||||
| from ..attrs cimport TAG, DEP | ||||
| 
 | ||||
| from .._ml import build_state2vec, build_model, precompute_hiddens | ||||
| 
 | ||||
| def get_templates(*args, **kwargs): | ||||
|     return [] | ||||
| 
 | ||||
| USE_FTRL = True | ||||
| DEBUG = False | ||||
|  | @ -56,30 +58,39 @@ def set_debug(val): | |||
|     DEBUG = val | ||||
| 
 | ||||
| 
 | ||||
| def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper_model): | ||||
| def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, upper_model, feat_maps): | ||||
|     cdef int[:, :] is_valid_ | ||||
|     cdef float[:, :] costs_ | ||||
|     cdef int[:, :] token_ids | ||||
|     lengths = [len(t) for t in tokvecs] | ||||
|     tokvecs = upper_model.ops.flatten(tokvecs) | ||||
|     is_valid = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='i') | ||||
|     costs = upper_model.ops.allocate((len(tokvecs), moves.n_moves), dtype='f') | ||||
|     token_ids = upper_model.ops.allocate((len(tokvecs), StateClass.nr_context_tokens()), | ||||
|                                          dtype='uint64') | ||||
|     token_ids = upper_model.ops.allocate((len(tokvecs), len(feat_maps)), dtype='i') | ||||
|     cached, backprops = zip(*[lyr.begin_update(tokvecs) for lyr in feat_maps]) | ||||
|     is_valid_ = is_valid | ||||
|     costs_ = costs | ||||
| 
 | ||||
|     def forward(states, drop=0.): | ||||
|     def forward(states_offsets, drop=0.): | ||||
|         nonlocal is_valid, costs, token_ids, moves | ||||
|         states, offsets = states_offsets | ||||
|         is_valid = is_valid[:len(states)] | ||||
|         costs = costs[:len(states)] | ||||
|         token_ids = token_ids[:len(states)] | ||||
|         is_valid = is_valid[:len(states)] | ||||
|         cdef StateClass state | ||||
|         for i, state in enumerate(states): | ||||
|         cdef int i | ||||
|         for i, (offset, state) in enumerate(zip(offsets, states)): | ||||
|             state.set_context_tokens(token_ids[i]) | ||||
|             moves.set_valid(&is_valid_[i, 0], state.c) | ||||
| 
 | ||||
|         features = cached[token_ids].sum(axis=1) | ||||
|         adjusted_ids = token_ids.copy() | ||||
|         for i, offset in enumerate(offsets): | ||||
|             adjusted_ids[i] *= token_ids[i] >= 0 | ||||
|             adjusted_ids[i] += offset | ||||
|         features = upper_model.ops.allocate((len(states), 64), dtype='f') | ||||
|         for i in range(len(states)): | ||||
|             for j, tok_i in enumerate(adjusted_ids[i]): | ||||
|                 if tok_i >= 0: | ||||
|                     features[i] += cached[j][tok_i] | ||||
| 
 | ||||
|         scores, bp_scores = upper_model.begin_update(features, drop=drop) | ||||
|         softmaxed = upper_model.ops.softmax(scores) | ||||
|  | @ -89,15 +100,16 @@ def get_greedy_model_for_batch(tokvecs, TransitionSystem moves, feat_maps, upper | |||
| 
 | ||||
|         def backward(golds, sgd=None): | ||||
|             nonlocal costs_, is_valid_, moves | ||||
|             cdef int i | ||||
|             for i, (state, gold) in enumerate(zip(states, golds)): | ||||
|                 moves.set_costs(&is_valid_[i, 0], &costs_[i, 0], | ||||
|                     state, gold) | ||||
|             d_scores = scores.copy() | ||||
|             d_scores.fill(0) | ||||
|             set_log_loss(upper_model.ops, d_scores, | ||||
|                 scores, is_valid_, costs_) | ||||
|                 scores, is_valid, costs) | ||||
|             d_tokens = bp_scores(d_scores, sgd) | ||||
|             return d_tokens | ||||
|             return (token_ids, d_tokens) | ||||
| 
 | ||||
|         return softmaxed, backward | ||||
| 
 | ||||
|  | @ -127,14 +139,18 @@ def transition_batch(TransitionSystem moves, states, scores): | |||
| 
 | ||||
| 
 | ||||
| def init_states(TransitionSystem moves, docs): | ||||
|     states = [] | ||||
|     cdef Doc doc | ||||
|     cdef StateClass state | ||||
|     offsets = [] | ||||
|     states = [] | ||||
|     offset = 0 | ||||
|     for i, doc in enumerate(docs): | ||||
|         state = StateClass.init(doc.c, doc.length) | ||||
|         moves.initialize_state(state.c) | ||||
|         states.append(state) | ||||
|     return states | ||||
|         offsets.append(offset) | ||||
|         offset += len(doc) | ||||
|     return states, offsets | ||||
| 
 | ||||
| 
 | ||||
| cdef class Parser: | ||||
|  | @ -184,18 +200,22 @@ cdef class Parser: | |||
|         cfg['actions'] = TransitionSystem.get_actions(**cfg) | ||||
|         self.moves = TransitionSystem(vocab.strings, cfg['actions']) | ||||
|         if model is None: | ||||
|             model = self.build_model(**cfg) | ||||
|         self.model = model | ||||
|             self.model, self.feature_maps = self.build_model(**cfg) | ||||
|         else: | ||||
|             self.model, self.feature_maps = model | ||||
|         self.cfg = cfg | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (Parser, (self.vocab, self.moves, self.model), None, None) | ||||
| 
 | ||||
|     def build_model(self, width=32, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): | ||||
|     def build_model(self, width=64, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): | ||||
|         nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR) | ||||
|         self.model = build_model(width*2, 2, self.moves.n_moves) | ||||
| 
 | ||||
|         model = chain(Maxout(width, width), Maxout(self.moves.n_moves, width)) | ||||
|         # TODO | ||||
|         self.feature_maps = [] #build_feature_maps(nr_context_tokens, width, nr_vector) | ||||
|         feature_maps = [Maxout(width, width) | ||||
|                         for i in range(nr_context_tokens)] | ||||
|         return model, feature_maps | ||||
| 
 | ||||
|     def __call__(self, Doc tokens): | ||||
|         """ | ||||
|  | @ -245,19 +265,21 @@ cdef class Parser: | |||
|         cdef Doc doc | ||||
|         cdef StateClass state | ||||
|         model = get_greedy_model_for_batch([d.tensor for d in docs], | ||||
|                     self.moves, self.model, self.feat_maps) | ||||
|         states = [StateClass.init(doc.c, doc.length) for doc in docs] | ||||
|         todo = list(states) | ||||
|                     self.moves, self.model, self.feature_maps) | ||||
|         states, offsets = init_states(self.moves, docs) | ||||
|         all_states = list(states) | ||||
|         todo = list(zip(states, offsets)) | ||||
|         while todo: | ||||
|             scores = model(todo) | ||||
|             transition_batch(self.moves, todo, scores) | ||||
|             todo = [st for st in states if not st.is_final()] | ||||
|         for state, doc in zip(states, docs): | ||||
|             states, offsets = zip(*todo) | ||||
|             scores = model((states, offsets)) | ||||
|             transition_batch(self.moves, states, scores) | ||||
|             todo = [st for st in todo if not st[0].py_is_final()] | ||||
|         for state, doc in zip(all_states, docs): | ||||
|             self.moves.finalize_state(state.c) | ||||
|             for i in range(doc.length): | ||||
|                 doc.c[i] = state.c._sent[i] | ||||
|         for doc in docs: | ||||
|             self.moves.finalize_parse(doc) | ||||
|             self.moves.finalize_doc(doc) | ||||
| 
 | ||||
|     def update(self, docs, golds, drop=0., sgd=None): | ||||
|         if isinstance(docs, Doc) and isinstance(golds, GoldParse): | ||||
|  | @ -266,33 +288,23 @@ cdef class Parser: | |||
|             self.moves.preprocess_gold(gold) | ||||
| 
 | ||||
|         model = get_greedy_model_for_batch([d.tensor for d in docs], | ||||
|                     self.moves, self.model, self.feat_maps) | ||||
|         states = init_states(self.moves, docs) | ||||
|                     self.moves, self.model, self.feature_maps) | ||||
|         states, offsets = init_states(self.moves, docs) | ||||
| 
 | ||||
|         d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] | ||||
|         output = list(d_tokens) | ||||
|         todo = zip(states, golds, d_tokens) | ||||
|         todo = zip(states, offsets, golds, d_tokens) | ||||
|         while todo: | ||||
|             states, golds, d_tokens = zip(*todo) | ||||
|             scores, finish_update = model.begin_update(token_ids) | ||||
|             d_state_features = finish_update(golds, sgd=sgd) | ||||
|             states, offsets, golds, d_tokens = zip(*todo) | ||||
|             scores, finish_update = model.begin_update((states, offsets)) | ||||
|             (token_ids, d_state_features) = finish_update(golds, sgd=sgd) | ||||
|             for i, token_ids in enumerate(token_ids): | ||||
|                 d_tokens[i][token_ids] += d_state_features[i] | ||||
|             transition_batch(self.moves, states) | ||||
|             transition_batch(self.moves, states, scores) | ||||
|             # Get unfinished states (and their matching gold and token gradients) | ||||
|             todo = filter(lambda sp: not sp[0].py_is_final(), todo) | ||||
|         return output | ||||
| 
 | ||||
|     def begin_training(self, docs, golds): | ||||
|         for gold in golds: | ||||
|             self.moves.preprocess_gold(gold) | ||||
|         states = self._init_states(docs) | ||||
|         tokvecs = [d.tensor for d in docs] | ||||
| 
 | ||||
|         features = self._get_features(states, tokvecs) | ||||
|         self.model.begin_training(features) | ||||
| 
 | ||||
| 
 | ||||
|     def step_through(self, Doc doc, GoldParse gold=None): | ||||
|         """ | ||||
|         Set up a stepwise state, to introspect and control the transition sequence. | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user