mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Remove references to config
Replaced with model arguments
This commit is contained in:
parent
c0cd5025e3
commit
1c697b4011
|
@ -461,34 +461,43 @@ import .coref_util_wl as utils
|
||||||
# TODO rename to plain coref
|
# TODO rename to plain coref
|
||||||
@registry.architectures("spacy.WLCoref.v1")
|
@registry.architectures("spacy.WLCoref.v1")
|
||||||
def build_wl_coref_model(
|
def build_wl_coref_model(
|
||||||
#TODO add other hyperparams
|
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
embedding_size: int = 20,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
n_hidden_layers: int = 1, # TODO rename to "depth"?
|
||||||
|
dropout: float = 0.3,
|
||||||
|
# pairs to keep per mention after rough scoring
|
||||||
|
# TODO change to meaningful name
|
||||||
|
rough_k: int = 50,
|
||||||
|
# TODO is this not a training loop setting?
|
||||||
|
a_scoring_batch_size: int = 512,
|
||||||
|
# span predictor embeddings
|
||||||
|
sp_embedding_size: int = 64,
|
||||||
):
|
):
|
||||||
|
|
||||||
# TODO change to use passed in values for config
|
|
||||||
config = utils._load_config("/dev/null")
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
|
|
||||||
coref_scorer, span_predictor = configure_pytorch_modules(config)
|
|
||||||
# TODO chain tok2vec with these models
|
# TODO chain tok2vec with these models
|
||||||
|
# TODO fix device - should be automatic
|
||||||
|
device = "gpu:0"
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_scorer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefScorer(
|
||||||
config.device,
|
device,
|
||||||
config.embedding_size,
|
embedding_size,
|
||||||
config.hidden_size,
|
hidden_size,
|
||||||
config.n_hidden_layers,
|
n_hidden_layers,
|
||||||
config.dropout_rate,
|
dropout_rate,
|
||||||
config.rough_k,
|
rough_k,
|
||||||
config.a_scoring_batch_size
|
a_scoring_batch_size
|
||||||
),
|
),
|
||||||
convert_inputs=convert_coref_scorer_inputs,
|
convert_inputs=convert_coref_scorer_inputs,
|
||||||
convert_outputs=convert_coref_scorer_outputs
|
convert_outputs=convert_coref_scorer_outputs
|
||||||
)
|
)
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(
|
SpanPredictor(
|
||||||
1024,
|
# TODO this was hardcoded to 1024, check
|
||||||
config.sp_embedding_size,
|
hidden_size,
|
||||||
config.device
|
sp_embedding_size,
|
||||||
|
device
|
||||||
),
|
),
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
|
@ -597,8 +606,6 @@ class CorefScorer(torch.nn.Module):
|
||||||
"""Combines all coref modules together to find coreferent spans.
|
"""Combines all coref modules together to find coreferent spans.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
config (coref.config.Config): the model's configuration,
|
|
||||||
see config.toml for the details
|
|
||||||
epochs_trained (int): number of epochs the model has been trained for
|
epochs_trained (int): number of epochs the model has been trained for
|
||||||
|
|
||||||
Submodules (in the order of their usage in the pipeline):
|
Submodules (in the order of their usage in the pipeline):
|
||||||
|
@ -622,8 +629,6 @@ class CorefScorer(torch.nn.Module):
|
||||||
A newly created model is set to evaluation mode.
|
A newly created model is set to evaluation mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_path (str): the path to the toml file with the configuration
|
|
||||||
section (str): the selected section of the config file
|
|
||||||
epochs_trained (int): the number of epochs finished
|
epochs_trained (int): the number of epochs finished
|
||||||
(useful for warm start)
|
(useful for warm start)
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user