mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Pass dropout through to embed tables
This commit is contained in:
		
							parent
							
								
									21d11936fe
								
							
						
					
					
						commit
						fbba7c517e
					
				| 
						 | 
					@ -266,15 +266,10 @@ def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
 | 
				
			||||||
    ops = embed.ops
 | 
					    ops = embed.ops
 | 
				
			||||||
    def add_history_fwd(vectors_hists, drop=0.):
 | 
					    def add_history_fwd(vectors_hists, drop=0.):
 | 
				
			||||||
        vectors, hist_ids = vectors_hists
 | 
					        vectors, hist_ids = vectors_hists
 | 
				
			||||||
        hist_feats, bp_hists = embed.begin_update(hist_ids)
 | 
					        hist_feats, bp_hists = embed.begin_update(hist_ids, drop=drop)
 | 
				
			||||||
        outputs = ops.xp.hstack((vectors, hist_feats))
 | 
					        outputs = ops.xp.hstack((vectors, hist_feats))
 | 
				
			||||||
        mask = ops.get_dropout_mask(outputs.shape, drop)
 | 
					 | 
				
			||||||
        if mask is not None:
 | 
					 | 
				
			||||||
            outputs *= mask
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def add_history_bwd(d_outputs, sgd=None):
 | 
					        def add_history_bwd(d_outputs, sgd=None):
 | 
				
			||||||
            if mask is not None:
 | 
					 | 
				
			||||||
                d_outputs *= mask
 | 
					 | 
				
			||||||
            d_vectors = d_outputs[:, :vectors.shape[1]]
 | 
					            d_vectors = d_outputs[:, :vectors.shape[1]]
 | 
				
			||||||
            d_hists = d_outputs[:, vectors.shape[1]:]
 | 
					            d_hists = d_outputs[:, vectors.shape[1]:]
 | 
				
			||||||
            bp_hists(d_hists, sgd=sgd)
 | 
					            bp_hists(d_hists, sgd=sgd)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user