mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	TextCatParametricAttention.v1: set key transform dimensions
This is necessary for tok2vec implementations that initialize lazily (e.g. curated transformers).
This commit is contained in:
		
							parent
							
								
									e2a3952de5
								
							
						
					
					
						commit
						ba15b8edd0
					
				|  | @ -264,6 +264,7 @@ def _build_parametric_attention_with_residual_nonlinear( | |||
| 
 | ||||
|         parametric_attention.set_ref("tok2vec", tok2vec) | ||||
|         parametric_attention.set_ref("attention_layer", attention_layer) | ||||
|         parametric_attention.set_ref("key_transform", key_transform) | ||||
|         parametric_attention.set_ref("nonlinear_layer", nonlinear_layer) | ||||
|         parametric_attention.set_ref("norm_layer", norm_layer) | ||||
| 
 | ||||
|  | @ -273,8 +274,10 @@ def _build_parametric_attention_with_residual_nonlinear( | |||
| def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model: | ||||
|     tok2vec_width = get_tok2vec_width(model) | ||||
|     model.get_ref("attention_layer").set_dim("nO", tok2vec_width) | ||||
|     model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width) | ||||
|     model.get_ref("key_transform").set_dim("nI", tok2vec_width) | ||||
|     model.get_ref("key_transform").set_dim("nO", tok2vec_width) | ||||
|     model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width) | ||||
|     model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width) | ||||
|     model.get_ref("norm_layer").set_dim("nI", tok2vec_width) | ||||
|     model.get_ref("norm_layer").set_dim("nO", tok2vec_width) | ||||
|     init_chain(model, X, Y) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user