mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Remove references to config
Replaced with model arguments
This commit is contained in:
parent
c0cd5025e3
commit
1c697b4011
|
@ -461,34 +461,43 @@ import .coref_util_wl as utils
|
|||
# TODO rename to plain coref
|
||||
@registry.architectures("spacy.WLCoref.v1")
|
||||
def build_wl_coref_model(
|
||||
#TODO add other hyperparams
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
embedding_size: int = 20,
|
||||
hidden_size: int = 1024,
|
||||
n_hidden_layers: int = 1, # TODO rename to "depth"?
|
||||
dropout: float = 0.3,
|
||||
# pairs to keep per mention after rough scoring
|
||||
# TODO change to meaningful name
|
||||
rough_k: int = 50,
|
||||
# TODO is this not a training loop setting?
|
||||
a_scoring_batch_size: int = 512,
|
||||
# span predictor embeddings
|
||||
sp_embedding_size: int = 64,
|
||||
):
|
||||
|
||||
# TODO change to use passed in values for config
|
||||
config = utils._load_config("/dev/null")
|
||||
with Model.define_operators({">>": chain}):
|
||||
|
||||
coref_scorer, span_predictor = configure_pytorch_modules(config)
|
||||
# TODO chain tok2vec with these models
|
||||
# TODO fix device - should be automatic
|
||||
device = "gpu:0"
|
||||
coref_scorer = PyTorchWrapper(
|
||||
CorefScorer(
|
||||
config.device,
|
||||
config.embedding_size,
|
||||
config.hidden_size,
|
||||
config.n_hidden_layers,
|
||||
config.dropout_rate,
|
||||
config.rough_k,
|
||||
config.a_scoring_batch_size
|
||||
device,
|
||||
embedding_size,
|
||||
hidden_size,
|
||||
n_hidden_layers,
|
||||
dropout_rate,
|
||||
rough_k,
|
||||
a_scoring_batch_size
|
||||
),
|
||||
convert_inputs=convert_coref_scorer_inputs,
|
||||
convert_outputs=convert_coref_scorer_outputs
|
||||
)
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(
|
||||
1024,
|
||||
config.sp_embedding_size,
|
||||
config.device
|
||||
# TODO this was hardcoded to 1024, check
|
||||
hidden_size,
|
||||
sp_embedding_size,
|
||||
device
|
||||
),
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
|
@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module):
|
|||
"""Combines all coref modules together to find coreferent spans.
|
||||
|
||||
Attributes:
|
||||
config (coref.config.Config): the model's configuration,
|
||||
see config.toml for the details
|
||||
epochs_trained (int): number of epochs the model has been trained for
|
||||
|
||||
Submodules (in the order of their usage in the pipeline):
|
||||
|
@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module):
|
|||
A newly created model is set to evaluation mode.
|
||||
|
||||
Args:
|
||||
config_path (str): the path to the toml file with the configuration
|
||||
section (str): the selected section of the config file
|
||||
epochs_trained (int): the number of epochs finished
|
||||
(useful for warm start)
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user