mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	Add BiRNN for entailment
Hastily add bidirectional RNN to entailment example
This commit is contained in:
		
							parent
							
								
									f123f92e0c
								
							
						
					
					
						commit
						ca996fc01a
					
				|  | @ -12,7 +12,7 @@ from keras.optimizers import Adam | |||
| from keras.layers.normalization import BatchNormalization | ||||
| 
 | ||||
| 
 | ||||
| def build_model(vectors, shape, settings): | ||||
| def build_model(vectors, shape, settings, use_rnn_encoding=False): | ||||
|     '''Compile the model.''' | ||||
|     max_length, nr_hidden, nr_class = shape | ||||
|     # Declare inputs. | ||||
|  | @ -21,6 +21,8 @@ def build_model(vectors, shape, settings): | |||
| 
 | ||||
|     # Construct operations, which we'll chain together. | ||||
|     embed = _StaticEmbedding(vectors, max_length, nr_hidden) | ||||
|     if use_rnn_encoding: | ||||
|         encode = _BiLSTMEncode(max_length, nr_hidden) | ||||
|     attend = _Attention(max_length, nr_hidden) | ||||
|     align = _SoftAlignment(max_length, nr_hidden) | ||||
|     compare = _Comparison(max_length, nr_hidden) | ||||
|  | @ -29,6 +31,10 @@ def build_model(vectors, shape, settings): | |||
|     # Declare the model as a computational graph. | ||||
|     sent1 = embed(ids1) # Shape: (i, n) | ||||
|     sent2 = embed(ids2) # Shape: (j, n) | ||||
|      | ||||
|     if use_rnn_encoding: | ||||
|         sent1 = encode(sent1) | ||||
|         sent2 = encode(sent2) | ||||
| 
 | ||||
|     attention = attend(sent1, sent2)  # Shape: (i, j) | ||||
| 
 | ||||
|  | @ -72,7 +78,14 @@ class _StaticEmbedding(object): | |||
| 
 | ||||
|     def __call__(self, sentence): | ||||
|         return self.project(self.embed(sentence)) | ||||
|      | ||||
| class _BiRNNEncoding(object): | ||||
|     def __init__(self, max_length, nr_out): | ||||
|         self.model = Sequential() | ||||
|         self.model.add(Bidirectional(LSTM(nr_out, input_length=max_length))) | ||||
| 
 | ||||
|     def __call__(self, sentence): | ||||
|         return self.model(sentence) | ||||
| 
 | ||||
| class _Attention(object): | ||||
|     def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user