From 1c697b40116ddb5276450bf0dc4e6b870f06205d Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 8 Mar 2022 18:13:09 +0900 Subject: [PATCH] Remove references to config Replaced with model arguments --- spacy/ml/models/coref.py | 43 ++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 2e291aa2b..bfaa97060 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -461,34 +461,43 @@ import .coref_util_wl as utils # TODO rename to plain coref @registry.architectures("spacy.WLCoref.v1") def build_wl_coref_model( - #TODO add other hyperparams tok2vec: Model[List[Doc], List[Floats2d]], + embedding_size: int = 20, + hidden_size: int = 1024, + n_hidden_layers: int = 1, # TODO rename to "depth"? + dropout: float = 0.3, + # pairs to keep per mention after rough scoring + # TODO change to meaningful name + rough_k: int = 50, + # TODO is this not a training loop setting? + a_scoring_batch_size: int = 512, + # span predictor embeddings + sp_embedding_size: int = 64, ): - # TODO change to use passed in values for config - config = utils._load_config("/dev/null") with Model.define_operators({">>": chain}): - - coref_scorer, span_predictor = configure_pytorch_modules(config) # TODO chain tok2vec with these models + # TODO fix device - should be automatic + device = "gpu:0" coref_scorer = PyTorchWrapper( CorefScorer( - config.device, - config.embedding_size, - config.hidden_size, - config.n_hidden_layers, - config.dropout_rate, - config.rough_k, - config.a_scoring_batch_size + device, + embedding_size, + hidden_size, + n_hidden_layers, + dropout_rate, + rough_k, + a_scoring_batch_size ), convert_inputs=convert_coref_scorer_inputs, convert_outputs=convert_coref_scorer_outputs ) span_predictor = PyTorchWrapper( SpanPredictor( - 1024, - config.sp_embedding_size, - config.device + # TODO this was hardcoded to 1024, check + hidden_size, + sp_embedding_size, + device ), convert_inputs=convert_span_predictor_inputs ) @@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module): """Combines all coref modules together to find coreferent spans. Attributes: - config (coref.config.Config): the model's configuration, - see config.toml for the details epochs_trained (int): number of epochs the model has been trained for Submodules (in the order of their usage in the pipeline): @@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module): A newly created model is set to evaluation mode. Args: - config_path (str): the path to the toml file with the configuration - section (str): the selected section of the config file epochs_trained (int): the number of epochs finished (useful for warm start) """