Remove references to config

Replaced with model arguments
This commit is contained in:
Paul O'Leary McCann 2022-03-08 18:13:09 +09:00
parent c0cd5025e3
commit 1c697b4011

View File

@ -461,34 +461,43 @@ import .coref_util_wl as utils
# TODO rename to plain coref # TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1") @registry.architectures("spacy.WLCoref.v1")
def build_wl_coref_model( def build_wl_coref_model(
#TODO add other hyperparams
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
embedding_size: int = 20,
hidden_size: int = 1024,
n_hidden_layers: int = 1, # TODO rename to "depth"?
dropout: float = 0.3,
# pairs to keep per mention after rough scoring
# TODO change to meaningful name
rough_k: int = 50,
# TODO is this not a training loop setting?
a_scoring_batch_size: int = 512,
# span predictor embeddings
sp_embedding_size: int = 64,
): ):
# TODO change to use passed in values for config
config = utils._load_config("/dev/null")
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_scorer, span_predictor = configure_pytorch_modules(config)
# TODO chain tok2vec with these models # TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "gpu:0"
coref_scorer = PyTorchWrapper( coref_scorer = PyTorchWrapper(
CorefScorer( CorefScorer(
config.device, device,
config.embedding_size, embedding_size,
config.hidden_size, hidden_size,
config.n_hidden_layers, n_hidden_layers,
config.dropout_rate, dropout_rate,
config.rough_k, rough_k,
config.a_scoring_batch_size a_scoring_batch_size
), ),
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs convert_outputs=convert_coref_scorer_outputs
) )
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor( SpanPredictor(
1024, # TODO this was hardcoded to 1024, check
config.sp_embedding_size, hidden_size,
config.device sp_embedding_size,
device
), ),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans. """Combines all coref modules together to find coreferent spans.
Attributes: Attributes:
config (coref.config.Config): the model's configuration,
see config.toml for the details
epochs_trained (int): number of epochs the model has been trained for epochs_trained (int): number of epochs the model has been trained for
Submodules (in the order of their usage in the pipeline): Submodules (in the order of their usage in the pipeline):
@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module):
A newly created model is set to evaluation mode. A newly created model is set to evaluation mode.
Args: Args:
config_path (str): the path to the toml file with the configuration
section (str): the selected section of the config file
epochs_trained (int): the number of epochs finished epochs_trained (int): the number of epochs finished
(useful for warm start) (useful for warm start)
""" """