mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Update textcat ensemble model
This commit is contained in:
		
							parent
							
								
									f50502dad7
								
							
						
					
					
						commit
						c2a18e4fa3
					
				| 
						 | 
					@ -72,23 +72,20 @@ def build_text_classifier_v2(
 | 
				
			||||||
        attention_layer = ParametricAttention(
 | 
					        attention_layer = ParametricAttention(
 | 
				
			||||||
            width
 | 
					            width
 | 
				
			||||||
        )  # TODO: benchmark performance difference of this layer
 | 
					        )  # TODO: benchmark performance difference of this layer
 | 
				
			||||||
        maxout_layer = Maxout(nO=width, nI=width)
 | 
					        maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True)
 | 
				
			||||||
        linear_layer = Linear(nO=nO, nI=width)
 | 
					 | 
				
			||||||
        cnn_model = (
 | 
					        cnn_model = (
 | 
				
			||||||
            tok2vec
 | 
					            tok2vec
 | 
				
			||||||
            >> list2ragged()
 | 
					            >> list2ragged()
 | 
				
			||||||
            >> attention_layer
 | 
					            >> attention_layer
 | 
				
			||||||
            >> reduce_sum()
 | 
					            >> reduce_sum()
 | 
				
			||||||
            >> residual(maxout_layer)
 | 
					            >> residual(maxout_layer)
 | 
				
			||||||
            >> linear_layer
 | 
					 | 
				
			||||||
            >> Dropout(0.0)
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        nO_double = nO * 2 if nO else None
 | 
					        nO_double = nO * 2 if nO else None
 | 
				
			||||||
        if exclusive_classes:
 | 
					        if exclusive_classes:
 | 
				
			||||||
            output_layer = Softmax(nO=nO, nI=nO_double)
 | 
					            output_layer = Softmax(nO=nO, nI=nO_double)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            output_layer = Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
 | 
					            output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
 | 
				
			||||||
        model = (linear_model | cnn_model) >> output_layer
 | 
					        model = (linear_model | cnn_model) >> output_layer
 | 
				
			||||||
        model.set_ref("tok2vec", tok2vec)
 | 
					        model.set_ref("tok2vec", tok2vec)
 | 
				
			||||||
    if model.has_dim("nO") is not False:
 | 
					    if model.has_dim("nO") is not False:
 | 
				
			||||||
| 
						 | 
					@ -96,7 +93,6 @@ def build_text_classifier_v2(
 | 
				
			||||||
    model.set_ref("output_layer", linear_model.get_ref("output_layer"))
 | 
					    model.set_ref("output_layer", linear_model.get_ref("output_layer"))
 | 
				
			||||||
    model.set_ref("attention_layer", attention_layer)
 | 
					    model.set_ref("attention_layer", attention_layer)
 | 
				
			||||||
    model.set_ref("maxout_layer", maxout_layer)
 | 
					    model.set_ref("maxout_layer", maxout_layer)
 | 
				
			||||||
    model.set_ref("linear_layer", linear_layer)
 | 
					 | 
				
			||||||
    model.attrs["multi_label"] = not exclusive_classes
 | 
					    model.attrs["multi_label"] = not exclusive_classes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.init = init_ensemble_textcat
 | 
					    model.init = init_ensemble_textcat
 | 
				
			||||||
| 
						 | 
					@ -108,7 +104,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
 | 
				
			||||||
    model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
 | 
					    model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
 | 
				
			||||||
    model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
 | 
					    model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
 | 
				
			||||||
    model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
 | 
					    model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
 | 
				
			||||||
    model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
 | 
					 | 
				
			||||||
    init_chain(model, X, Y)
 | 
					    init_chain(model, X, Y)
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user