mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	working residual net
This commit is contained in:
		
							parent
							
								
									b439e04f8d
								
							
						
					
					
						commit
						f99f5b75dc
					
				|  | @ -36,8 +36,7 @@ def read_conllx(loc, n=0): | ||||||
|                 try: |                 try: | ||||||
|                     id_ = int(id_) - 1 |                     id_ = int(id_) - 1 | ||||||
|                     head = (int(head) - 1) if head != '0' else id_ |                     head = (int(head) - 1) if head != '0' else id_ | ||||||
|                     dep = 'ROOT' if dep == 'root' else 'unlabelled' |                     dep = 'ROOT' if dep == 'root' else dep #'unlabelled' | ||||||
|                     # Hack for efficiency |  | ||||||
|                     tokens.append((id_, word, pos+'__'+morph, head, dep, 'O')) |                     tokens.append((id_, word, pos+'__'+morph, head, dep, 'O')) | ||||||
|                 except: |                 except: | ||||||
|                     raise |                     raise | ||||||
|  | @ -82,6 +81,7 @@ def organize_data(vocab, train_sents): | ||||||
| def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | ||||||
|     LangClass = spacy.util.get_lang_class(lang_name) |     LangClass = spacy.util.get_lang_class(lang_name) | ||||||
|     train_sents = list(read_conllx(train_loc)) |     train_sents = list(read_conllx(train_loc)) | ||||||
|  |     dev_sents = list(read_conllx(dev_loc)) | ||||||
|     train_sents = PseudoProjectivity.preprocess_training_data(train_sents) |     train_sents = PseudoProjectivity.preprocess_training_data(train_sents) | ||||||
| 
 | 
 | ||||||
|     actions = ArcEager.get_actions(gold_parses=train_sents) |     actions = ArcEager.get_actions(gold_parses=train_sents) | ||||||
|  | @ -136,8 +136,11 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | ||||||
|     parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) |     parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) | ||||||
| 
 | 
 | ||||||
|     Xs, ys = organize_data(vocab, train_sents) |     Xs, ys = organize_data(vocab, train_sents) | ||||||
|     Xs = Xs[:100] |     dev_Xs, dev_ys = organize_data(vocab, dev_sents) | ||||||
|     ys = ys[:100] |     Xs = Xs[:500] | ||||||
|  |     ys = ys[:500] | ||||||
|  |     dev_Xs = dev_Xs[:100] | ||||||
|  |     dev_ys = dev_ys[:100] | ||||||
|     with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer): |     with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer): | ||||||
|         docs = list(Xs) |         docs = list(Xs) | ||||||
|         for doc in docs: |         for doc in docs: | ||||||
|  | @ -145,7 +148,8 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | ||||||
|         parser.begin_training(docs, ys) |         parser.begin_training(docs, ys) | ||||||
|         nn_loss = [0.] |         nn_loss = [0.] | ||||||
|         def track_progress(): |         def track_progress(): | ||||||
|             scorer = score_model(vocab, encoder, tagger, parser, Xs, ys) |             with encoder.tagger.use_params(optimizer.averages): | ||||||
|  |                 scorer = score_model(vocab, encoder, tagger, parser, dev_Xs, dev_ys) | ||||||
|             itn = len(nn_loss) |             itn = len(nn_loss) | ||||||
|             print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc)) |             print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc)) | ||||||
|             nn_loss.append(0.) |             nn_loss.append(0.) | ||||||
|  | @ -161,6 +165,7 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | ||||||
|                 tagger.update(doc, gold) |                 tagger.update(doc, gold) | ||||||
|             d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer) |             d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer) | ||||||
|             upd_tokvecs(d_tokvecs, sgd=optimizer) |             upd_tokvecs(d_tokvecs, sgd=optimizer) | ||||||
|  |             encoder.update(docs, golds, optimizer) | ||||||
|             nn_loss[-1] += loss |             nn_loss[-1] += loss | ||||||
|     nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) |     nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) | ||||||
|     nlp.end_training(model_dir) |     nlp.end_training(model_dir) | ||||||
|  |  | ||||||
							
								
								
									
										25
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -5,6 +5,7 @@ from thinc.neural._classes.hash_embed import HashEmbed | ||||||
| from thinc.neural._classes.convolution import ExtractWindow | from thinc.neural._classes.convolution import ExtractWindow | ||||||
| from thinc.neural._classes.static_vectors import StaticVectors | from thinc.neural._classes.static_vectors import StaticVectors | ||||||
| from thinc.neural._classes.batchnorm import BatchNorm | from thinc.neural._classes.batchnorm import BatchNorm | ||||||
|  | from thinc.neural._classes.resnet import Residual | ||||||
| 
 | 
 | ||||||
