Remove references to config

Replaced with model arguments
This commit is contained in:
Paul O'Leary McCann 2022-03-08 18:13:09 +09:00
parent c0cd5025e3
commit 1c697b4011

View File

@ -461,34 +461,43 @@ import .coref_util_wl as utils
# TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1")
def build_wl_coref_model(
#TODO add other hyperparams
tok2vec: Model[List[Doc], List[Floats2d]],
embedding_size: int = 20,
hidden_size: int = 1024,
n_hidden_layers: int = 1, # TODO rename to "depth"?
dropout: float = 0.3,
# pairs to keep per mention after rough scoring
# TODO change to meaningful name
rough_k: int = 50,
# TODO is this not a training loop setting?
a_scoring_batch_size: int = 512,
# span predictor embeddings
sp_embedding_size: int = 64,
):
# TODO change to use passed in values for config
config = utils._load_config("/dev/null")
with Model.define_operators({">>": chain}):
coref_scorer, span_predictor = configure_pytorch_modules(config)
# TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "gpu:0"
coref_scorer = PyTorchWrapper(
CorefScorer(
config.device,
config.embedding_size,
config.hidden_size,
config.n_hidden_layers,
config.dropout_rate,
config.rough_k,
config.a_scoring_batch_size
device,
embedding_size,
hidden_size,
n_hidden_layers,
dropout_rate,
rough_k,
a_scoring_batch_size
),
convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs
)
span_predictor = PyTorchWrapper(
SpanPredictor(
1024,
config.sp_embedding_size,
config.device
# TODO this was hardcoded to 1024, check
hidden_size,
sp_embedding_size,
device
),
convert_inputs=convert_span_predictor_inputs
)
@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans.
Attributes:
config (coref.config.Config): the model's configuration,
see config.toml for the details
epochs_trained (int): number of epochs the model has been trained for
Submodules (in the order of their usage in the pipeline):
@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module):
A newly created model is set to evaluation mode.
Args:
config_path (str): the path to the toml file with the configuration
section (str): the selected section of the config file
epochs_trained (int): the number of epochs finished
(useful for warm start)
"""