mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Remove references to config
Replaced with model arguments
This commit is contained in:
		
							parent
							
								
									c0cd5025e3
								
							
						
					
					
						commit
						1c697b4011
					
				| 
						 | 
				
			
			@ -461,34 +461,43 @@ import .coref_util_wl as utils
 | 
			
		|||
# TODO rename to plain coref
 | 
			
		||||
@registry.architectures("spacy.WLCoref.v1")
 | 
			
		||||
def build_wl_coref_model(
 | 
			
		||||
        #TODO add other hyperparams
 | 
			
		||||
    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
			
		||||
    embedding_size: int = 20,
 | 
			
		||||
    hidden_size: int = 1024,
 | 
			
		||||
    n_hidden_layers: int = 1, # TODO rename to "depth"?
 | 
			
		||||
    dropout: float = 0.3,
 | 
			
		||||
    # pairs to keep per mention after rough scoring
 | 
			
		||||
    # TODO change to meaningful name
 | 
			
		||||
    rough_k: int = 50,
 | 
			
		||||
    # TODO is this not a training loop setting?
 | 
			
		||||
    a_scoring_batch_size: int = 512,
 | 
			
		||||
    # span predictor embeddings
 | 
			
		||||
    sp_embedding_size: int = 64,
 | 
			
		||||
    ):
 | 
			
		||||
    
 | 
			
		||||
    # TODO change to use passed in values for config
 | 
			
		||||
    config = utils._load_config("/dev/null")
 | 
			
		||||
    with Model.define_operators({">>": chain}):
 | 
			
		||||
 | 
			
		||||
        coref_scorer, span_predictor = configure_pytorch_modules(config)
 | 
			
		||||
        # TODO chain tok2vec with these models
 | 
			
		||||
        # TODO fix device - should be automatic
 | 
			
		||||
        device = "gpu:0"
 | 
			
		||||
        coref_scorer = PyTorchWrapper(
 | 
			
		||||
            CorefScorer(
 | 
			
		||||
                config.device,
 | 
			
		||||
                config.embedding_size,
 | 
			
		||||
                config.hidden_size,
 | 
			
		||||
                config.n_hidden_layers,
 | 
			
		||||
                config.dropout_rate,
 | 
			
		||||
                config.rough_k,
 | 
			
		||||
                config.a_scoring_batch_size
 | 
			
		||||
                device,
 | 
			
		||||
                embedding_size,
 | 
			
		||||
                hidden_size,
 | 
			
		||||
                n_hidden_layers,
 | 
			
		||||
                dropout_rate,
 | 
			
		||||
                rough_k,
 | 
			
		||||
                a_scoring_batch_size
 | 
			
		||||
            ),
 | 
			
		||||
            convert_inputs=convert_coref_scorer_inputs,
 | 
			
		||||
            convert_outputs=convert_coref_scorer_outputs
 | 
			
		||||
        )
 | 
			
		||||
        span_predictor = PyTorchWrapper(
 | 
			
		||||
            SpanPredictor(
 | 
			
		||||
                1024,
 | 
			
		||||
                config.sp_embedding_size,
 | 
			
		||||
                config.device
 | 
			
		||||
                # TODO this was hardcoded to 1024, check
 | 
			
		||||
                hidden_size,
 | 
			
		||||
                sp_embedding_size,
 | 
			
		||||
                device
 | 
			
		||||
            ),
 | 
			
		||||
            convert_inputs=convert_span_predictor_inputs
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module):
 | 
			
		|||
    """Combines all coref modules together to find coreferent spans.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        config (coref.config.Config): the model's configuration,
 | 
			
		||||
            see config.toml for the details
 | 
			
		||||
        epochs_trained (int): number of epochs the model has been trained for
 | 
			
		||||
 | 
			
		||||
    Submodules (in the order of their usage in the pipeline):
 | 
			
		||||
| 
						 | 
				
			
			@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module):
 | 
			
		|||
        A newly created model is set to evaluation mode.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            config_path (str): the path to the toml file with the configuration
 | 
			
		||||
            section (str): the selected section of the config file
 | 
			
		||||
            epochs_trained (int): the number of epochs finished
 | 
			
		||||
                (useful for warm start)
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user