| from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP | from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP | ||||||
| 
 | 
 | ||||||
|  | @ -36,8 +37,7 @@ def build_debug_model(state2vec, width, depth, nr_class): | ||||||
|     with Model.define_operators({'>>': chain, '**': clone}): |     with Model.define_operators({'>>': chain, '**': clone}): | ||||||
|         model = ( |         model = ( | ||||||
|             state2vec |             state2vec | ||||||
|             >> Maxout(width) |             >> Maxout(nr_class) | ||||||
|             >> Affine(nr_class) |  | ||||||
|         ) |         ) | ||||||
|     return model |     return model | ||||||
| 
 | 
 | ||||||
|  | @ -64,13 +64,8 @@ def build_debug_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): | ||||||
| def build_state2vec(nr_context_tokens, width, nr_vector=1000): | def build_state2vec(nr_context_tokens, width, nr_vector=1000): | ||||||
|     ops = Model.ops |     ops = Model.ops | ||||||
|     with Model.define_operators({'|': concatenate, '+': add, '>>': chain}): |     with Model.define_operators({'|': concatenate, '+': add, '>>': chain}): | ||||||
| 
 |         hiddens = [get_col(i) >> Maxout(width) for i in range(nr_context_tokens)] | ||||||
|         hiddens = [get_col(i) >> Affine(width) for i in range(nr_context_tokens)] |         model = get_token_vectors >> add(*hiddens) | ||||||
|         model = ( |  | ||||||
|             get_token_vectors |  | ||||||
|             >> add(*hiddens) |  | ||||||
|             >> Maxout(width) |  | ||||||
|         ) |  | ||||||
|     return model |     return model | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -78,7 +73,7 @@ def print_shape(prefix): | ||||||
|     def forward(X, drop=0.): |     def forward(X, drop=0.): | ||||||
|         return X, lambda dX, **kwargs: dX |         return X, lambda dX, **kwargs: dX | ||||||
|     return layerize(forward) |     return layerize(forward) | ||||||
|      | 
 | ||||||
| 
 | 
 | ||||||
