This commit is contained in:
Paul O'Leary McCann 2023-03-03 16:51:04 +01:00 committed by GitHub
commit 16145b2188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 539 additions and 0 deletions

View File

@ -30,6 +30,9 @@ from .project.push import project_push # noqa: F401
from .project.pull import project_pull # noqa: F401
from .project.document import project_document # noqa: F401
from .find_threshold import find_threshold # noqa: F401
from .configure import use_transformer, use_tok2vec # noqa: F401
from .configure import configure_resume_cli # noqa: F401
from .merge import merge_pipelines # noqa: F401
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)

View File

@ -48,6 +48,7 @@ and custom model implementations.
"""
BENCHMARK_HELP = """Commands for benchmarking pipelines."""
INIT_HELP = """Commands for initializing configs and pipeline packages."""
CONFIGURE_HELP = """Commands for automatically modifying configs."""
# Wrappers for Typer's annotations. Initially created to set defaults and to
# keep the names short, but not needed at the moment.
@ -59,11 +60,13 @@ benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_he
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
configure_cli = typer.Typer(name="configure", help=CONFIGURE_HELP, no_args_is_help=True)
app.add_typer(project_cli)
app.add_typer(debug_cli)
app.add_typer(benchmark_cli)
app.add_typer(init_cli)
app.add_typer(configure_cli)
def setup_cli() -> None:

214
spacy/cli/configure.py Normal file
View File

@ -0,0 +1,214 @@
from pathlib import Path
from wasabi import msg
import typer
from thinc.api import Config
from typing import Any, Dict, Iterable, List, Union
import spacy
from spacy.language import Language
from ._util import configure_cli, Arg, Opt
# These are the architectures that are recognized as tok2vec/feature sources.
TOK2VEC_ARCHS = [
("spacy", "Tok2Vec"),
("spacy", "HashEmbedCNN"),
("spacy-transformers", "TransformerModel"),
]
# These are the listeners.
LISTENER_ARCHS = [
("spacy", "Tok2VecListener"),
("spacy-transformers", "TransformerListener"),
]
def _deep_get(
obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any
) -> Any:
"""Given a multi-part key, try to get the key. If at any point this isn't
possible, return the default.
"""
out = None
slot = obj
for notch in key:
if slot is None or notch not in slot:
return default
slot = slot[notch]
return slot
def _get_tok2vecs(config: Config) -> List[str]:
"""Given a pipeline config, return the names of components that are
tok2vecs (or Transformers).
"""
out = []
for name, comp in config["components"].items():
arch = _deep_get(comp, ("model", "@architectures"), False)
if not arch:
continue
ns, model, ver = arch.split(".")
if (ns, model) in TOK2VEC_ARCHS:
out.append(name)
return out
def _has_listener(nlp: Language, pipe_name: str):
"""Given a pipeline and a component name, check if it has a listener."""
arch = _deep_get(
nlp.config,
("components", pipe_name, "model", "tok2vec", "@architectures"),
False,
)
if not arch:
return False
ns, model, ver = arch.split(".")
return (ns, model) in LISTENER_ARCHS
def _get_listeners(nlp: Language) -> List[str]:
"""Get the name of every component that contains a listener.
Does not check that they listen to the same thing; assumes a pipeline has
only one feature source.
"""
out = []
for name in nlp.pipe_names:
if _has_listener(nlp, name):
out.append(name)
return out
def _check_single_tok2vec(name: str, config: Config) -> None:
"""Check if there is just one tok2vec in a config.
A very simple check, but used in multiple functions.
"""
tok2vecs = _get_tok2vecs(config)
fail_msg = f"""
Can't handle pipelines with more than one feature source,
but {name} has {len(tok2vecs)}."""
if len(tok2vecs) > 1:
msg.fail(fail_msg, exits=1)
@configure_cli.command("resume")
def configure_resume_cli(
# fmt: off
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
# fmt: on
) -> Config:
"""Create a config for resuming training.
A config for resuming training is the same as the input config, but with
all components sourced.
DOCS: https://spacy.io/api/cli#configure-resume
"""
nlp = spacy.load(base_model)
conf = nlp.config
# Paths are not JSON serializable
path_str = str(base_model)
for comp in nlp.pipe_names:
conf["components"][comp] = {"source": path_str}
if str(output_file) == "-":
print(conf.to_str())
else:
conf.to_disk(output_file)
msg.good("Saved config", output_file)
return conf
@configure_cli.command("transformer")
def use_transformer(
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
) -> Config:
"""Replace pipeline tok2vec with transformer.
DOCS: https://spacy.io/api/cli#configure-transformer
"""
# 1. identify tok2vec
# 2. replace tok2vec
# 3. replace listeners
nlp = spacy.load(base_model)
_check_single_tok2vec(base_model, nlp.config)
tok2vecs = _get_tok2vecs(nlp.config)
assert len(tok2vecs) > 0, "Must have tok2vec to replace!"
nlp.remove_pipe(tok2vecs[0])
# the rest can be default values
trf_config = {
"model": {
"name": transformer_name,
}
}
try:
trf = nlp.add_pipe("transformer", config=trf_config, first=True)
except ValueError:
fail_msg = (
"Configuring a transformer requires spacy-transformers. "
"Install with: pip install spacy-transformers"
)
msg.fail(fail_msg, exits=1)
# now update the listeners
listeners = _get_listeners(nlp)
for listener in listeners:
listener_config = {
"@architectures": "spacy-transformers.TransformerListener.v1",
"grad_factor": 1.0,
"upstream": "transformer",
"pooling": {"@layers": "reduce_mean.v1"},
}
nlp.config["components"][listener]["model"]["tok2vec"] = listener_config
if str(output_file) == "-":
print(nlp.config.to_str())
else:
nlp.config.to_disk(output_file)
msg.good("Saved config", output_file)
return nlp.config
@configure_cli.command("tok2vec")
def use_tok2vec(base_model: str, output_file: Path) -> Config:
"""Replace pipeline tok2vec with CNN tok2vec.
DOCS: https://spacy.io/api/cli#configure-tok2vec
"""
nlp = spacy.load(base_model)
_check_single_tok2vec(base_model, nlp.config)
tok2vecs = _get_tok2vecs(nlp.config)
assert len(tok2vecs) > 0, "Must have tok2vec to replace!"
nlp.remove_pipe(tok2vecs[0])
tok2vec = nlp.add_pipe("tok2vec", first=True)
width = "${components.tok2vec.model.encode:width}"
listeners = _get_listeners(nlp)
for listener in listeners:
listener_config = {
"@architectures": "spacy.Tok2VecListener.v1",
"width": width,
"upstream": "tok2vec",
}
nlp.config["components"][listener]["model"]["tok2vec"] = listener_config
if str(output_file) == "-":
print(nlp.config.to_str())
else:
nlp.config.to_disk(output_file)
msg.good("Saved config", output_file)
return nlp.config

160
spacy/cli/merge.py Normal file
View File

@ -0,0 +1,160 @@
from pathlib import Path
import re
from wasabi import msg
import spacy
from spacy.language import Language
from ._util import app, Arg, Opt, Dict
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
from .configure import _has_listener
def _increment_suffix(name: str) -> str:
"""Given a name, return an incremented version.
If no numeric suffix is found, return the original with "2" appended.
This is used to avoid name collisions in pipelines.
"""
res = re.search(r"\d+$", name)
if res is None:
return f"{name}2"
else:
num = res.group()
prefix = name[0 : -len(num)]
return f"{prefix}{int(num) + 1}"
def _make_unique_pipe_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
"""Given two pipelines, try to rename any collisions in component names.
If a simple increment of a numeric suffix doesn't work, will give up.
"""
fail_msg = """
Tried automatically renaming {name} to {new_name}, but still
had a collision, so bailing out. Please make your pipe names
unique.
"""
# map of components to be renamed
rename = {}
# check pipeline names
names = nlp.pipe_names
for name in nlp2.pipe_names:
if name in names:
inc = _increment_suffix(name)
if inc in names or inc in nlp2.pipe_names:
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
rename[name] = inc
return rename
def _inner_merge(
nlp: Language, nlp2: Language, replace_listeners: bool = False
) -> Language:
"""Actually do the merge.
nlp (Language): Base pipeline to add components to.
nlp2 (Language): Pipeline to add components from.
replace_listeners (bool): Whether to replace listeners. Usually only true
if there's one listener.
returns: assembled pipeline.
"""
# The outer merge already verified there was exactly one tok2vec
tok2vec_name = _get_tok2vecs(nlp2.config)[0]
rename = _make_unique_pipe_names(nlp, nlp2)
if len(_get_listeners(nlp2)) > 1:
if replace_listeners:
msg.warn(
"""
Replacing listeners for multiple components. Note this can make
your pipeline large and slow. Consider chaining pipelines (like
nlp2(nlp(text))) instead.
"""
)
else:
# TODO provide a guide for what to do here
msg.warn(
"""
The result of this merge will have two feature sources
(tok2vecs) and multiple listeners. This will work for
inference, but will probably not work when training without
extra adjustment. If you continue to train the pipelines
separately this is not a problem.
"""
)
for comp in nlp2.pipe_names:
if replace_listeners and comp == tok2vec_name:
# the tok2vec should not be copied over
continue
if replace_listeners and _has_listener(nlp2, comp):
nlp2.replace_listeners(tok2vec_name, comp, ["model.tok2vec"])
nlp.add_pipe(comp, source=nlp2, name=rename.get(comp, comp))
if comp in rename:
msg.info(f"Renaming {comp} to {rename[comp]} to avoid collision...")
return nlp
@app.command("merge")
def merge_pipelines(
# fmt: off
base_model: str = Arg(..., help="Name or path of base model"),
added_model: str = Arg(..., help="Name or path of model to be merged"),
output_file: Path = Arg(..., help="Path to save merged model")
# fmt: on
) -> Language:
"""Combine components from multiple pipelines.
Given two pipelines, the components from them are merged into a single
pipeline. The exact way this works depends on whether the second pipeline
has one listener or more than one listener. In the single listener case
`replace_listeners` is used, otherwise components are simply appended to
the base pipeline.
DOCS: https://spacy.io/api/cli#merge
"""
nlp = spacy.load(base_model)
nlp2 = spacy.load(added_model)
# to merge models:
# - lang must be the same
# - vectors must be the same
# - vocabs must be the same
# - tokenizer must be the same (only partially checkable)
if nlp.lang != nlp2.lang:
msg.fail("Can't merge - languages don't match", exits=1)
# check vector equality
if (
nlp.vocab.vectors.shape != nlp2.vocab.vectors.shape
or nlp.vocab.vectors.key2row != nlp2.vocab.vectors.key2row
or nlp.vocab.vectors.to_bytes(exclude=["strings"])
!= nlp2.vocab.vectors.to_bytes(exclude=["strings"])
):
msg.fail("Can't merge - vectors don't match", exits=1)
if nlp.config["nlp"]["tokenizer"] != nlp2.config["nlp"]["tokenizer"]:
msg.fail("Can't merge - tokenizers don't match", exits=1)
# Check that each pipeline only has one feature source
_check_single_tok2vec(base_model, nlp.config)
_check_single_tok2vec(added_model, nlp2.config)
# Check how many listeners there are and replace based on that
# TODO: option to recognize frozen tok2vecs
# TODO: take list of pipe names to copy, ignore others
listeners = _get_listeners(nlp2)
replace_listeners = len(listeners) == 1
nlp_out = _inner_merge(nlp, nlp2, replace_listeners=replace_listeners)
# write the final pipeline
nlp.to_disk(output_file)
msg.info(f"Saved pipeline to: {output_file}")
return nlp

View File

@ -21,6 +21,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands
from spacy.cli._util import upload_file, download_file
from spacy.cli.configure import configure_resume_cli, use_tok2vec
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
@ -29,6 +30,7 @@ from spacy.cli.debug_data import _print_span_characteristics
from spacy.cli.debug_data import _get_spans_length_freq_dist
from spacy.cli.download import get_compatibility, get_version
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.merge import merge_pipelines
from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name
from spacy.cli.project.remote_storage import RemoteStorage
@ -1180,6 +1182,71 @@ def test_upload_download_local_file():
assert file_.read() == content
def test_configure_resume(tmp_path):
nlp = spacy.blank("en")
nlp.add_pipe("ner")
nlp.add_pipe("textcat")
base_path = tmp_path / "base"
nlp.to_disk(base_path)
out_path = tmp_path / "resume.cfg"
conf = configure_resume_cli(base_path, out_path)
assert out_path.exists(), "Didn't save config file"
for comp, val in conf["components"].items():
assert "source" in val, f"Non-sourced component: {comp}"
def test_use_tok2vec(tmp_path):
# Can't add a transformer here because spacy-transformers might not be present
nlp = spacy.blank("en")
nlp.add_pipe("tok2vec")
base_path = tmp_path / "tok2vec_sample_2"
nlp.to_disk(base_path)
out_path = tmp_path / "converted_to_tok2vec"
conf = use_tok2vec(base_path, out_path)
assert out_path.exists(), "No model saved"
assert "tok2vec" in conf["components"], "No tok2vec component"
def test_merge_pipelines(tmp_path):
# width is a placeholder, since we won't actually train this
listener_config = {
"model": {
"tok2vec": {"@architectures": "spacy.Tok2VecListener.v1", "width": "0"}
}
}
# base pipeline
base = spacy.blank("en")
base.add_pipe("tok2vec")
base.add_pipe("ner", config=listener_config)
base_path = tmp_path / "merge_base"
base.to_disk(base_path)
# added pipeline
added = spacy.blank("en")
added.add_pipe("tok2vec")
added.add_pipe("ner", config=listener_config)
added_path = tmp_path / "merge_added"
added.to_disk(added_path)
# these should combine and not have a name collision
out_path = tmp_path / "merge_result"
merged = merge_pipelines(base_path, added_path, out_path)
# will give a key error if not present
merged.get_pipe("ner")
merged.get_pipe("ner2")
ner2_conf = merged.config["components"]["ner2"]
arch = ner2_conf["model"]["tok2vec"]["@architectures"]
assert arch == "spacy.HashEmbedCNN.v2", "Wrong arch - listener not replaced?"
def test_walk_directory():
with make_tempdir() as d:
files = [

View File

@ -7,6 +7,8 @@ menu:
- ['info', 'info']
- ['validate', 'validate']
- ['init', 'init']
- ['configure', 'configure']
- ['merge', 'merge']
- ['convert', 'convert']
- ['debug', 'debug']
- ['train', 'train']
@ -250,6 +252,96 @@ $ python -m spacy init labels [config_path] [output_path] [--code] [--verbose] [
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--paths.train ./train.spacy`. ~~Any (option/flag)~~ |
| **CREATES** | The label files. |
## configure {id="configure", version="TODO"}
Modify or combine existing configs. Example uses include swapping feature
sources or creating configs to resume training. This may simplify parts of the
development cycle. For example, it allows starting off with a config for a
faster, less accurate model (using the CNN tok2vec) and conveniently switching
to transformers later, without having to manually adjust the config.
### configure resume {id="configure-resume", tag="command"}
Modify the input config for use in resuming training. When resuming training,
all components are sourced from the previously trained pipeline.
```cli
$ python -m spacy configure resume [base_model] [output_file]
```
| Name | Description |
| ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `base_model` | A trained pipeline to resume training (package name or path). ~~str (positional)~~ |
| `output_file` | Path to output `.cfg` file or `-` to write the config to stdout (so you can pipe it forward to a file or to the `train` command). Note that if you're writing to stdout, no additional logging info is printed. ~~Path (positional)~~ |
### configure transformer {id="configure-transformer", tag="command"}
Modify the base config to use a transformer component, optionally specifying the
base transformer to use. Useful for converting a CNN tok2vec pipeline to use
transformers.
During development of a model, you can use a CNN tok2vec for faster training
time and reduced hardware requirements, and then use this command to convert
your pipeline to use a transformer once you've verified a proof of concept. This
can also help isolate whether any training issues are transformer-related or
not.
```cli
$ python -m spacy configure transformer [base_model] [output_file] [--transformer_name]
```
| Name | Description |
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `base_model` | A trained pipeline to resume training (package name or path). ~~str (positional)~~ |
| `output_file` | Path to output `.cfg` file or `-` to write the config to stdout (so you can pipe it forward to a file or to the `train` command). Note that if you're writing to stdout, no additional logging info is printed. ~~Path (positional)~~ |
| `transformer_name` | The name of the base HuggingFace model to use. Defaults to `roberta-base`. ~~str (option)~~ |
### configure tok2vec {id="configure-tok2vec", tag="command"}
Modify the base model config to use a CNN tok2vec component. Useful for
generating a config from a transformer-based model for faster training
iteration.
```cli
$ python -m spacy configure tok2vec [base_model] [output_file]
```
| Name | Description |
| ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `base_model` | A trained pipeline to resume training (package name or path). ~~str (positional)~~ |
| `output_file` | Path to output `.cfg` file or `-` to write the config to stdout (so you can pipe it forward to a file or to the `train` command). Note that if you're writing to stdout, no additional logging info is printed. ~~Path (positional)~~ |
## merge {id="merge", tag="command"}
Take two pipelines and create a new one with components from both of them,
handling the configuration of listeners. The output is a serialized pipeline.
Components in the final pipeline are in the same order as in the original
pipelines, with the base pipeline first and the added pipeline after. Because
pipeline names must be unique, if there is a name collision in components, the
later components will be automatically renamed.
For components with listeners, the resulting pipeline structure depends on the
number of listeners. If the second pipeline has only one listener, then
[`replace_listeners`](https://spacy.io/api/language/#replace_listeners) will be
used. If there is more than one listener, `replace_listeners` will not be used.
In the multi-listener case, the resulting pipeline may require more adjustment
for training to work.
This is useful if you have trained a specialized component, such as NER or
textcat, and want to provide with one of the official pretrained pipelines or
another pipeline.
```cli
$ python -m spacy merge [base_model] [added_model] [output_file]
```
| Name | Description |
| ------------- | ---------------------------------------------------------------------------------------- |
| `base_model` | A trained pipeline (package name or path) to use as a base. ~~str (positional)~~ |
| `added_model` | A trained pipeline (package name or path) to combine with the base. ~~str (positional)~~ |
| `output_file` | Path to output pipeline. ~~Path (positional)~~ |
## convert {id="convert",tag="command"}
Convert files into spaCy's