mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	fix convolution layer
This commit is contained in:
		
							parent
							
								
									dd691d0053
								
							
						
					
					
						commit
						7edb2e1711
					
				| 
						 | 
				
			
			@ -12,9 +12,9 @@ from examples.pipeline.wiki_entity_linking import run_el, training_set_creator,
 | 
			
		|||
 | 
			
		||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic
 | 
			
		||||
 | 
			
		||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone
 | 
			
		||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
 | 
			
		||||
from thinc.v2v import Model, Maxout, Affine
 | 
			
		||||
from thinc.t2v import Pooling, mean_pool
 | 
			
		||||
from thinc.t2v import Pooling, mean_pool, sum_pool
 | 
			
		||||
from thinc.t2t import ParametricAttention
 | 
			
		||||
from thinc.misc import Residual
 | 
			
		||||
from thinc.misc import LayerNorm as LN
 | 
			
		||||
| 
						 | 
				
			
			@ -96,13 +96,13 @@ class EL_Model:
 | 
			
		|||
            try:
 | 
			
		||||
                # if to_print:
 | 
			
		||||
                    # print()
 | 
			
		||||
                    # print(article_count, "Training on article", article_id)
 | 
			
		||||
                print(article_count, "Training on article", article_id)
 | 
			
		||||
                article_count += 1
 | 
			
		||||
                article_docs = list()
 | 
			
		||||
                entities = list()
 | 
			
		||||
                golds = list()
 | 
			
		||||
                for inst_cluster in inst_cluster_set:
 | 
			
		||||
                    if instance_pos_count < 2:   # TODO remove
 | 
			
		||||
                    if instance_pos_count < 2:  # TODO del
 | 
			
		||||
                        article_docs.append(train_doc[article_id])
 | 
			
		||||
                        entities.append(train_pos.get(inst_cluster))
 | 
			
		||||
                        golds.append(float(1.0))
 | 
			
		||||
| 
						 | 
				
			
			@ -228,16 +228,23 @@ class EL_Model:
 | 
			
		|||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _encoder(in_width, hidden_width):
 | 
			
		||||
        conv_depth = 1
 | 
			
		||||
        cnn_maxout_pieces = 3
 | 
			
		||||
 | 
			
		||||
        with Model.define_operators({">>": chain}):
 | 
			
		||||
            convolution = Residual((ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3, pieces=cnn_maxout_pieces))))
 | 
			
		||||
 | 
			
		||||
            encoder = SpacyVectors \
 | 
			
		||||
                      >> with_flatten(LN(Maxout(in_width, in_width)) >> convolution ** conv_depth, pad=conv_depth) \
 | 
			
		||||
                      >> flatten_add_lengths \
 | 
			
		||||
                      >> ParametricAttention(in_width)\
 | 
			
		||||
                      >> Pooling(mean_pool) \
 | 
			
		||||
                >> (ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3)))  \
 | 
			
		||||
                      >> Residual(zero_init(Maxout(in_width, in_width))) \
 | 
			
		||||
                      >> zero_init(Affine(hidden_width, in_width, drop_factor=0.0))
 | 
			
		||||
 | 
			
		||||
            # TODO: ReLu instead of LN(Maxout)  ?
 | 
			
		||||
            # TODO: more convolutions ?
 | 
			
		||||
            # sum_pool or mean_pool ?
 | 
			
		||||
 | 
			
		||||
        return encoder
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -261,16 +268,17 @@ class EL_Model:
 | 
			
		|||
            print(doc_encoding)
 | 
			
		||||
 | 
			
		||||
        doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
 | 
			
		||||
        entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
 | 
			
		||||
        concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
 | 
			
		||||
 | 
			
		||||
        print("doc_encodings", len(doc_encodings), doc_encodings)
 | 
			
		||||
 | 
			
		||||
        entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
 | 
			
		||||
        print("entity_encodings", len(entity_encodings), entity_encodings)
 | 
			
		||||
        print("concat_encodings", len(concat_encodings), concat_encodings)
 | 
			
		||||
 | 
			
		||||
        concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
 | 
			
		||||
        # print("concat_encodings", len(concat_encodings), concat_encodings)
 | 
			
		||||
 | 
			
		||||
        predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP)
 | 
			
		||||
        print("predictions", predictions)
 | 
			
		||||
        predictions = self.model.ops.flatten(predictions)
 | 
			
		||||
        print("predictions", predictions)
 | 
			
		||||
        golds = self.model.ops.asarray(golds)
 | 
			
		||||
 | 
			
		||||
        loss, d_scores = self.get_loss(predictions, golds)
 | 
			
		||||
| 
						 | 
				
			
			@ -287,15 +295,15 @@ class EL_Model:
 | 
			
		|||
 | 
			
		||||
        d_scores = d_scores.reshape((-1, 1))
 | 
			
		||||
        d_scores = d_scores.astype(np.float32)
 | 
			
		||||
        print("d_scores", d_scores)
 | 
			
		||||
        # print("d_scores", d_scores)
 | 
			
		||||
 | 
			
		||||
        model_gradient = bp_model(d_scores, sgd=self.sgd)
 | 
			
		||||
        print("model_gradient", model_gradient)
 | 
			
		||||
        # print("model_gradient", model_gradient)
 | 
			
		||||
 | 
			
		||||
        doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
 | 
			
		||||
        print("doc_gradient", doc_gradient)
 | 
			
		||||
        # print("doc_gradient", doc_gradient)
 | 
			
		||||
        entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
 | 
			
		||||
        print("entity_gradient", entity_gradient)
 | 
			
		||||
        # print("entity_gradient", entity_gradient)
 | 
			
		||||
 | 
			
		||||
        bp_doc(doc_gradient)
 | 
			
		||||
        bp_encoding(entity_gradient)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user