mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix use of dropout in sentiment analysis LSTM example
This commit is contained in:
		
							parent
							
								
									1ed40682a3
								
							
						
					
					
						commit
						7793e2ad82
					
				| 
						 | 
					@ -111,10 +111,9 @@ def compile_lstm(embeddings, shape, settings):
 | 
				
			||||||
            mask_zero=True
 | 
					            mask_zero=True
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    model.add(TimeDistributed(Dense(shape['nr_hidden'] * 2, bias=False)))
 | 
					    model.add(TimeDistributed(Dense(shape['nr_hidden'], bias=False)))
 | 
				
			||||||
    model.add(Dropout(settings['dropout']))
 | 
					    model.add(Bidirectional(LSTM(shape['nr_hidden'], dropout_U=settings['dropout'],
 | 
				
			||||||
    model.add(Bidirectional(LSTM(shape['nr_hidden'])))
 | 
					                                 dropout_W=settings['dropout'])))
 | 
				
			||||||
    model.add(Dropout(settings['dropout']))
 | 
					 | 
				
			||||||
    model.add(Dense(shape['nr_class'], activation='sigmoid'))
 | 
					    model.add(Dense(shape['nr_class'], activation='sigmoid'))
 | 
				
			||||||
    model.compile(optimizer=Adam(lr=settings['lr']), loss='binary_crossentropy',
 | 
					    model.compile(optimizer=Adam(lr=settings['lr']), loss='binary_crossentropy',
 | 
				
			||||||
		  metrics=['accuracy'])
 | 
							  metrics=['accuracy'])
 | 
				
			||||||
| 
						 | 
					@ -195,7 +194,7 @@ def main(model_dir, train_dir, dev_dir,
 | 
				
			||||||
        dev_labels = numpy.asarray(dev_labels, dtype='int32')
 | 
					        dev_labels = numpy.asarray(dev_labels, dtype='int32')
 | 
				
			||||||
        lstm = train(train_texts, train_labels, dev_texts, dev_labels,
 | 
					        lstm = train(train_texts, train_labels, dev_texts, dev_labels,
 | 
				
			||||||
                     {'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
 | 
					                     {'nr_hidden': nr_hidden, 'max_length': max_length, 'nr_class': 1},
 | 
				
			||||||
                     {'dropout': 0.5, 'lr': learn_rate},
 | 
					                     {'dropout': dropout, 'lr': learn_rate},
 | 
				
			||||||
                     {},
 | 
					                     {},
 | 
				
			||||||
                     nb_epoch=nb_epoch, batch_size=batch_size)
 | 
					                     nb_epoch=nb_epoch, batch_size=batch_size)
 | 
				
			||||||
        weights = lstm.get_weights()
 | 
					        weights = lstm.get_weights()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user