mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Add training option to set annotations on update (#7767)
* Add training option to set annotations on update Add a `[training]` option called `set_annotations_on_update` to specify a list of components for which the predicted annotations should be set on `example.predicted` immediately after that component has been updated. The predicted annotations can be accessed by later components in the pipeline during the processing of the batch in the same `update` call. * Rename to annotates / annotating_components * Add test for `annotating_components` when training from config * Add documentation
This commit is contained in:
parent
c105ed10fd
commit
95c0833656
|
@ -80,6 +80,8 @@ eval_frequency = 200
|
||||||
score_weights = {}
|
score_weights = {}
|
||||||
# Names of pipeline components that shouldn't be updated during training
|
# Names of pipeline components that shouldn't be updated during training
|
||||||
frozen_components = []
|
frozen_components = []
|
||||||
|
# Names of pipeline components that should set annotations during training
|
||||||
|
annotating_components = []
|
||||||
# Location in the config where the dev corpus is defined
|
# Location in the config where the dev corpus is defined
|
||||||
dev_corpus = "corpora.dev"
|
dev_corpus = "corpora.dev"
|
||||||
# Location in the config where the train corpus is defined
|
# Location in the config where the train corpus is defined
|
||||||
|
|
|
@ -1074,6 +1074,7 @@ class Language:
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
exclude: Iterable[str] = SimpleFrozenList(),
|
exclude: Iterable[str] = SimpleFrozenList(),
|
||||||
|
annotates: Iterable[str] = SimpleFrozenList(),
|
||||||
):
|
):
|
||||||
"""Update the models in the pipeline.
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
|
@ -1081,10 +1082,13 @@ class Language:
|
||||||
_: Should not be set - serves to catch backwards-incompatible scripts.
|
_: Should not be set - serves to catch backwards-incompatible scripts.
|
||||||
drop (float): The dropout rate.
|
drop (float): The dropout rate.
|
||||||
sgd (Optimizer): An optimizer.
|
sgd (Optimizer): An optimizer.
|
||||||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
losses (Dict[str, float]): Dictionary to update with the loss, keyed by
|
||||||
|
component.
|
||||||
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||||
components, keyed by component name.
|
components, keyed by component name.
|
||||||
exclude (Iterable[str]): Names of components that shouldn't be updated.
|
exclude (Iterable[str]): Names of components that shouldn't be updated.
|
||||||
|
annotates (Iterable[str]): Names of components that should set
|
||||||
|
annotations on the predicted examples after updating.
|
||||||
RETURNS (Dict[str, float]): The updated losses dictionary
|
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#update
|
DOCS: https://spacy.io/api/language#update
|
||||||
|
@ -1103,15 +1107,16 @@ class Language:
|
||||||
sgd = self._optimizer
|
sgd = self._optimizer
|
||||||
if component_cfg is None:
|
if component_cfg is None:
|
||||||
component_cfg = {}
|
component_cfg = {}
|
||||||
|
pipe_kwargs = {}
|
||||||
for i, (name, proc) in enumerate(self.pipeline):
|
for i, (name, proc) in enumerate(self.pipeline):
|
||||||
component_cfg.setdefault(name, {})
|
component_cfg.setdefault(name, {})
|
||||||
|
pipe_kwargs[name] = deepcopy(component_cfg[name])
|
||||||
component_cfg[name].setdefault("drop", drop)
|
component_cfg[name].setdefault("drop", drop)
|
||||||
|
pipe_kwargs[name].setdefault("batch_size", self.batch_size)
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in exclude or not hasattr(proc, "update"):
|
if name not in exclude and hasattr(proc, "update"):
|
||||||
continue
|
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
|
||||||
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
|
if sgd not in (None, False):
|
||||||
if sgd not in (None, False):
|
|
||||||
for name, proc in self.pipeline:
|
|
||||||
if (
|
if (
|
||||||
name not in exclude
|
name not in exclude
|
||||||
and hasattr(proc, "is_trainable")
|
and hasattr(proc, "is_trainable")
|
||||||
|
@ -1119,6 +1124,18 @@ class Language:
|
||||||
and proc.model not in (True, False, None)
|
and proc.model not in (True, False, None)
|
||||||
):
|
):
|
||||||
proc.finish_update(sgd)
|
proc.finish_update(sgd)
|
||||||
|
if name in annotates:
|
||||||
|
for doc, eg in zip(
|
||||||
|
_pipe(
|
||||||
|
(eg.predicted for eg in examples),
|
||||||
|
proc=proc,
|
||||||
|
name=name,
|
||||||
|
default_error_handler=self.default_error_handler,
|
||||||
|
kwargs=pipe_kwargs[name],
|
||||||
|
),
|
||||||
|
examples,
|
||||||
|
):
|
||||||
|
eg.predicted = doc
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def rehearse(
|
def rehearse(
|
||||||
|
|
|
@ -313,6 +313,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
logger: Logger = Field(..., title="The logger to track training progress")
|
logger: Logger = Field(..., title="The logger to track training progress")
|
||||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||||
|
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
||||||
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
113
spacy/tests/pipeline/test_annotates_on_update.py
Normal file
113
spacy/tests/pipeline/test_annotates_on_update.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
from typing import Callable, Iterable, Iterator
|
||||||
|
import pytest
|
||||||
|
import io
|
||||||
|
|
||||||
|
from thinc.api import Config
|
||||||
|
from spacy.language import Language
|
||||||
|
from spacy.training import Example
|
||||||
|
from spacy.training.loop import train
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.util import registry, load_model_from_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_str():
|
||||||
|
return """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["sentencizer","assert_sents"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.assert_sents]
|
||||||
|
factory = "assert_sents"
|
||||||
|
|
||||||
|
[components.sentencizer]
|
||||||
|
factory = "sentencizer"
|
||||||
|
punct_chars = null
|
||||||
|
|
||||||
|
[training]
|
||||||
|
dev_corpus = "corpora.dev"
|
||||||
|
train_corpus = "corpora.train"
|
||||||
|
annotating_components = ["sentencizer"]
|
||||||
|
max_steps = 2
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "unannotated_corpus"
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "unannotated_corpus"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotates_on_update():
|
||||||
|
# The custom component checks for sentence annotation
|
||||||
|
@Language.factory("assert_sents", default_config={})
|
||||||
|
def assert_sents(nlp, name):
|
||||||
|
return AssertSents(name)
|
||||||
|
|
||||||
|
class AssertSents:
|
||||||
|
def __init__(self, name, **cfg):
|
||||||
|
self.name = name
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, doc):
|
||||||
|
if not doc.has_annotation("SENT_START"):
|
||||||
|
raise ValueError("No sents")
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def update(self, examples, *, drop=0.0, sgd=None, losses=None):
|
||||||
|
for example in examples:
|
||||||
|
if not example.predicted.has_annotation("SENT_START"):
|
||||||
|
raise ValueError("No sents")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe("sentencizer")
|
||||||
|
nlp.add_pipe("assert_sents")
|
||||||
|
|
||||||
|
# When the pipeline runs, annotations are set
|
||||||
|
doc = nlp("This is a sentence.")
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for text in ["a a", "b b", "c c"]:
|
||||||
|
examples.append(Example(nlp.make_doc(text), nlp(text)))
|
||||||
|
|
||||||
|
for example in examples:
|
||||||
|
assert not example.predicted.has_annotation("SENT_START")
|
||||||
|
|
||||||
|
# If updating without setting annotations, assert_sents will raise an error
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.update(examples)
|
||||||
|
|
||||||
|
# Updating while setting annotations for the sentencizer succeeds
|
||||||
|
nlp.update(examples, annotates=["sentencizer"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotating_components_from_config(config_str):
|
||||||
|
@registry.readers("unannotated_corpus")
|
||||||
|
def create_unannotated_corpus() -> Callable[[Language], Iterable[Example]]:
|
||||||
|
return UnannotatedCorpus()
|
||||||
|
|
||||||
|
class UnannotatedCorpus:
|
||||||
|
def __call__(self, nlp: Language) -> Iterator[Example]:
|
||||||
|
for text in ["a a", "b b", "c c"]:
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
yield Example(doc, doc)
|
||||||
|
|
||||||
|
orig_config = Config().from_str(config_str)
|
||||||
|
nlp = load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
assert nlp.config["training"]["annotating_components"] == ["sentencizer"]
|
||||||
|
train(nlp)
|
||||||
|
|
||||||
|
nlp.config["training"]["annotating_components"] = []
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
train(nlp)
|
|
@ -1,7 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.pipeline import TrainablePipe
|
from spacy.pipeline import TrainablePipe
|
||||||
|
from spacy.training import Example
|
||||||
from spacy.util import SimpleFrozenList, get_arg_names
|
from spacy.util import SimpleFrozenList, get_arg_names
|
||||||
|
from spacy.lang.en import English
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -417,3 +419,41 @@ def test_pipe_methods_initialize():
|
||||||
assert "test" in nlp.config["initialize"]["components"]
|
assert "test" in nlp.config["initialize"]["components"]
|
||||||
nlp.remove_pipe("test")
|
nlp.remove_pipe("test")
|
||||||
assert "test" not in nlp.config["initialize"]["components"]
|
assert "test" not in nlp.config["initialize"]["components"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_with_annotates():
|
||||||
|
name = "test_with_annotates"
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def make_component(name):
|
||||||
|
results[name] = ""
|
||||||
|
|
||||||
|
def component(doc):
|
||||||
|
nonlocal results
|
||||||
|
results[name] += doc.text
|
||||||
|
return doc
|
||||||
|
|
||||||
|
return component
|
||||||
|
|
||||||
|
c1 = Language.component(f"{name}1", func=make_component(f"{name}1"))
|
||||||
|
c2 = Language.component(f"{name}2", func=make_component(f"{name}2"))
|
||||||
|
|
||||||
|
components = set([f"{name}1", f"{name}2"])
|
||||||
|
|
||||||
|
nlp = English()
|
||||||
|
texts = ["a", "bb", "ccc"]
|
||||||
|
examples = []
|
||||||
|
for text in texts:
|
||||||
|
examples.append(Example(nlp.make_doc(text), nlp.make_doc(text)))
|
||||||
|
|
||||||
|
for components_to_annotate in [[], [f"{name}1"], [f"{name}1", f"{name}2"], [f"{name}2", f"{name}1"]]:
|
||||||
|
for key in results:
|
||||||
|
results[key] = ""
|
||||||
|
nlp = English(vocab=nlp.vocab)
|
||||||
|
nlp.add_pipe(f"{name}1")
|
||||||
|
nlp.add_pipe(f"{name}2")
|
||||||
|
nlp.update(examples, annotates=components_to_annotate)
|
||||||
|
for component in components_to_annotate:
|
||||||
|
assert results[component] == "".join(eg.predicted.text for eg in examples)
|
||||||
|
for component in components - set(components_to_annotate):
|
||||||
|
assert results[component] == ""
|
||||||
|
|
|
@ -74,6 +74,8 @@ def train(
|
||||||
|
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T["frozen_components"]
|
frozen_components = T["frozen_components"]
|
||||||
|
# Components that should set annotations on update
|
||||||
|
annotating_components = T["annotating_components"]
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
training_step_iterator = train_while_improving(
|
training_step_iterator = train_while_improving(
|
||||||
nlp,
|
nlp,
|
||||||
|
@ -86,11 +88,17 @@ def train(
|
||||||
max_steps=T["max_steps"],
|
max_steps=T["max_steps"],
|
||||||
eval_frequency=T["eval_frequency"],
|
eval_frequency=T["eval_frequency"],
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
|
annotating_components=annotating_components,
|
||||||
)
|
)
|
||||||
clean_output_dir(output_path)
|
clean_output_dir(output_path)
|
||||||
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
||||||
if frozen_components:
|
if frozen_components:
|
||||||
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
|
||||||
|
if annotating_components:
|
||||||
|
stdout.write(
|
||||||
|
msg.info(f"Set annotations on update for: {annotating_components}")
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate}") + "\n")
|
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate}") + "\n")
|
||||||
with nlp.select_pipes(disable=frozen_components):
|
with nlp.select_pipes(disable=frozen_components):
|
||||||
log_step, finalize_logger = train_logger(nlp, stdout, stderr)
|
log_step, finalize_logger = train_logger(nlp, stdout, stderr)
|
||||||
|
@ -142,6 +150,7 @@ def train_while_improving(
|
||||||
patience: int,
|
patience: int,
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
exclude: List[str],
|
exclude: List[str],
|
||||||
|
annotating_components: List[str],
|
||||||
):
|
):
|
||||||
"""Train until an evaluation stops improving. Works as a generator,
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
@ -193,7 +202,12 @@ def train_while_improving(
|
||||||
dropout = next(dropouts)
|
dropout = next(dropouts)
|
||||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
nlp.update(
|
nlp.update(
|
||||||
subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
|
subbatch,
|
||||||
|
drop=dropout,
|
||||||
|
losses=losses,
|
||||||
|
sgd=False,
|
||||||
|
exclude=exclude,
|
||||||
|
annotates=annotating_components,
|
||||||
)
|
)
|
||||||
# TODO: refactor this so we don't have to run it separately in here
|
# TODO: refactor this so we don't have to run it separately in here
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
|
|
|
@ -182,24 +182,25 @@ single corpus once and then divide it up into `train` and `dev` partitions.
|
||||||
This section defines settings and controls for the training and evaluation
|
This section defines settings and controls for the training and evaluation
|
||||||
process that are used when you run [`spacy train`](/api/cli#train).
|
process that are used when you run [`spacy train`](/api/cli#train).
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
||||||
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
||||||
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
|
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
|
||||||
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
||||||
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
||||||
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
||||||
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be initialized or updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be initialized or updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
||||||
| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be `"pytorch"` or `"tensorflow"`. Defaults to variable `${system.gpu_allocator}`. ~~str~~ |
|
| `annotating_components` | Pipeline component names that should set annotations on the predicted docs during training. See [here](/usage/training#annotating-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
||||||
| `logger` | Callable that takes the `nlp` and stdout and stderr `IO` objects, sets up the logger, and returns two new callables to log a training step and to finalize the logger. Defaults to [`ConsoleLogger`](/api/top-level#ConsoleLogger). ~~Callable[[Language, IO, IO], [Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]]]~~ |
|
| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be `"pytorch"` or `"tensorflow"`. Defaults to variable `${system.gpu_allocator}`. ~~str~~ |
|
||||||
| `max_epochs` | Maximum number of epochs to train for. `0` means an unlimited number of epochs. `-1` means that the train corpus should be streamed rather than loaded into memory with no shuffling within the training loop. Defaults to `0`. ~~int~~ |
|
| `logger` | Callable that takes the `nlp` and stdout and stderr `IO` objects, sets up the logger, and returns two new callables to log a training step and to finalize the logger. Defaults to [`ConsoleLogger`](/api/top-level#ConsoleLogger). ~~Callable[[Language, IO, IO], [Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]]]~~ |
|
||||||
| `max_steps` | Maximum number of update steps to train for. `0` means an unlimited number of steps. Defaults to `20000`. ~~int~~ |
|
| `max_epochs` | Maximum number of epochs to train for. `0` means an unlimited number of epochs. `-1` means that the train corpus should be streamed rather than loaded into memory with no shuffling within the training loop. Defaults to `0`. ~~int~~ |
|
||||||
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
| `max_steps` | Maximum number of update steps to train for. `0` means an unlimited number of steps. Defaults to `20000`. ~~int~~ |
|
||||||
| `patience` | How many steps to continue without improvement in evaluation score. `0` disables early stopping. Defaults to `1600`. ~~int~~ |
|
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
||||||
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
|
| `patience` | How many steps to continue without improvement in evaluation score. `0` disables early stopping. Defaults to `1600`. ~~int~~ |
|
||||||
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
|
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
|
||||||
| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
|
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
|
||||||
|
| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
|
||||||
|
|
||||||
### pretraining {#config-pretraining tag="section,optional"}
|
### pretraining {#config-pretraining tag="section,optional"}
|
||||||
|
|
||||||
|
|
|
@ -245,14 +245,14 @@ and call the optimizer, while the others simply increment the gradients.
|
||||||
> losses = trf.update(examples, sgd=optimizer)
|
> losses = trf.update(examples, sgd=optimizer)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `examples` | A batch of [`Example`](/api/example) objects. Only the [`Example.predicted`](/api/example#predicted) `Doc` object is used, the reference `Doc` is ignored. ~~Iterable[Example]~~ |
|
| `examples` | A batch of [`Example`](/api/example) objects. Only the [`Example.predicted`](/api/example#predicted) `Doc` object is used, the reference `Doc` is ignored. ~~Iterable[Example]~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `drop` | The dropout rate. ~~float~~ |
|
| `drop` | The dropout rate. ~~float~~ |
|
||||||
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
## Transformer.create_optimizer {#create_optimizer tag="method"}
|
## Transformer.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
|
@ -493,6 +493,11 @@ This requires sentence boundaries to be set (e.g. by the
|
||||||
depending on the sentence lengths. However, it does provide the transformer with
|
depending on the sentence lengths. However, it does provide the transformer with
|
||||||
more meaningful windows to attend over.
|
more meaningful windows to attend over.
|
||||||
|
|
||||||
|
To set sentence boundaries with the `sentencizer` during training, add a
|
||||||
|
`sentencizer` to the beginning of the pipeline and include it in
|
||||||
|
[`[training.annotating_components]`](/usage/training#annotating-components) to
|
||||||
|
have it set the sentence boundaries before the `transformer` component runs.
|
||||||
|
|
||||||
### strided_spans.v1 {#strided_spans tag="registered function"}
|
### strided_spans.v1 {#strided_spans tag="registered function"}
|
||||||
|
|
||||||
> #### Example config
|
> #### Example config
|
||||||
|
|
|
@ -414,11 +414,11 @@ as-is. They are also excluded when calling
|
||||||
> #### Note on frozen components
|
> #### Note on frozen components
|
||||||
>
|
>
|
||||||
> Even though frozen components are not **updated** during training, they will
|
> Even though frozen components are not **updated** during training, they will
|
||||||
> still **run** during training and evaluation. This is very important, because
|
> still **run** during evaluation. This is very important, because they may
|
||||||
> they may still impact your model's performance – for instance, a sentence
|
> still impact your model's performance – for instance, a sentence boundary
|
||||||
> boundary detector can impact what the parser or entity recognizer considers a
|
> detector can impact what the parser or entity recognizer considers a valid
|
||||||
> valid parse. So the evaluation results should always reflect what your
|
> parse. So the evaluation results should always reflect what your pipeline will
|
||||||
> pipeline will produce at runtime.
|
> produce at runtime.
|
||||||
|
|
||||||
```ini
|
```ini
|
||||||
[nlp]
|
[nlp]
|
||||||
|
@ -455,6 +455,64 @@ replace_listeners = ["model.tok2vec"]
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
|
### Using predictions from preceding components {#annotating-components new="3.1"}
|
||||||
|
|
||||||
|
By default, components are updated in isolation during training, which means
|
||||||
|
that they don't see the predictions of any earlier components in the pipeline. A
|
||||||
|
component receives [`Example.predicted`](/api/example) as input and compares its
|
||||||
|
predictions to [`Example.reference`](/api/example) without saving its
|
||||||
|
annotations in the `predicted` doc.
|
||||||
|
|
||||||
|
Instead, if certain components should **set their annotations** during training,
|
||||||
|
use the setting `annotating_components` in the `[training]` block to specify a
|
||||||
|
list of components. For example, the feature `DEP` from the parser could be used
|
||||||
|
as a tagger feature by including `DEP` in the tok2vec `attrs` and including
|
||||||
|
`parser` in `annotating_components`:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
### config.cfg (excerpt) {highlight="7,12"}
|
||||||
|
[nlp]
|
||||||
|
pipeline = ["parser", "tagger"]
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = ${components.tagger.model.tok2vec.encode.width}
|
||||||
|
attrs = ["NORM","DEP"]
|
||||||
|
rows = [5000,2500]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[training]
|
||||||
|
annotating_components = ["parser"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Any component in the pipeline can be included as an annotating component,
|
||||||
|
including frozen components. Frozen components can set annotations during
|
||||||
|
training just as they would set annotations during evaluation or when the final
|
||||||
|
pipeline is run. The config excerpt below shows how a frozen `ner` component and
|
||||||
|
a `sentencizer` can provide the required `doc.sents` and `doc.ents` for the
|
||||||
|
entity linker during training:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
### config.cfg (excerpt)
|
||||||
|
[nlp]
|
||||||
|
pipeline = ["sentencizer", "ner", "entity_linker"]
|
||||||
|
|
||||||
|
[components.ner]
|
||||||
|
source = "en_core_web_sm"
|
||||||
|
|
||||||
|
[training]
|
||||||
|
frozen_components = ["ner"]
|
||||||
|
annotating_components = ["sentencizer", "ner"]
|
||||||
|
```
|
||||||
|
|
||||||
|
<Infobox variant="warning" title="Training speed with annotating components" id="annotating-components-speed">
|
||||||
|
|
||||||
|
Be aware that non-frozen annotating components with statistical models will
|
||||||
|
**run twice** on each batch, once to update the model and once to apply the
|
||||||
|
now-updated model to the predicted docs.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
### Using registered functions {#config-functions}
|
### Using registered functions {#config-functions}
|
||||||
|
|
||||||
The training configuration defined in the config file doesn't have to only
|
The training configuration defined in the config file doesn't have to only
|
||||||
|
|
Loading…
Reference in New Issue
Block a user