mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Dont hard-code for 'corpora' name
This commit is contained in:
parent
a023cf3ecc
commit
3a0a3b8db6
|
@ -77,12 +77,10 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
|||
# Create iterator, which yields out info after each optimization step.
|
||||
config = nlp.config.interpolate()
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
||||
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
||||
optimizer T["optimizer"]
|
||||
score_weights = T["score_weights"]
|
||||
# TODO: This might not be called corpora
|
||||
corpora = registry.resolve(config["corpora"], schema=ConfigSchemaCorpora)
|
||||
train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"])
|
||||
dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"])
|
||||
batcher = T["batcher"]
|
||||
train_logger = T["logger"]
|
||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||
|
@ -101,7 +99,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
|||
patience=T["patience"],
|
||||
max_steps=T["max_steps"],
|
||||
eval_frequency=T["eval_frequency"],
|
||||
raw_text=None,
|
||||
raw_text=raw_text,
|
||||
exclude=frozen_components,
|
||||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
|
|
Loading…
Reference in New Issue
Block a user