mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Remove references to config
Replaced with model arguments
This commit is contained in:
		
							parent
							
								
									c0cd5025e3
								
							
						
					
					
						commit
						1c697b4011
					
				| 
						 | 
					@ -461,34 +461,43 @@ import .coref_util_wl as utils
 | 
				
			||||||
# TODO rename to plain coref
 | 
					# TODO rename to plain coref
 | 
				
			||||||
@registry.architectures("spacy.WLCoref.v1")
 | 
					@registry.architectures("spacy.WLCoref.v1")
 | 
				
			||||||
def build_wl_coref_model(
 | 
					def build_wl_coref_model(
 | 
				
			||||||
        #TODO add other hyperparams
 | 
					 | 
				
			||||||
    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
					    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
				
			||||||
 | 
					    embedding_size: int = 20,
 | 
				
			||||||
 | 
					    hidden_size: int = 1024,
 | 
				
			||||||
 | 
					    n_hidden_layers: int = 1, # TODO rename to "depth"?
 | 
				
			||||||
 | 
					    dropout: float = 0.3,
 | 
				
			||||||
 | 
					    # pairs to keep per mention after rough scoring
 | 
				
			||||||
 | 
					    # TODO change to meaningful name
 | 
				
			||||||
 | 
					    rough_k: int = 50,
 | 
				
			||||||
 | 
					    # TODO is this not a training loop setting?
 | 
				
			||||||
 | 
					    a_scoring_batch_size: int = 512,
 | 
				
			||||||
 | 
					    # span predictor embeddings
 | 
				
			||||||
 | 
					    sp_embedding_size: int = 64,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # TODO change to use passed in values for config
 | 
					 | 
				
			||||||
    config = utils._load_config("/dev/null")
 | 
					 | 
				
			||||||
    with Model.define_operators({">>": chain}):
 | 
					    with Model.define_operators({">>": chain}):
 | 
				
			||||||
 | 
					 | 
				
			||||||
        coref_scorer, span_predictor = configure_pytorch_modules(config)
 | 
					 | 
				
			||||||
        # TODO chain tok2vec with these models
 | 
					        # TODO chain tok2vec with these models
 | 
				
			||||||
 | 
					        # TODO fix device - should be automatic
 | 
				
			||||||
 | 
					        device = "gpu:0"
 | 
				
			||||||
        coref_scorer = PyTorchWrapper(
 | 
					        coref_scorer = PyTorchWrapper(
 | 
				
			||||||
            CorefScorer(
 | 
					            CorefScorer(
 | 
				
			||||||
                config.device,
 | 
					                device,
 | 
				
			||||||
                config.embedding_size,
 | 
					                embedding_size,
 | 
				
			||||||
                config.hidden_size,
 | 
					                hidden_size,
 | 
				
			||||||
                config.n_hidden_layers,
 | 
					                n_hidden_layers,
 | 
				
			||||||
                config.dropout_rate,
 | 
					                dropout_rate,
 | 
				
			||||||
                config.rough_k,
 | 
					                rough_k,
 | 
				
			||||||
                config.a_scoring_batch_size
 | 
					                a_scoring_batch_size
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
            convert_inputs=convert_coref_scorer_inputs,
 | 
					            convert_inputs=convert_coref_scorer_inputs,
 | 
				
			||||||
            convert_outputs=convert_coref_scorer_outputs
 | 
					            convert_outputs=convert_coref_scorer_outputs
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        span_predictor = PyTorchWrapper(
 | 
					        span_predictor = PyTorchWrapper(
 | 
				
			||||||
            SpanPredictor(
 | 
					            SpanPredictor(
 | 
				
			||||||
                1024,
 | 
					                # TODO this was hardcoded to 1024, check
 | 
				
			||||||
                config.sp_embedding_size,
 | 
					                hidden_size,
 | 
				
			||||||
                config.device
 | 
					                sp_embedding_size,
 | 
				
			||||||
 | 
					                device
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
            convert_inputs=convert_span_predictor_inputs
 | 
					            convert_inputs=convert_span_predictor_inputs
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module):
 | 
				
			||||||
    """Combines all coref modules together to find coreferent spans.
 | 
					    """Combines all coref modules together to find coreferent spans.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Attributes:
 | 
					    Attributes:
 | 
				
			||||||
        config (coref.config.Config): the model's configuration,
 | 
					 | 
				
			||||||
            see config.toml for the details
 | 
					 | 
				
			||||||
        epochs_trained (int): number of epochs the model has been trained for
 | 
					        epochs_trained (int): number of epochs the model has been trained for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Submodules (in the order of their usage in the pipeline):
 | 
					    Submodules (in the order of their usage in the pipeline):
 | 
				
			||||||
| 
						 | 
					@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module):
 | 
				
			||||||
        A newly created model is set to evaluation mode.
 | 
					        A newly created model is set to evaluation mode.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            config_path (str): the path to the toml file with the configuration
 | 
					 | 
				
			||||||
            section (str): the selected section of the config file
 | 
					 | 
				
			||||||
            epochs_trained (int): the number of epochs finished
 | 
					            epochs_trained (int): the number of epochs finished
 | 
				
			||||||
                (useful for warm start)
 | 
					                (useful for warm start)
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user