mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix PyTorch BiLSTM
This commit is contained in:
		
							parent
							
								
									a26fe8e7bb
								
							
						
					
					
						commit
						afeddfff26
					
				| 
						 | 
					@ -29,6 +29,7 @@ from . import util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    import torch.nn
 | 
					    import torch.nn
 | 
				
			||||||
 | 
					    from thinc.extra.wrappers import PyTorchWrapperRNN
 | 
				
			||||||
except:
 | 
					except:
 | 
				
			||||||
    torch = None
 | 
					    torch = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -252,7 +253,7 @@ def link_vectors_to_models(vocab):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
 | 
					def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
 | 
				
			||||||
    if depth == 0:
 | 
					    if depth == 0:
 | 
				
			||||||
        return noop()
 | 
					        return layerize(noop())
 | 
				
			||||||
    model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
 | 
					    model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
 | 
				
			||||||
    return with_square_sequences(PyTorchWrapperRNN(model))
 | 
					    return with_square_sequences(PyTorchWrapperRNN(model))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -299,7 +300,6 @@ def Tok2Vec(width, embed_size, **kwargs):
 | 
				
			||||||
            ExtractWindow(nW=1)
 | 
					            ExtractWindow(nW=1)
 | 
				
			||||||
            >> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
 | 
					            >> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
        tok2vec = (
 | 
					        tok2vec = (
 | 
				
			||||||
            FeatureExtracter(cols)
 | 
					            FeatureExtracter(cols)
 | 
				
			||||||
            >> with_flatten(
 | 
					            >> with_flatten(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user