mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
generalize corpora, dot notation for dev and train corpus
This commit is contained in:
parent
8cedb2f380
commit
0c35885751
|
@ -8,6 +8,22 @@ init_tok2vec = null
|
||||||
seed = 0
|
seed = 0
|
||||||
use_pytorch_for_gpu_memory = false
|
use_pytorch_for_gpu_memory = false
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths:train}
|
||||||
|
gold_preproc = true
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths:dev}
|
||||||
|
gold_preproc = ${corpora.train.gold_preproc}
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
seed = ${system:seed}
|
seed = ${system:seed}
|
||||||
dropout = 0.1
|
dropout = 0.1
|
||||||
|
@ -20,22 +36,8 @@ patience = 10000
|
||||||
eval_frequency = 200
|
eval_frequency = 200
|
||||||
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||||
frozen_components = []
|
frozen_components = []
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
[training.corpus]
|
train_corpus = "corpora.train"
|
||||||
|
|
||||||
[training.corpus.train]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths:train}
|
|
||||||
gold_preproc = true
|
|
||||||
max_length = 0
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.corpus.dev]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths:dev}
|
|
||||||
gold_preproc = ${training.read_train:gold_preproc}
|
|
||||||
max_length = 0
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.batcher]
|
[training.batcher]
|
||||||
@batchers = "spacy.batch_by_words.v1"
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
|
|
@ -8,6 +8,22 @@ init_tok2vec = null
|
||||||
seed = 0
|
seed = 0
|
||||||
use_pytorch_for_gpu_memory = false
|
use_pytorch_for_gpu_memory = false
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths:train}
|
||||||
|
gold_preproc = true
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths:dev}
|
||||||
|
gold_preproc = ${corpora.train.gold_preproc}
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
seed = ${system:seed}
|
seed = ${system:seed}
|
||||||
dropout = 0.2
|
dropout = 0.2
|
||||||
|
@ -20,22 +36,6 @@ patience = 10000
|
||||||
eval_frequency = 200
|
eval_frequency = 200
|
||||||
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
||||||
|
|
||||||
[training.corpus]
|
|
||||||
|
|
||||||
[training.corpus.train]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths:train}
|
|
||||||
gold_preproc = true
|
|
||||||
max_length = 0
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.corpus.dev]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths:dev}
|
|
||||||
gold_preproc = ${training.read_train:gold_preproc}
|
|
||||||
max_length = 0
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.batcher]
|
[training.batcher]
|
||||||
@batchers = "spacy.batch_by_words.v1"
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
discard_oversize = false
|
discard_oversize = false
|
||||||
|
|
|
@ -20,6 +20,7 @@ from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import ID
|
from ..attrs import ID
|
||||||
from .. import util
|
from .. import util
|
||||||
|
from ..util import dot_to_object
|
||||||
|
|
||||||
|
|
||||||
@app.command(
|
@app.command(
|
||||||
|
@ -106,7 +107,7 @@ def pretrain(
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
nlp, config = util.load_model_from_config(config)
|
nlp, config = util.load_model_from_config(config)
|
||||||
P_cfg = config["pretraining"]
|
P_cfg = config["pretraining"]
|
||||||
corpus = P_cfg["corpus"]
|
corpus = dot_to_object(config, config["pretraining"]["corpus"])
|
||||||
batcher = P_cfg["batcher"]
|
batcher = P_cfg["batcher"]
|
||||||
model = create_pretraining_model(nlp, config["pretraining"])
|
model = create_pretraining_model(nlp, config["pretraining"])
|
||||||
optimizer = config["pretraining"]["optimizer"]
|
optimizer = config["pretraining"]["optimizer"]
|
||||||
|
|
|
@ -173,6 +173,18 @@ factory = "{{ pipe }}"
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.train}
|
||||||
|
max_length = {{ 500 if hardware == "gpu" else 2000 }}
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.dev}
|
||||||
|
max_length = 0
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
{% if use_transformer or optimize == "efficiency" or not word_vectors -%}
|
{% if use_transformer or optimize == "efficiency" or not word_vectors -%}
|
||||||
vectors = null
|
vectors = null
|
||||||
|
@ -182,11 +194,12 @@ vectors = "{{ word_vectors }}"
|
||||||
{% if use_transformer -%}
|
{% if use_transformer -%}
|
||||||
accumulate_gradient = {{ transformer["size_factor"] }}
|
accumulate_gradient = {{ transformer["size_factor"] }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
|
train_corpus = "corpora.train"
|
||||||
|
|
||||||
[training.optimizer]
|
[training.optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
|
|
||||||
|
|
||||||
{% if use_transformer -%}
|
{% if use_transformer -%}
|
||||||
[training.optimizer.learn_rate]
|
[training.optimizer.learn_rate]
|
||||||
@schedules = "warmup_linear.v1"
|
@schedules = "warmup_linear.v1"
|
||||||
|
@ -195,18 +208,6 @@ total_steps = 20000
|
||||||
initial_rate = 5e-5
|
initial_rate = 5e-5
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
[training.corpus]
|
|
||||||
|
|
||||||
[training.corpus.train]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths.train}
|
|
||||||
max_length = {{ 500 if hardware == "gpu" else 2000 }}
|
|
||||||
|
|
||||||
[training.corpus.dev]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths.dev}
|
|
||||||
max_length = 0
|
|
||||||
|
|
||||||
{% if use_transformer %}
|
{% if use_transformer %}
|
||||||
[training.batcher]
|
[training.batcher]
|
||||||
@batchers = "spacy.batch_by_padded.v1"
|
@batchers = "spacy.batch_by_padded.v1"
|
||||||
|
|
|
@ -18,6 +18,7 @@ from ..language import Language
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..training.example import Example
|
from ..training.example import Example
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from ..util import dot_to_object
|
||||||
|
|
||||||
|
|
||||||
@app.command(
|
@app.command(
|
||||||
|
@ -92,8 +93,8 @@ def train(
|
||||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||||
T_cfg = config["training"]
|
T_cfg = config["training"]
|
||||||
optimizer = T_cfg["optimizer"]
|
optimizer = T_cfg["optimizer"]
|
||||||
train_corpus = T_cfg["corpus"]["train"]
|
train_corpus = dot_to_object(config, config["training"]["train_corpus"])
|
||||||
dev_corpus = T_cfg["corpus"]["dev"]
|
dev_corpus = dot_to_object(config, config["training"]["dev_corpus"])
|
||||||
batcher = T_cfg["batcher"]
|
batcher = T_cfg["batcher"]
|
||||||
train_logger = T_cfg["logger"]
|
train_logger = T_cfg["logger"]
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
|
|
|
@ -22,6 +22,33 @@ after_pipeline_creation = null
|
||||||
|
|
||||||
[components]
|
[components]
|
||||||
|
|
||||||
|
# Readers for corpora like dev and train.
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.train}
|
||||||
|
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||||
|
# and tokens. If you set this to true, take care to ensure your run-time
|
||||||
|
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||||
|
gold_preproc = false
|
||||||
|
# Limitations on training document length
|
||||||
|
max_length = 0
|
||||||
|
# Limitation on number of training examples
|
||||||
|
limit = 0
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.dev}
|
||||||
|
# Whether to train on sequences with 'gold standard' sentence boundaries
|
||||||
|
# and tokens. If you set this to true, take care to ensure your run-time
|
||||||
|
# data is passed in sentence-by-sentence via some prior preprocessing.
|
||||||
|
gold_preproc = false
|
||||||
|
# Limitations on training document length
|
||||||
|
max_length = 0
|
||||||
|
# Limitation on number of training examples
|
||||||
|
limit = 0
|
||||||
|
|
||||||
# Training hyper-parameters and additional features.
|
# Training hyper-parameters and additional features.
|
||||||
[training]
|
[training]
|
||||||
seed = ${system.seed}
|
seed = ${system.seed}
|
||||||
|
@ -40,35 +67,14 @@ eval_frequency = 200
|
||||||
score_weights = {}
|
score_weights = {}
|
||||||
# Names of pipeline components that shouldn't be updated during training
|
# Names of pipeline components that shouldn't be updated during training
|
||||||
frozen_components = []
|
frozen_components = []
|
||||||
|
# Location in the config where the dev corpus is defined
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
|
# Location in the config where the train corpus is defined
|
||||||
|
train_corpus = "corpora.train"
|
||||||
|
|
||||||
[training.logger]
|
[training.logger]
|
||||||
@loggers = "spacy.ConsoleLogger.v1"
|
@loggers = "spacy.ConsoleLogger.v1"
|
||||||
|
|
||||||
[training.corpus]
|
|
||||||
|
|
||||||
[training.corpus.train]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths.train}
|
|
||||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
|
||||||
# and tokens. If you set this to true, take care to ensure your run-time
|
|
||||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
|
||||||
gold_preproc = false
|
|
||||||
# Limitations on training document length
|
|
||||||
max_length = 0
|
|
||||||
# Limitation on number of training examples
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.corpus.dev]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
|
||||||
path = ${paths.dev}
|
|
||||||
# Whether to train on sequences with 'gold standard' sentence boundaries
|
|
||||||
# and tokens. If you set this to true, take care to ensure your run-time
|
|
||||||
# data is passed in sentence-by-sentence via some prior preprocessing.
|
|
||||||
gold_preproc = false
|
|
||||||
# Limitations on training document length
|
|
||||||
max_length = 0
|
|
||||||
# Limitation on number of training examples
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.batcher]
|
[training.batcher]
|
||||||
@batchers = "spacy.batch_by_words.v1"
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
|
|
@ -4,6 +4,7 @@ dropout = 0.2
|
||||||
n_save_every = null
|
n_save_every = null
|
||||||
component = "tok2vec"
|
component = "tok2vec"
|
||||||
layer = ""
|
layer = ""
|
||||||
|
corpus = "corpora.pretrain"
|
||||||
|
|
||||||
[pretraining.batcher]
|
[pretraining.batcher]
|
||||||
@batchers = "spacy.batch_by_words.v1"
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
@ -12,13 +13,6 @@ discard_oversize = false
|
||||||
tolerance = 0.2
|
tolerance = 0.2
|
||||||
get_length = null
|
get_length = null
|
||||||
|
|
||||||
[pretraining.corpus]
|
|
||||||
@readers = "spacy.JsonlReader.v1"
|
|
||||||
path = ${paths.raw}
|
|
||||||
min_length = 5
|
|
||||||
max_length = 500
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[pretraining.objective]
|
[pretraining.objective]
|
||||||
type = "characters"
|
type = "characters"
|
||||||
n_characters = 4
|
n_characters = 4
|
||||||
|
@ -33,3 +27,12 @@ grad_clip = 1.0
|
||||||
use_averages = true
|
use_averages = true
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.pretrain]
|
||||||
|
@readers = "spacy.JsonlReader.v1"
|
||||||
|
path = ${paths.raw}
|
||||||
|
min_length = 5
|
||||||
|
max_length = 500
|
||||||
|
limit = 0
|
||||||
|
|
|
@ -198,7 +198,8 @@ class ModelMetaSchema(BaseModel):
|
||||||
class ConfigSchemaTraining(BaseModel):
|
class ConfigSchemaTraining(BaseModel):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
|
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
|
||||||
corpus: Dict[str, Reader] = Field(..., title="Reader for the training and dev data")
|
dev_corpus: StrictStr = Field(..., title="Path in the config to the dev data")
|
||||||
|
train_corpus: StrictStr = Field(..., title="Path in the config to the training data")
|
||||||
batcher: Batcher = Field(..., title="Batcher for the training data")
|
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||||
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
|
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
|
||||||
|
@ -248,7 +249,7 @@ class ConfigSchemaPretrain(BaseModel):
|
||||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||||
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
|
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
|
||||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
corpus: Reader = Field(..., title="Reader for the training data")
|
corpus: StrictStr = Field(..., title="Path in the config to the training data")
|
||||||
batcher: Batcher = Field(..., title="Batcher for the training data")
|
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||||
component: str = Field(..., title="Component to find the layer to pretrain")
|
component: str = Field(..., title="Component to find the layer to pretrain")
|
||||||
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
|
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
|
||||||
|
@ -267,6 +268,7 @@ class ConfigSchema(BaseModel):
|
||||||
nlp: ConfigSchemaNlp
|
nlp: ConfigSchemaNlp
|
||||||
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
||||||
components: Dict[str, Dict[str, Any]]
|
components: Dict[str, Dict[str, Any]]
|
||||||
|
corpora: Dict[str, Reader]
|
||||||
|
|
||||||
@root_validator(allow_reuse=True)
|
@root_validator(allow_reuse=True)
|
||||||
def validate_config(cls, values):
|
def validate_config(cls, values):
|
||||||
|
|
|
@ -17,18 +17,18 @@ nlp_config_string = """
|
||||||
train = ""
|
train = ""
|
||||||
dev = ""
|
dev = ""
|
||||||
|
|
||||||
[training]
|
[corpora]
|
||||||
|
|
||||||
[training.corpus]
|
[corpora.train]
|
||||||
|
|
||||||
[training.corpus.train]
|
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
path = ${paths.train}
|
path = ${paths.train}
|
||||||
|
|
||||||
[training.corpus.dev]
|
[corpora.dev]
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
path = ${paths.dev}
|
path = ${paths.dev}
|
||||||
|
|
||||||
|
[training]
|
||||||
|
|
||||||
[training.batcher]
|
[training.batcher]
|
||||||
@batchers = "spacy.batch_by_words.v1"
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
size = 666
|
size = 666
|
||||||
|
@ -302,20 +302,20 @@ def test_config_overrides():
|
||||||
|
|
||||||
def test_config_interpolation():
|
def test_config_interpolation():
|
||||||
config = Config().from_str(nlp_config_string, interpolate=False)
|
config = Config().from_str(nlp_config_string, interpolate=False)
|
||||||
assert config["training"]["corpus"]["train"]["path"] == "${paths.train}"
|
assert config["corpora"]["train"]["path"] == "${paths.train}"
|
||||||
interpolated = config.interpolate()
|
interpolated = config.interpolate()
|
||||||
assert interpolated["training"]["corpus"]["train"]["path"] == ""
|
assert interpolated["corpora"]["train"]["path"] == ""
|
||||||
nlp = English.from_config(config)
|
nlp = English.from_config(config)
|
||||||
assert nlp.config["training"]["corpus"]["train"]["path"] == "${paths.train}"
|
assert nlp.config["corpora"]["train"]["path"] == "${paths.train}"
|
||||||
# Ensure that variables are preserved in nlp config
|
# Ensure that variables are preserved in nlp config
|
||||||
width = "${components.tok2vec.model.width}"
|
width = "${components.tok2vec.model.width}"
|
||||||
assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
|
assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
|
||||||
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
|
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
|
||||||
interpolated2 = nlp.config.interpolate()
|
interpolated2 = nlp.config.interpolate()
|
||||||
assert interpolated2["training"]["corpus"]["train"]["path"] == ""
|
assert interpolated2["corpora"]["train"]["path"] == ""
|
||||||
assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
|
assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
|
||||||
nlp2 = English.from_config(interpolated)
|
nlp2 = English.from_config(interpolated)
|
||||||
assert nlp2.config["training"]["corpus"]["train"]["path"] == ""
|
assert nlp2.config["corpora"]["train"]["path"] == ""
|
||||||
assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
|
assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,57 @@
|
||||||
|
from typing import Dict, Iterable, Callable
|
||||||
import pytest
|
import pytest
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
from spacy.util import load_model_from_config
|
|
||||||
|
from spacy import Language
|
||||||
|
from spacy.util import load_model_from_config, registry, dot_to_object
|
||||||
|
from spacy.training import Example
|
||||||
|
|
||||||
|
|
||||||
|
def test_readers():
|
||||||
|
config_string = """
|
||||||
|
[training]
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
@readers = "myreader.v1"
|
||||||
|
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec", "textcat"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.textcat]
|
||||||
|
factory = "textcat"
|
||||||
|
"""
|
||||||
|
@registry.readers.register("myreader.v1")
|
||||||
|
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
||||||
|
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
||||||
|
def reader(nlp: Language):
|
||||||
|
doc = nlp.make_doc(f"This is an example")
|
||||||
|
return [Example.from_dict(doc, annots)]
|
||||||
|
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
||||||
|
|
||||||
|
config = Config().from_str(config_string)
|
||||||
|
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
||||||
|
|
||||||
|
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
||||||
|
assert isinstance(train_corpus, Callable)
|
||||||
|
optimizer = resolved["training"]["optimizer"]
|
||||||
|
# simulate a training loop
|
||||||
|
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
|
for example in train_corpus(nlp):
|
||||||
|
nlp.update([example], sgd=optimizer)
|
||||||
|
dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"])
|
||||||
|
scores = nlp.evaluate(list(dev_corpus(nlp)))
|
||||||
|
assert scores["cats_score"]
|
||||||
|
# ensure the pipeline runs
|
||||||
|
doc = nlp("Quick test")
|
||||||
|
assert doc.cats
|
||||||
|
extra_corpus = resolved["corpora"]["extra"]
|
||||||
|
assert isinstance(extra_corpus, Callable)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@ -16,7 +67,7 @@ def test_cat_readers(reader, additional_config):
|
||||||
nlp_config_string = """
|
nlp_config_string = """
|
||||||
[training]
|
[training]
|
||||||
|
|
||||||
[training.corpus]
|
[corpora]
|
||||||
@readers = "PLACEHOLDER"
|
@readers = "PLACEHOLDER"
|
||||||
|
|
||||||
[nlp]
|
[nlp]
|
||||||
|
@ -32,11 +83,11 @@ def test_cat_readers(reader, additional_config):
|
||||||
factory = "textcat"
|
factory = "textcat"
|
||||||
"""
|
"""
|
||||||
config = Config().from_str(nlp_config_string)
|
config = Config().from_str(nlp_config_string)
|
||||||
config["training"]["corpus"]["@readers"] = reader
|
config["corpora"]["@readers"] = reader
|
||||||
config["training"]["corpus"].update(additional_config)
|
config["corpora"].update(additional_config)
|
||||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
||||||
|
|
||||||
train_corpus = resolved["training"]["corpus"]["train"]
|
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
||||||
optimizer = resolved["training"]["optimizer"]
|
optimizer = resolved["training"]["optimizer"]
|
||||||
# simulate a training loop
|
# simulate a training loop
|
||||||
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
|
@ -46,7 +97,7 @@ def test_cat_readers(reader, additional_config):
|
||||||
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
|
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
|
||||||
nlp.update([example], sgd=optimizer)
|
nlp.update([example], sgd=optimizer)
|
||||||
# simulate performance benchmark on dev corpus
|
# simulate performance benchmark on dev corpus
|
||||||
dev_corpus = resolved["training"]["corpus"]["dev"]
|
dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"])
|
||||||
dev_examples = list(dev_corpus(nlp))
|
dev_examples = list(dev_corpus(nlp))
|
||||||
for example in dev_examples:
|
for example in dev_examples:
|
||||||
# this shouldn't fail if each dev example has at least one positive label
|
# this shouldn't fail if each dev example has at least one positive label
|
||||||
|
|
|
@ -355,6 +355,16 @@ Registry @architectures
|
||||||
Name spacy.MaxoutWindowEncoder.v1
|
Name spacy.MaxoutWindowEncoder.v1
|
||||||
Module spacy.ml.models.tok2vec
|
Module spacy.ml.models.tok2vec
|
||||||
File /path/to/spacy/ml/models/tok2vec.py (line 207)
|
File /path/to/spacy/ml/models/tok2vec.py (line 207)
|
||||||
|
ℹ [corpora.dev]
|
||||||
|
Registry @readers
|
||||||
|
Name spacy.Corpus.v1
|
||||||
|
Module spacy.training.corpus
|
||||||
|
File /path/to/spacy/training/corpus.py (line 18)
|
||||||
|
ℹ [corpora.train]
|
||||||
|
Registry @readers
|
||||||
|
Name spacy.Corpus.v1
|
||||||
|
Module spacy.training.corpus
|
||||||
|
File /path/to/spacy/training/corpus.py (line 18)
|
||||||
ℹ [training.logger]
|
ℹ [training.logger]
|
||||||
Registry @loggers
|
Registry @loggers
|
||||||
Name spacy.ConsoleLogger.v1
|
Name spacy.ConsoleLogger.v1
|
||||||
|
@ -370,16 +380,6 @@ Registry @schedules
|
||||||
Name compounding.v1
|
Name compounding.v1
|
||||||
Module thinc.schedules
|
Module thinc.schedules
|
||||||
File /path/to/thinc/thinc/schedules.py (line 43)
|
File /path/to/thinc/thinc/schedules.py (line 43)
|
||||||
ℹ [training.corpus.dev]
|
|
||||||
Registry @readers
|
|
||||||
Name spacy.Corpus.v1
|
|
||||||
Module spacy.training.corpus
|
|
||||||
File /path/to/spacy/training/corpus.py (line 18)
|
|
||||||
ℹ [training.corpus.train]
|
|
||||||
Registry @readers
|
|
||||||
Name spacy.Corpus.v1
|
|
||||||
Module spacy.training.corpus
|
|
||||||
File /path/to/spacy/training/corpus.py (line 18)
|
|
||||||
ℹ [training.optimizer]
|
ℹ [training.optimizer]
|
||||||
Registry @optimizers
|
Registry @optimizers
|
||||||
Name Adam.v1
|
Name Adam.v1
|
||||||
|
|
|
@ -26,7 +26,7 @@ streaming.
|
||||||
> [paths]
|
> [paths]
|
||||||
> train = "corpus/train.spacy"
|
> train = "corpus/train.spacy"
|
||||||
>
|
>
|
||||||
> [training.corpus.train]
|
> [corpora.train]
|
||||||
> @readers = "spacy.Corpus.v1"
|
> @readers = "spacy.Corpus.v1"
|
||||||
> path = ${paths.train}
|
> path = ${paths.train}
|
||||||
> gold_preproc = false
|
> gold_preproc = false
|
||||||
|
@ -135,7 +135,7 @@ Initialize the reader.
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> ### Example config
|
> ### Example config
|
||||||
> [pretraining.corpus]
|
> [corpora.pretrain]
|
||||||
> @readers = "spacy.JsonlReader.v1"
|
> @readers = "spacy.JsonlReader.v1"
|
||||||
> path = "corpus/raw_text.jsonl"
|
> path = "corpus/raw_text.jsonl"
|
||||||
> min_length = 0
|
> min_length = 0
|
||||||
|
|
|
@ -121,28 +121,78 @@ that you don't want to hard-code in your config file.
|
||||||
$ python -m spacy train config.cfg --paths.train ./corpus/train.spacy
|
$ python -m spacy train config.cfg --paths.train ./corpus/train.spacy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### corpora {#config-corpora tag="section"}
|
||||||
|
|
||||||
|
This section defines a dictionary mapping of string keys to `Callable`
|
||||||
|
functions. Each callable takes an `nlp` object and yields
|
||||||
|
[`Example`](/api/example) objects. By default, the two keys `train` and `dev`
|
||||||
|
are specified and each refer to a [`Corpus`](/api/top-level#Corpus). When
|
||||||
|
pretraining, an additional pretrain section is added that defaults to a
|
||||||
|
[`JsonlReader`](/api/top-level#JsonlReader).
|
||||||
|
|
||||||
|
These subsections can be expanded with additional subsections, each referring to
|
||||||
|
a callback of type `Callable[[Language], Iterator[Example]]`:
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> [corpora]
|
||||||
|
> [corpora.train]
|
||||||
|
> @readers = "spacy.Corpus.v1"
|
||||||
|
> path = ${paths:train}
|
||||||
|
>
|
||||||
|
> [corpora.dev]
|
||||||
|
> @readers = "spacy.Corpus.v1"
|
||||||
|
> path = ${paths:dev}
|
||||||
|
>
|
||||||
|
> [corpora.pretrain]
|
||||||
|
> @readers = "spacy.JsonlReader.v1"
|
||||||
|
> path = ${paths.raw}
|
||||||
|
> min_length = 5
|
||||||
|
> max_length = 500
|
||||||
|
>
|
||||||
|
> [corpora.mydata]
|
||||||
|
> @readers = "my_reader.v1"
|
||||||
|
> shuffle = true
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Alternatively, the `corpora` block could refer to one function with return type
|
||||||
|
`Dict[str, Callable[[Language], Iterator[Example]]]`:
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> [corpora]
|
||||||
|
> @readers = "my_dict_reader.v1"
|
||||||
|
> train_path = ${paths:train}
|
||||||
|
> dev_path = ${paths:dev}
|
||||||
|
> shuffle = true
|
||||||
|
>
|
||||||
|
> ```
|
||||||
|
|
||||||
### training {#config-training tag="section"}
|
### training {#config-training tag="section"}
|
||||||
|
|
||||||
This section defines settings and controls for the training and evaluation
|
This section defines settings and controls for the training and evaluation
|
||||||
process that are used when you run [`spacy train`](/api/cli#train).
|
process that are used when you run [`spacy train`](/api/cli#train).
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
||||||
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
||||||
| `corpus` | Dictionary with `train` and `dev` keys, each referring to a callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/top-level#Corpus). ~~Callable[[Language], Iterator[Example]]~~ |
|
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
||||||
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
||||||
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
||||||
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
||||||
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ |
|
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ |
|
||||||
| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |
|
| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |
|
||||||
| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ |
|
| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ |
|
||||||
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
||||||
| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ |
|
| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ |
|
||||||
| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ |
|
| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ |
|
||||||
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
|
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
|
||||||
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
|
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
|
||||||
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ |
|
| `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
|
||||||
|
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ |
|
||||||
|
|
||||||
### pretraining {#config-pretraining tag="section,optional"}
|
### pretraining {#config-pretraining tag="section,optional"}
|
||||||
|
|
||||||
|
@ -150,17 +200,18 @@ This section is optional and defines settings and controls for
|
||||||
[language model pretraining](/usage/embeddings-transformers#pretraining). It's
|
[language model pretraining](/usage/embeddings-transformers#pretraining). It's
|
||||||
used when you run [`spacy pretrain`](/api/cli#pretrain).
|
used when you run [`spacy pretrain`](/api/cli#pretrain).
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------ |
|
||||||
| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ |
|
| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ |
|
||||||
| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ |
|
| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ |
|
||||||
| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ |
|
| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ |
|
||||||
| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ |
|
| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ |
|
||||||
| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
||||||
| `corpus` | Callable that takes the current `nlp` object and yields [`Doc`](/api/doc) objects. Defaults to [`JsonlReader`](/api/top-level#JsonlReader). ~~Callable[[Language, str], Iterable[Example]]~~ |
|
| `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
|
||||||
| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
||||||
| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ |
|
| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ |
|
||||||
| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ |
|
| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ |
|
||||||
|
| |
|
||||||
|
|
||||||
## Training data {#training}
|
## Training data {#training}
|
||||||
|
|
||||||
|
|
|
@ -448,7 +448,7 @@ remain in the config file stored on your local system.
|
||||||
> [training.logger]
|
> [training.logger]
|
||||||
> @loggers = "spacy.WandbLogger.v1"
|
> @loggers = "spacy.WandbLogger.v1"
|
||||||
> project_name = "monitor_spacy_training"
|
> project_name = "monitor_spacy_training"
|
||||||
> remove_config_values = ["paths.train", "paths.dev", "training.corpus.train.path", "training.corpus.dev.path"]
|
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|
@ -478,7 +478,7 @@ the [`Corpus`](/api/corpus) class.
|
||||||
> [paths]
|
> [paths]
|
||||||
> train = "corpus/train.spacy"
|
> train = "corpus/train.spacy"
|
||||||
>
|
>
|
||||||
> [training.corpus.train]
|
> [corpora.train]
|
||||||
> @readers = "spacy.Corpus.v1"
|
> @readers = "spacy.Corpus.v1"
|
||||||
> path = ${paths.train}
|
> path = ${paths.train}
|
||||||
> gold_preproc = false
|
> gold_preproc = false
|
||||||
|
@ -506,7 +506,7 @@ JSONL file. Also see the [`JsonlReader`](/api/corpus#jsonlreader) class.
|
||||||
> [paths]
|
> [paths]
|
||||||
> pretrain = "corpus/raw_text.jsonl"
|
> pretrain = "corpus/raw_text.jsonl"
|
||||||
>
|
>
|
||||||
> [pretraining.corpus]
|
> [corpora.pretrain]
|
||||||
> @readers = "spacy.JsonlReader.v1"
|
> @readers = "spacy.JsonlReader.v1"
|
||||||
> path = ${paths.pretrain}
|
> path = ${paths.pretrain}
|
||||||
> min_length = 0
|
> min_length = 0
|
||||||
|
|
|
@ -969,7 +969,7 @@ your results.
|
||||||
> [training.logger]
|
> [training.logger]
|
||||||
> @loggers = "spacy.WandbLogger.v1"
|
> @loggers = "spacy.WandbLogger.v1"
|
||||||
> project_name = "monitor_spacy_training"
|
> project_name = "monitor_spacy_training"
|
||||||
> remove_config_values = ["paths.train", "paths.dev", "training.corpus.train.path", "training.corpus.dev.path"]
|
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
![Screenshot: Visualized training results](../images/wandb1.jpg)
|
![Screenshot: Visualized training results](../images/wandb1.jpg)
|
||||||
|
|
|
@ -746,7 +746,7 @@ as **config settings** – in this case, `source`.
|
||||||
> #### config.cfg
|
> #### config.cfg
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [training.corpus.train]
|
> [corpora.train]
|
||||||
> @readers = "corpus_variants.v1"
|
> @readers = "corpus_variants.v1"
|
||||||
> source = "s3://your_bucket/path/data.csv"
|
> source = "s3://your_bucket/path/data.csv"
|
||||||
> ```
|
> ```
|
||||||
|
|
Loading…
Reference in New Issue
Block a user