| @layerize | @layerize | ||||||
| def get_token_vectors(tokens_attrs_vectors, drop=0.): | def get_token_vectors(tokens_attrs_vectors, drop=0.): | ||||||
|  | @ -173,9 +168,10 @@ def _reshape(layer): | ||||||
| @layerize | @layerize | ||||||
| def flatten(seqs, drop=0.): | def flatten(seqs, drop=0.): | ||||||
|     ops = Model.ops |     ops = Model.ops | ||||||
|  |     lengths = [len(seq) for seq in seqs] | ||||||
|     def finish_update(d_X, sgd=None): |     def finish_update(d_X, sgd=None): | ||||||
|         return d_X |         return ops.unflatten(d_X, lengths) | ||||||
|     X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs]) |     X = ops.xp.vstack(seqs) | ||||||
|     return X, finish_update |     return X, finish_update | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -194,8 +190,9 @@ def build_tok2vec(lang, width, depth=2, embed_size=1000): | ||||||
|                 #(static | prefix | suffix | shape) |                 #(static | prefix | suffix | shape) | ||||||
|                 (lower | prefix | suffix | shape | tag) |                 (lower | prefix | suffix | shape | tag) | ||||||
|                 >> Maxout(width, width*5) |                 >> Maxout(width, width*5) | ||||||
|                 #>> (ExtractWindow(nW=1) >> Maxout(width, width*3)) |                 >> Residual((ExtractWindow(nW=1) >> Maxout(width, width*3))) | ||||||
|                 #>> (ExtractWindow(nW=1) >> Maxout(width, width*3)) |                 >> Residual((ExtractWindow(nW=1) >> Maxout(width, width*3))) | ||||||
|  |                 >> Residual((ExtractWindow(nW=1) >> Maxout(width, width*3))) | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|     return tok2vec |     return tok2vec | ||||||
|  |  | ||||||
|  | @ -9,7 +9,7 @@ from .syntax.parser cimport Parser | ||||||
| from .syntax.ner cimport BiluoPushDown | from .syntax.ner cimport BiluoPushDown | ||||||
| from .syntax.arc_eager cimport ArcEager | from .syntax.arc_eager cimport ArcEager | ||||||
| from .tagger import Tagger | from .tagger import Tagger | ||||||
| from ._ml import build_tok2vec | from ._ml import build_tok2vec, flatten | ||||||
| 
 | 
 | ||||||
| # TODO: The disorganization here is pretty embarrassing. At least it's only | # TODO: The disorganization here is pretty embarrassing. At least it's only | ||||||
| # internals. | # internals. | ||||||
|  | @ -24,7 +24,8 @@ class TokenVectorEncoder(object): | ||||||
|         self.model = build_tok2vec(vocab.lang, 64, **cfg) |         self.model = build_tok2vec(vocab.lang, 64, **cfg) | ||||||
|         self.tagger = chain( |         self.tagger = chain( | ||||||
|                         self.model, |                         self.model, | ||||||
|                         Softmax(self.vocab.morphology.n_tags)) |                         flatten, | ||||||
|  |                         Softmax(self.vocab.morphology.n_tags, 64)) | ||||||
| 
 | 
 | ||||||
|     def __call__(self, doc): |     def __call__(self, doc): | ||||||
|         doc.tensor = self.model([doc])[0] |         doc.tensor = self.model([doc])[0] | ||||||
|  |  | ||||||
|  | @ -48,7 +48,7 @@ cdef class StateClass: | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR): |     def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR): | ||||||
|         return 4 |         return 11 | ||||||
| 
 | 
 | ||||||
|     def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, |     def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, | ||||||
|             nL=2, nR=2): |             nL=2, nR=2): | ||||||
|  | @ -56,14 +56,14 @@ cdef class StateClass: | ||||||
|         output[1] = self.B(1) |         output[1] = self.B(1) | ||||||
|         output[2] = self.S(0) |         output[2] = self.S(0) | ||||||
|         output[3] = self.S(1) |         output[3] = self.S(1) | ||||||
|         #output[4] = self.L(self.S(0), 1) |         output[4] = self.L(self.S(0), 1) | ||||||
|         #output[5] = self.L(self.S(0), 2) |         output[5] = self.L(self.S(0), 2) | ||||||
|         #output[6] = self.R(self.S(0), 1) |         output[6] = self.R(self.S(0), 1) | ||||||
|         #output[7] = self.R(self.S(0), 2) |         output[7] = self.R(self.S(0), 2) | ||||||
|         #output[7] = self.L(self.S(1), 1) |         output[7] = self.L(self.S(1), 1) | ||||||
|         #output[8] = self.L(self.S(1), 2) |         output[8] = self.L(self.S(1), 2) | ||||||
|         #output[9] = self.R(self.S(1), 1) |         output[9] = self.R(self.S(1), 1) | ||||||
|         #output[10] = self.R(self.S(1), 2) |         output[10] = self.R(self.S(1), 2) | ||||||
| 
 | 
 | ||||||
|     def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): |     def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): | ||||||
|         cdef int i, j, tok_i |         cdef int i, j, tok_i | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user