mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Fix mixture weights in fine_tune
This commit is contained in:
		
							parent
							
								
									335fa8b05c
								
							
						
					
					
						commit
						6259490347
					
				
							
								
								
									
										20
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								spacy/_ml.py
									
									
									
									
									
								
							| 
						 | 
					@ -372,21 +372,25 @@ def fine_tune(embedding, combine=None):
 | 
				
			||||||
        vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
 | 
					        vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
 | 
				
			||||||
        flat_tokvecs = embedding.ops.flatten(tokvecs)
 | 
					        flat_tokvecs = embedding.ops.flatten(tokvecs)
 | 
				
			||||||
        flat_vecs = embedding.ops.flatten(vecs)
 | 
					        flat_vecs = embedding.ops.flatten(vecs)
 | 
				
			||||||
 | 
					        alpha = model.mix
 | 
				
			||||||
 | 
					        minus = 1-model.mix
 | 
				
			||||||
        output = embedding.ops.unflatten(
 | 
					        output = embedding.ops.unflatten(
 | 
				
			||||||
                   (model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs),
 | 
					                   (alpha * flat_tokvecs + minus * flat_vecs), lengths)
 | 
				
			||||||
                    lengths)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def fine_tune_bwd(d_output, sgd=None):
 | 
					        def fine_tune_bwd(d_output, sgd=None):
 | 
				
			||||||
            bp_vecs([d_o * model.d_mix[0] for d_o in d_output], sgd=sgd)
 | 
					 | 
				
			||||||
            flat_grad = model.ops.flatten(d_output)
 | 
					            flat_grad = model.ops.flatten(d_output)
 | 
				
			||||||
            model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum()
 | 
					            model.d_mix += flat_tokvecs.dot(flat_grad.T).sum()
 | 
				
			||||||
            model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum()
 | 
					            model.d_mix += 1-flat_vecs.dot(flat_grad.T).sum()
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            bp_vecs([d_o * minus for d_o in d_output], sgd=sgd)
 | 
				
			||||||
 | 
					            d_output = [d_o * alpha for d_o in d_output]
 | 
				
			||||||
            sgd(model._mem.weights, model._mem.gradient, key=model.id)
 | 
					            sgd(model._mem.weights, model._mem.gradient, key=model.id)
 | 
				
			||||||
            return [d_o * model.d_mix[1] for d_o in d_output]
 | 
					            model.mix = model.ops.xp.minimum(model.mix, 1.0)
 | 
				
			||||||
 | 
					            return d_output
 | 
				
			||||||
        return output, fine_tune_bwd
 | 
					        return output, fine_tune_bwd
 | 
				
			||||||
    model = wrap(fine_tune_fwd, embedding)
 | 
					    model = wrap(fine_tune_fwd, embedding)
 | 
				
			||||||
    model.mix = model._mem.add((model.id, 'mix'), (2,))
 | 
					    model.mix = model._mem.add((model.id, 'mix'), (1,))
 | 
				
			||||||
    model.mix.fill(0.5)
 | 
					    model.mix.fill(0.0)
 | 
				
			||||||
    model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
 | 
					    model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user