mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 21:51:24 +03:00 
			
		
		
		
	Add BiRNN for entailment
Hastily add bidirectional RNN to entailment example
This commit is contained in:
		
							parent
							
								
									f123f92e0c
								
							
						
					
					
						commit
						ca996fc01a
					
				|  | @ -12,7 +12,7 @@ from keras.optimizers import Adam | ||||||
| from keras.layers.normalization import BatchNormalization | from keras.layers.normalization import BatchNormalization | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def build_model(vectors, shape, settings): | def build_model(vectors, shape, settings, use_rnn_encoding=False): | ||||||
|     '''Compile the model.''' |     '''Compile the model.''' | ||||||
|     max_length, nr_hidden, nr_class = shape |     max_length, nr_hidden, nr_class = shape | ||||||
|     # Declare inputs. |     # Declare inputs. | ||||||
|  | @ -21,6 +21,8 @@ def build_model(vectors, shape, settings): | ||||||
| 
 | 
 | ||||||
|     # Construct operations, which we'll chain together. |     # Construct operations, which we'll chain together. | ||||||
|     embed = _StaticEmbedding(vectors, max_length, nr_hidden) |     embed = _StaticEmbedding(vectors, max_length, nr_hidden) | ||||||
|  |     if use_rnn_encoding: | ||||||
|  |         encode = _BiLSTMEncode(max_length, nr_hidden) | ||||||
|     attend = _Attention(max_length, nr_hidden) |     attend = _Attention(max_length, nr_hidden) | ||||||
|     align = _SoftAlignment(max_length, nr_hidden) |     align = _SoftAlignment(max_length, nr_hidden) | ||||||
|     compare = _Comparison(max_length, nr_hidden) |     compare = _Comparison(max_length, nr_hidden) | ||||||
|  | @ -29,6 +31,10 @@ def build_model(vectors, shape, settings): | ||||||
|     # Declare the model as a computational graph. |     # Declare the model as a computational graph. | ||||||
|     sent1 = embed(ids1) # Shape: (i, n) |     sent1 = embed(ids1) # Shape: (i, n) | ||||||
|     sent2 = embed(ids2) # Shape: (j, n) |     sent2 = embed(ids2) # Shape: (j, n) | ||||||
|  |      | ||||||
|  |     if use_rnn_encoding: | ||||||
|  |         sent1 = encode(sent1) | ||||||
|  |         sent2 = encode(sent2) | ||||||
| 
 | 
 | ||||||
|     attention = attend(sent1, sent2)  # Shape: (i, j) |     attention = attend(sent1, sent2)  # Shape: (i, j) | ||||||
| 
 | 
 | ||||||
|  | @ -72,7 +78,14 @@ class _StaticEmbedding(object): | ||||||
| 
 | 
 | ||||||
|     def __call__(self, sentence): |     def __call__(self, sentence): | ||||||
|         return self.project(self.embed(sentence)) |         return self.project(self.embed(sentence)) | ||||||
|  |      | ||||||
|  | class _BiRNNEncoding(object): | ||||||
|  |     def __init__(self, max_length, nr_out): | ||||||
|  |         self.model = Sequential() | ||||||
|  |         self.model.add(Bidirectional(LSTM(nr_out, input_length=max_length))) | ||||||
| 
 | 
 | ||||||
|  |     def __call__(self, sentence): | ||||||
|  |         return self.model(sentence) | ||||||
| 
 | 
 | ||||||
| class _Attention(object): | class _Attention(object): | ||||||
|     def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'): |     def __init__(self, max_length, nr_hidden, dropout=0.0, L2=1e-4, activation='relu'): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user