mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-23 06:29:48 +03:00
Init
This commit is contained in:
parent
a9910b6081
commit
60661ab0fa
|
@ -10,6 +10,7 @@ from .info import info # noqa: F401
|
||||||
from .package import package # noqa: F401
|
from .package import package # noqa: F401
|
||||||
from .profile import profile # noqa: F401
|
from .profile import profile # noqa: F401
|
||||||
from .train import train_cli # noqa: F401
|
from .train import train_cli # noqa: F401
|
||||||
|
from .rehearse import rehearse_cli # noqa: F401
|
||||||
from .assemble import assemble_cli # noqa: F401
|
from .assemble import assemble_cli # noqa: F401
|
||||||
from .pretrain import pretrain # noqa: F401
|
from .pretrain import pretrain # noqa: F401
|
||||||
from .debug_data import debug_data # noqa: F401
|
from .debug_data import debug_data # noqa: F401
|
||||||
|
|
83
spacy/cli/rehearse.py
Normal file
83
spacy/cli/rehearse.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
from typing import Optional, Dict, Any, Union
|
||||||
|
from pathlib import Path
|
||||||
|
from wasabi import msg
|
||||||
|
import typer
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
|
from ._util import import_code, setup_gpu
|
||||||
|
from ..training.loop import train as train_nlp
|
||||||
|
from ..training.initialize import init_nlp
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(
|
||||||
|
"rehearse",
|
||||||
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||||
|
)
|
||||||
|
def rehearse_cli(
|
||||||
|
# fmt: off
|
||||||
|
ctx: typer.Context, # This is only used to read additional arguments
|
||||||
|
config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
|
||||||
|
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store trained pipeline in"),
|
||||||
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
|
||||||
|
# fmt: on
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rehearse a spaCy pipeline. Requires data in spaCy's binary format. To
|
||||||
|
convert data from other formats, use the `spacy convert` command. The
|
||||||
|
config file includes all settings and hyperparameters used during training.
|
||||||
|
To override settings in the config, e.g. settings that point to local
|
||||||
|
paths or that you want to experiment with, you can override them as
|
||||||
|
command line options. For instance, --training.batch_size 128 overrides
|
||||||
|
the value of "batch_size" in the block "[training]". The --code argument
|
||||||
|
lets you pass in a Python file that's imported before training. It can be
|
||||||
|
used to register custom functions and architectures that can then be
|
||||||
|
referenced in the config.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/cli#rehearse
|
||||||
|
"""
|
||||||
|
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
overrides = parse_config_overrides(ctx.args)
|
||||||
|
import_code(code_path)
|
||||||
|
rehearse(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
|
||||||
|
|
||||||
|
|
||||||
|
def rehearse(
|
||||||
|
config_path: Union[str, Path],
|
||||||
|
output_path: Optional[Union[str, Path]] = None,
|
||||||
|
*,
|
||||||
|
use_gpu: int = -1,
|
||||||
|
overrides: Dict[str, Any] = util.SimpleFrozenDict(),
|
||||||
|
):
|
||||||
|
config_path = util.ensure_path(config_path)
|
||||||
|
output_path = util.ensure_path(output_path)
|
||||||
|
# Make sure all files and paths exists if they are needed
|
||||||
|
if not config_path or (str(config_path) != "-" and not config_path.exists()):
|
||||||
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
|
if not output_path:
|
||||||
|
msg.info("No output directory provided")
|
||||||
|
else:
|
||||||
|
if not output_path.exists():
|
||||||
|
output_path.mkdir(parents=True)
|
||||||
|
msg.good(f"Created output directory: {output_path}")
|
||||||
|
msg.info(f"Saving to output directory: {output_path}")
|
||||||
|
setup_gpu(use_gpu)
|
||||||
|
with show_validation_error(config_path):
|
||||||
|
config = util.load_config(config_path, overrides=overrides, interpolate=False)
|
||||||
|
msg.divider("Initializing pipeline")
|
||||||
|
with show_validation_error(config_path, hint_fill=False):
|
||||||
|
nlp = init_nlp(config, use_gpu=use_gpu)
|
||||||
|
msg.good("Initialized pipeline")
|
||||||
|
msg.divider("Training pipeline")
|
||||||
|
train_nlp(
|
||||||
|
nlp,
|
||||||
|
output_path,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
use_rehearse=True,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
stderr=sys.stderr,
|
||||||
|
)
|
|
@ -1211,32 +1211,25 @@ class Language:
|
||||||
if isinstance(examples, list) and len(examples) == 0:
|
if isinstance(examples, list) and len(examples) == 0:
|
||||||
return losses
|
return losses
|
||||||
validate_examples(examples, "Language.rehearse")
|
validate_examples(examples, "Language.rehearse")
|
||||||
if sgd is None:
|
|
||||||
if self._optimizer is None:
|
|
||||||
self._optimizer = self.create_optimizer()
|
|
||||||
sgd = self._optimizer
|
|
||||||
pipes = list(self.pipeline)
|
pipes = list(self.pipeline)
|
||||||
random.shuffle(pipes)
|
|
||||||
if component_cfg is None:
|
if component_cfg is None:
|
||||||
component_cfg = {}
|
component_cfg = {}
|
||||||
grads = {}
|
|
||||||
|
|
||||||
def get_grads(key, W, dW):
|
|
||||||
grads[key] = (W, dW)
|
|
||||||
return W, dW
|
|
||||||
|
|
||||||
get_grads.learn_rate = sgd.learn_rate # type: ignore[attr-defined, union-attr]
|
|
||||||
get_grads.b1 = sgd.b1 # type: ignore[attr-defined, union-attr]
|
|
||||||
get_grads.b2 = sgd.b2 # type: ignore[attr-defined, union-attr]
|
|
||||||
for name, proc in pipes:
|
for name, proc in pipes:
|
||||||
if name in exclude or not hasattr(proc, "rehearse"):
|
if name in exclude or not hasattr(proc, "rehearse"):
|
||||||
continue
|
continue
|
||||||
grads = {}
|
|
||||||
proc.rehearse( # type: ignore[attr-defined]
|
proc.rehearse( # type: ignore[attr-defined]
|
||||||
examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {})
|
examples, sgd=None, losses=losses, **component_cfg.get(name, {})
|
||||||
)
|
)
|
||||||
for key, (W, dW) in grads.items():
|
if isinstance(sgd, Optimizer):
|
||||||
sgd(key, W, dW) # type: ignore[call-arg, misc]
|
if (
|
||||||
|
name not in exclude
|
||||||
|
and isinstance(proc, ty.TrainableComponent)
|
||||||
|
and proc.is_trainable
|
||||||
|
and proc.model not in (True, False, None)
|
||||||
|
):
|
||||||
|
proc.finish_update(sgd)
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def begin_training(
|
def begin_training(
|
||||||
|
|
|
@ -241,7 +241,8 @@ class Tagger(TrainablePipe):
|
||||||
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
|
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
|
||||||
grads, loss = loss_func(tag_scores, tutor_tag_scores)
|
grads, loss = loss_func(tag_scores, tutor_tag_scores)
|
||||||
bp_tag_scores(grads)
|
bp_tag_scores(grads)
|
||||||
self.finish_update(sgd)
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ def train(
|
||||||
output_path: Optional[Path] = None,
|
output_path: Optional[Path] = None,
|
||||||
*,
|
*,
|
||||||
use_gpu: int = -1,
|
use_gpu: int = -1,
|
||||||
|
use_rehearse: bool = False,
|
||||||
stdout: IO = sys.stdout,
|
stdout: IO = sys.stdout,
|
||||||
stderr: IO = sys.stderr,
|
stderr: IO = sys.stderr,
|
||||||
) -> Tuple["Language", Optional[Path]]:
|
) -> Tuple["Language", Optional[Path]]:
|
||||||
|
@ -35,6 +36,7 @@ def train(
|
||||||
output_path (Optional[Path]): Optional output path to save trained model to.
|
output_path (Optional[Path]): Optional output path to save trained model to.
|
||||||
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||||
before calling this function.
|
before calling this function.
|
||||||
|
use_rehearse (bool): Use nlp.rehearse after nlp.update
|
||||||
stdout (file): A file-like object to write output messages. To disable
|
stdout (file): A file-like object to write output messages. To disable
|
||||||
printing, set to io.StringIO.
|
printing, set to io.StringIO.
|
||||||
stderr (file): A second file-like object to write output messages. To disable
|
stderr (file): A second file-like object to write output messages. To disable
|
||||||
|
@ -54,7 +56,10 @@ def train(
|
||||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||||
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||||
optimizer = T["optimizer"]
|
if use_rehearse:
|
||||||
|
optimizer = nlp.resume_training()
|
||||||
|
else:
|
||||||
|
optimizer = T["optimizer"]
|
||||||
score_weights = T["score_weights"]
|
score_weights = T["score_weights"]
|
||||||
batcher = T["batcher"]
|
batcher = T["batcher"]
|
||||||
train_logger = T["logger"]
|
train_logger = T["logger"]
|
||||||
|
@ -88,6 +93,7 @@ def train(
|
||||||
patience=T["patience"],
|
patience=T["patience"],
|
||||||
max_steps=T["max_steps"],
|
max_steps=T["max_steps"],
|
||||||
eval_frequency=T["eval_frequency"],
|
eval_frequency=T["eval_frequency"],
|
||||||
|
use_rehearse=use_rehearse,
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
annotating_components=annotating_components,
|
annotating_components=annotating_components,
|
||||||
before_update=before_update,
|
before_update=before_update,
|
||||||
|
@ -150,6 +156,7 @@ def train_while_improving(
|
||||||
accumulate_gradient: int,
|
accumulate_gradient: int,
|
||||||
patience: int,
|
patience: int,
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
|
use_rehearse: bool = False,
|
||||||
exclude: List[str],
|
exclude: List[str],
|
||||||
annotating_components: List[str],
|
annotating_components: List[str],
|
||||||
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
||||||
|
@ -214,6 +221,12 @@ def train_while_improving(
|
||||||
exclude=exclude,
|
exclude=exclude,
|
||||||
annotates=annotating_components,
|
annotates=annotating_components,
|
||||||
)
|
)
|
||||||
|
nlp.rehearse(
|
||||||
|
subbatch,
|
||||||
|
losses=losses,
|
||||||
|
sgd=False, # type: ignore[arg-type]
|
||||||
|
exclude=exclude,
|
||||||
|
)
|
||||||
# TODO: refactor this so we don't have to run it separately in here
|
# TODO: refactor this so we don't have to run it separately in here
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -11,6 +11,7 @@ menu:
|
||||||
- ['debug', 'debug']
|
- ['debug', 'debug']
|
||||||
- ['train', 'train']
|
- ['train', 'train']
|
||||||
- ['pretrain', 'pretrain']
|
- ['pretrain', 'pretrain']
|
||||||
|
- ['rehearse', 'rehearse']
|
||||||
- ['evaluate', 'evaluate']
|
- ['evaluate', 'evaluate']
|
||||||
- ['benchmark', 'benchmark']
|
- ['benchmark', 'benchmark']
|
||||||
- ['apply', 'apply']
|
- ['apply', 'apply']
|
||||||
|
@ -1134,6 +1135,39 @@ $ python -m spacy pretrain [config_path] [output_dir] [--code] [--resume-path] [
|
||||||
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.dropout 0.2`. ~~Any (option/flag)~~ |
|
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.dropout 0.2`. ~~Any (option/flag)~~ |
|
||||||
| **CREATES** | The pretrained weights that can be used to initialize `spacy train`. |
|
| **CREATES** | The pretrained weights that can be used to initialize `spacy train`. |
|
||||||
|
|
||||||
|
## rehearse {id="rehearse",tag="command, experimental"}
|
||||||
|
|
||||||
|
This command is designed to fine-tune pre-trained models while also trying to address the “catastrophic forgetting” problem.
|
||||||
|
It uses "rehearsal" updates that teach the current model to make predictions similar to an initial model. This feature is experimental.
|
||||||
|
|
||||||
|
<Infobox title="Please note" variant="warning">
|
||||||
|
|
||||||
|
The `rehearse` command outputs the sum of both losses from the `TrainablePipe.update` and `TrainablePipe.rehearse`.
|
||||||
|
This can potentially cause the loss to increase drastically, even while the scores also increasing. It's likely due to the model making more different predictions than the intital model.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```bash
|
||||||
|
> $ python -m spacy rehearse config.cfg --output ./output --paths.train ./train --paths.dev ./dev
|
||||||
|
> ```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ python -m spacy rehearse [config_path] [--output] [--code] [--verbose] [--gpu-id] [overrides]
|
||||||
|
```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `config_path` | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. If `-`, the data will be [read from stdin](/usage/training#config-stdin). ~~Union[Path, str] \(positional)~~ |
|
||||||
|
| `--output`, `-o` | Directory to store trained pipeline in. Will be created if it doesn't exist. ~~Optional[Path] \(option)~~ |
|
||||||
|
| `--code`, `-c` | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ |
|
||||||
|
| `--verbose`, `-V` | Show more detailed messages during training. ~~bool (flag)~~ |
|
||||||
|
| `--gpu-id`, `-g` | GPU ID or `-1` for CPU. Defaults to `-1`. ~~int (option)~~ |
|
||||||
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
|
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--paths.train ./train.spacy`. ~~Any (option/flag)~~ |
|
||||||
|
| **CREATES** | The final rehearse pipeline and the best rehearsed pipeline.
|
||||||
|
|
||||||
## evaluate {id="evaluate",version="2",tag="command"}
|
## evaluate {id="evaluate",version="2",tag="command"}
|
||||||
|
|
||||||
The `evaluate` subcommand is superseded by
|
The `evaluate` subcommand is superseded by
|
||||||
|
|
|
@ -346,7 +346,7 @@ and custom registered functions if needed. See the
|
||||||
|
|
||||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||||
current model to make predictions similar to an initial model, to try to address
|
current model to make predictions similar to an initial model, to try to address
|
||||||
the "catastrophic forgetting" problem. This feature is experimental.
|
the "catastrophic forgetting" problem. Please note that this function needs to be used together with `Language.update`. This feature is experimental.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
|
|
@ -244,7 +244,7 @@ predictions and gold-standard annotations, and update the component's model.
|
||||||
|
|
||||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||||
current model to make predictions similar to an initial model, to try to address
|
current model to make predictions similar to an initial model, to try to address
|
||||||
the "catastrophic forgetting" problem. This feature is experimental.
|
the "catastrophic forgetting" problem. Please note that this function needs to be used together with `TrainablePipe.update`. This feature is experimental.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
|
Loading…
Reference in New Issue
Block a user