mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 18:12:00 +03:00
Merge e668c2c21f
into 6aa6b86d49
This commit is contained in:
commit
16145b2188
|
@ -30,6 +30,9 @@ from .project.push import project_push # noqa: F401
|
|||
from .project.pull import project_pull # noqa: F401
|
||||
from .project.document import project_document # noqa: F401
|
||||
from .find_threshold import find_threshold # noqa: F401
|
||||
from .configure import use_transformer, use_tok2vec # noqa: F401
|
||||
from .configure import configure_resume_cli # noqa: F401
|
||||
from .merge import merge_pipelines # noqa: F401
|
||||
|
||||
|
||||
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
|
||||
|
|
|
@ -48,6 +48,7 @@ and custom model implementations.
|
|||
"""
|
||||
BENCHMARK_HELP = """Commands for benchmarking pipelines."""
|
||||
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
||||
CONFIGURE_HELP = """Commands for automatically modifying configs."""
|
||||
|
||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||
# keep the names short, but not needed at the moment.
|
||||
|
@ -59,11 +60,13 @@ benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_he
|
|||
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
||||
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
|
||||
configure_cli = typer.Typer(name="configure", help=CONFIGURE_HELP, no_args_is_help=True)
|
||||
|
||||
app.add_typer(project_cli)
|
||||
app.add_typer(debug_cli)
|
||||
app.add_typer(benchmark_cli)
|
||||
app.add_typer(init_cli)
|
||||
app.add_typer(configure_cli)
|
||||
|
||||
|
||||
def setup_cli() -> None:
|
||||
|
|
214
spacy/cli/configure.py
Normal file
214
spacy/cli/configure.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
import typer
|
||||
from thinc.api import Config
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
|
||||
import spacy
|
||||
from spacy.language import Language
|
||||
|
||||
from ._util import configure_cli, Arg, Opt
|
||||
|
||||
# These are the architectures that are recognized as tok2vec/feature sources.
|
||||
TOK2VEC_ARCHS = [
|
||||
("spacy", "Tok2Vec"),
|
||||
("spacy", "HashEmbedCNN"),
|
||||
("spacy-transformers", "TransformerModel"),
|
||||
]
|
||||
# These are the listeners.
|
||||
LISTENER_ARCHS = [
|
||||
("spacy", "Tok2VecListener"),
|
||||
("spacy-transformers", "TransformerListener"),
|
||||
]
|
||||
|
||||
|
||||
def _deep_get(
|
||||
obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any
|
||||
) -> Any:
|
||||
"""Given a multi-part key, try to get the key. If at any point this isn't
|
||||
possible, return the default.
|
||||
"""
|
||||
out = None
|
||||
slot = obj
|
||||
for notch in key:
|
||||
if slot is None or notch not in slot:
|
||||
return default
|
||||
slot = slot[notch]
|
||||
return slot
|
||||
|
||||
|
||||
def _get_tok2vecs(config: Config) -> List[str]:
|
||||
"""Given a pipeline config, return the names of components that are
|
||||
tok2vecs (or Transformers).
|
||||
"""
|
||||
out = []
|
||||
for name, comp in config["components"].items():
|
||||
arch = _deep_get(comp, ("model", "@architectures"), False)
|
||||
if not arch:
|
||||
continue
|
||||
|
||||
ns, model, ver = arch.split(".")
|
||||
if (ns, model) in TOK2VEC_ARCHS:
|
||||
out.append(name)
|
||||
return out
|
||||
|
||||
|
||||
def _has_listener(nlp: Language, pipe_name: str):
|
||||
"""Given a pipeline and a component name, check if it has a listener."""
|
||||
arch = _deep_get(
|
||||
nlp.config,
|
||||
("components", pipe_name, "model", "tok2vec", "@architectures"),
|
||||
False,
|
||||
)
|
||||
if not arch:
|
||||
return False
|
||||
ns, model, ver = arch.split(".")
|
||||
return (ns, model) in LISTENER_ARCHS
|
||||
|
||||
|
||||
def _get_listeners(nlp: Language) -> List[str]:
|
||||
"""Get the name of every component that contains a listener.
|
||||
|
||||
Does not check that they listen to the same thing; assumes a pipeline has
|
||||
only one feature source.
|
||||
"""
|
||||
out = []
|
||||
for name in nlp.pipe_names:
|
||||
if _has_listener(nlp, name):
|
||||
out.append(name)
|
||||
return out
|
||||
|
||||
|
||||
def _check_single_tok2vec(name: str, config: Config) -> None:
|
||||
"""Check if there is just one tok2vec in a config.
|
||||
|
||||
A very simple check, but used in multiple functions.
|
||||
"""
|
||||
tok2vecs = _get_tok2vecs(config)
|
||||
fail_msg = f"""
|
||||
Can't handle pipelines with more than one feature source,
|
||||
but {name} has {len(tok2vecs)}."""
|
||||
if len(tok2vecs) > 1:
|
||||
msg.fail(fail_msg, exits=1)
|
||||
|
||||
|
||||
@configure_cli.command("resume")
|
||||
def configure_resume_cli(
|
||||
# fmt: off
|
||||
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
|
||||
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
|
||||
# fmt: on
|
||||
) -> Config:
|
||||
"""Create a config for resuming training.
|
||||
|
||||
A config for resuming training is the same as the input config, but with
|
||||
all components sourced.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#configure-resume
|
||||
"""
|
||||
|
||||
nlp = spacy.load(base_model)
|
||||
conf = nlp.config
|
||||
|
||||
# Paths are not JSON serializable
|
||||
path_str = str(base_model)
|
||||
|
||||
for comp in nlp.pipe_names:
|
||||
conf["components"][comp] = {"source": path_str}
|
||||
|
||||
if str(output_file) == "-":
|
||||
print(conf.to_str())
|
||||
else:
|
||||
conf.to_disk(output_file)
|
||||
msg.good("Saved config", output_file)
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
@configure_cli.command("transformer")
|
||||
def use_transformer(
|
||||
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
|
||||
) -> Config:
|
||||
"""Replace pipeline tok2vec with transformer.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#configure-transformer
|
||||
"""
|
||||
|
||||
# 1. identify tok2vec
|
||||
# 2. replace tok2vec
|
||||
# 3. replace listeners
|
||||
nlp = spacy.load(base_model)
|
||||
_check_single_tok2vec(base_model, nlp.config)
|
||||
|
||||
tok2vecs = _get_tok2vecs(nlp.config)
|
||||
assert len(tok2vecs) > 0, "Must have tok2vec to replace!"
|
||||
|
||||
nlp.remove_pipe(tok2vecs[0])
|
||||
# the rest can be default values
|
||||
trf_config = {
|
||||
"model": {
|
||||
"name": transformer_name,
|
||||
}
|
||||
}
|
||||
try:
|
||||
trf = nlp.add_pipe("transformer", config=trf_config, first=True)
|
||||
except ValueError:
|
||||
fail_msg = (
|
||||
"Configuring a transformer requires spacy-transformers. "
|
||||
"Install with: pip install spacy-transformers"
|
||||
)
|
||||
msg.fail(fail_msg, exits=1)
|
||||
|
||||
# now update the listeners
|
||||
listeners = _get_listeners(nlp)
|
||||
for listener in listeners:
|
||||
listener_config = {
|
||||
"@architectures": "spacy-transformers.TransformerListener.v1",
|
||||
"grad_factor": 1.0,
|
||||
"upstream": "transformer",
|
||||
"pooling": {"@layers": "reduce_mean.v1"},
|
||||
}
|
||||
nlp.config["components"][listener]["model"]["tok2vec"] = listener_config
|
||||
|
||||
if str(output_file) == "-":
|
||||
print(nlp.config.to_str())
|
||||
else:
|
||||
nlp.config.to_disk(output_file)
|
||||
msg.good("Saved config", output_file)
|
||||
|
||||
return nlp.config
|
||||
|
||||
|
||||
@configure_cli.command("tok2vec")
|
||||
def use_tok2vec(base_model: str, output_file: Path) -> Config:
|
||||
"""Replace pipeline tok2vec with CNN tok2vec.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#configure-tok2vec
|
||||
"""
|
||||
nlp = spacy.load(base_model)
|
||||
_check_single_tok2vec(base_model, nlp.config)
|
||||
|
||||
tok2vecs = _get_tok2vecs(nlp.config)
|
||||
assert len(tok2vecs) > 0, "Must have tok2vec to replace!"
|
||||
|
||||
nlp.remove_pipe(tok2vecs[0])
|
||||
|
||||
tok2vec = nlp.add_pipe("tok2vec", first=True)
|
||||
width = "${components.tok2vec.model.encode:width}"
|
||||
|
||||
listeners = _get_listeners(nlp)
|
||||
for listener in listeners:
|
||||
listener_config = {
|
||||
"@architectures": "spacy.Tok2VecListener.v1",
|
||||
"width": width,
|
||||
"upstream": "tok2vec",
|
||||
}
|
||||
nlp.config["components"][listener]["model"]["tok2vec"] = listener_config
|
||||
|
||||
if str(output_file) == "-":
|
||||
print(nlp.config.to_str())
|
||||
else:
|
||||
nlp.config.to_disk(output_file)
|
||||
msg.good("Saved config", output_file)
|
||||
|
||||
return nlp.config
|
160
spacy/cli/merge.py
Normal file
160
spacy/cli/merge.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
from wasabi import msg
|
||||
|
||||
import spacy
|
||||
from spacy.language import Language
|
||||
|
||||
from ._util import app, Arg, Opt, Dict
|
||||
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
|
||||
from .configure import _has_listener
|
||||
|
||||
|
||||
def _increment_suffix(name: str) -> str:
|
||||
"""Given a name, return an incremented version.
|
||||
|
||||
If no numeric suffix is found, return the original with "2" appended.
|
||||
|
||||
This is used to avoid name collisions in pipelines.
|
||||
"""
|
||||
|
||||
res = re.search(r"\d+$", name)
|
||||
if res is None:
|
||||
return f"{name}2"
|
||||
else:
|
||||
num = res.group()
|
||||
prefix = name[0 : -len(num)]
|
||||
return f"{prefix}{int(num) + 1}"
|
||||
|
||||
|
||||
def _make_unique_pipe_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
||||
"""Given two pipelines, try to rename any collisions in component names.
|
||||
|
||||
If a simple increment of a numeric suffix doesn't work, will give up.
|
||||
"""
|
||||
|
||||
fail_msg = """
|
||||
Tried automatically renaming {name} to {new_name}, but still
|
||||
had a collision, so bailing out. Please make your pipe names
|
||||
unique.
|
||||
"""
|
||||
|
||||
# map of components to be renamed
|
||||
rename = {}
|
||||
# check pipeline names
|
||||
names = nlp.pipe_names
|
||||
for name in nlp2.pipe_names:
|
||||
if name in names:
|
||||
inc = _increment_suffix(name)
|
||||
if inc in names or inc in nlp2.pipe_names:
|
||||
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
|
||||
rename[name] = inc
|
||||
return rename
|
||||
|
||||
|
||||
def _inner_merge(
|
||||
nlp: Language, nlp2: Language, replace_listeners: bool = False
|
||||
) -> Language:
|
||||
"""Actually do the merge.
|
||||
|
||||
nlp (Language): Base pipeline to add components to.
|
||||
nlp2 (Language): Pipeline to add components from.
|
||||
replace_listeners (bool): Whether to replace listeners. Usually only true
|
||||
if there's one listener.
|
||||
returns: assembled pipeline.
|
||||
"""
|
||||
|
||||
# The outer merge already verified there was exactly one tok2vec
|
||||
tok2vec_name = _get_tok2vecs(nlp2.config)[0]
|
||||
rename = _make_unique_pipe_names(nlp, nlp2)
|
||||
|
||||
if len(_get_listeners(nlp2)) > 1:
|
||||
if replace_listeners:
|
||||
msg.warn(
|
||||
"""
|
||||
Replacing listeners for multiple components. Note this can make
|
||||
your pipeline large and slow. Consider chaining pipelines (like
|
||||
nlp2(nlp(text))) instead.
|
||||
"""
|
||||
)
|
||||
else:
|
||||
# TODO provide a guide for what to do here
|
||||
msg.warn(
|
||||
"""
|
||||
The result of this merge will have two feature sources
|
||||
(tok2vecs) and multiple listeners. This will work for
|
||||
inference, but will probably not work when training without
|
||||
extra adjustment. If you continue to train the pipelines
|
||||
separately this is not a problem.
|
||||
"""
|
||||
)
|
||||
|
||||
for comp in nlp2.pipe_names:
|
||||
if replace_listeners and comp == tok2vec_name:
|
||||
# the tok2vec should not be copied over
|
||||
continue
|
||||
if replace_listeners and _has_listener(nlp2, comp):
|
||||
nlp2.replace_listeners(tok2vec_name, comp, ["model.tok2vec"])
|
||||
nlp.add_pipe(comp, source=nlp2, name=rename.get(comp, comp))
|
||||
if comp in rename:
|
||||
msg.info(f"Renaming {comp} to {rename[comp]} to avoid collision...")
|
||||
return nlp
|
||||
|
||||
|
||||
@app.command("merge")
|
||||
def merge_pipelines(
|
||||
# fmt: off
|
||||
base_model: str = Arg(..., help="Name or path of base model"),
|
||||
added_model: str = Arg(..., help="Name or path of model to be merged"),
|
||||
output_file: Path = Arg(..., help="Path to save merged model")
|
||||
# fmt: on
|
||||
) -> Language:
|
||||
"""Combine components from multiple pipelines.
|
||||
|
||||
Given two pipelines, the components from them are merged into a single
|
||||
pipeline. The exact way this works depends on whether the second pipeline
|
||||
has one listener or more than one listener. In the single listener case
|
||||
`replace_listeners` is used, otherwise components are simply appended to
|
||||
the base pipeline.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#merge
|
||||
"""
|
||||
nlp = spacy.load(base_model)
|
||||
nlp2 = spacy.load(added_model)
|
||||
|
||||
# to merge models:
|
||||
# - lang must be the same
|
||||
# - vectors must be the same
|
||||
# - vocabs must be the same
|
||||
# - tokenizer must be the same (only partially checkable)
|
||||
if nlp.lang != nlp2.lang:
|
||||
msg.fail("Can't merge - languages don't match", exits=1)
|
||||
|
||||
# check vector equality
|
||||
if (
|
||||
nlp.vocab.vectors.shape != nlp2.vocab.vectors.shape
|
||||
or nlp.vocab.vectors.key2row != nlp2.vocab.vectors.key2row
|
||||
or nlp.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
!= nlp2.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
):
|
||||
msg.fail("Can't merge - vectors don't match", exits=1)
|
||||
|
||||
if nlp.config["nlp"]["tokenizer"] != nlp2.config["nlp"]["tokenizer"]:
|
||||
msg.fail("Can't merge - tokenizers don't match", exits=1)
|
||||
|
||||
# Check that each pipeline only has one feature source
|
||||
_check_single_tok2vec(base_model, nlp.config)
|
||||
_check_single_tok2vec(added_model, nlp2.config)
|
||||
|
||||
# Check how many listeners there are and replace based on that
|
||||
# TODO: option to recognize frozen tok2vecs
|
||||
# TODO: take list of pipe names to copy, ignore others
|
||||
listeners = _get_listeners(nlp2)
|
||||
replace_listeners = len(listeners) == 1
|
||||
nlp_out = _inner_merge(nlp, nlp2, replace_listeners=replace_listeners)
|
||||
|
||||
# write the final pipeline
|
||||
nlp.to_disk(output_file)
|
||||
msg.info(f"Saved pipeline to: {output_file}")
|
||||
|
||||
return nlp
|
|
@ -21,6 +21,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list
|
|||
from spacy.cli._util import substitute_project_variables
|
||||
from spacy.cli._util import validate_project_commands
|
||||
from spacy.cli._util import upload_file, download_file
|
||||
from spacy.cli.configure import configure_resume_cli, use_tok2vec
|
||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||
|
@ -29,6 +30,7 @@ from spacy.cli.debug_data import _print_span_characteristics
|
|||
from spacy.cli.debug_data import _get_spans_length_freq_dist
|
||||
from spacy.cli.download import get_compatibility, get_version
|
||||
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
||||
from spacy.cli.merge import merge_pipelines
|
||||
from spacy.cli.package import get_third_party_dependencies
|
||||
from spacy.cli.package import _is_permitted_package_name
|
||||
from spacy.cli.project.remote_storage import RemoteStorage
|
||||
|
@ -1180,6 +1182,71 @@ def test_upload_download_local_file():
|
|||
assert file_.read() == content
|
||||
|
||||
|
||||
def test_configure_resume(tmp_path):
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe("ner")
|
||||
nlp.add_pipe("textcat")
|
||||
base_path = tmp_path / "base"
|
||||
nlp.to_disk(base_path)
|
||||
|
||||
out_path = tmp_path / "resume.cfg"
|
||||
conf = configure_resume_cli(base_path, out_path)
|
||||
|
||||
assert out_path.exists(), "Didn't save config file"
|
||||
|
||||
for comp, val in conf["components"].items():
|
||||
assert "source" in val, f"Non-sourced component: {comp}"
|
||||
|
||||
|
||||
def test_use_tok2vec(tmp_path):
|
||||
# Can't add a transformer here because spacy-transformers might not be present
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe("tok2vec")
|
||||
base_path = tmp_path / "tok2vec_sample_2"
|
||||
nlp.to_disk(base_path)
|
||||
|
||||
out_path = tmp_path / "converted_to_tok2vec"
|
||||
conf = use_tok2vec(base_path, out_path)
|
||||
assert out_path.exists(), "No model saved"
|
||||
|
||||
assert "tok2vec" in conf["components"], "No tok2vec component"
|
||||
|
||||
|
||||
def test_merge_pipelines(tmp_path):
|
||||
|
||||
# width is a placeholder, since we won't actually train this
|
||||
listener_config = {
|
||||
"model": {
|
||||
"tok2vec": {"@architectures": "spacy.Tok2VecListener.v1", "width": "0"}
|
||||
}
|
||||
}
|
||||
# base pipeline
|
||||
base = spacy.blank("en")
|
||||
base.add_pipe("tok2vec")
|
||||
base.add_pipe("ner", config=listener_config)
|
||||
base_path = tmp_path / "merge_base"
|
||||
base.to_disk(base_path)
|
||||
|
||||
# added pipeline
|
||||
added = spacy.blank("en")
|
||||
added.add_pipe("tok2vec")
|
||||
added.add_pipe("ner", config=listener_config)
|
||||
added_path = tmp_path / "merge_added"
|
||||
added.to_disk(added_path)
|
||||
|
||||
# these should combine and not have a name collision
|
||||
out_path = tmp_path / "merge_result"
|
||||
merged = merge_pipelines(base_path, added_path, out_path)
|
||||
|
||||
# will give a key error if not present
|
||||
merged.get_pipe("ner")
|
||||
merged.get_pipe("ner2")
|
||||
|
||||
ner2_conf = merged.config["components"]["ner2"]
|
||||
arch = ner2_conf["model"]["tok2vec"]["@architectures"]
|
||||
assert arch == "spacy.HashEmbedCNN.v2", "Wrong arch - listener not replaced?"
|
||||
|
||||
|
||||
def test_walk_directory():
|
||||
with make_tempdir() as d:
|
||||
files = [
|
||||
|
|
|
@ -7,6 +7,8 @@ menu:
|
|||
- ['info', 'info']
|
||||
- ['validate', 'validate']
|
||||
- ['init', 'init']
|
||||
- ['configure', 'configure']
|
||||
- ['merge', 'merge']
|
||||
- ['convert', 'convert']
|
||||
- ['debug', 'debug']
|
||||
- ['train', 'train']
|
||||
|
@ -250,6 +252,96 @@ $ python -m spacy init labels [config_path] [output_path] [--code] [--verbose] [
|
|||
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--paths.train ./train.spacy`. ~~Any (option/flag)~~ |
|
||||
| **CREATES** | The label files. |
|
||||
|
||||
## configure {id="configure", version="TODO"}
|
||||
|
||||
Modify or combine existing configs. Example uses include swapping feature
|
||||
sources or creating configs to resume training. This may simplify parts of the
|
||||
development cycle. For example, it allows starting off with a config for a
|
||||
faster, less accurate model (using the CNN tok2vec) and conveniently switching
|
||||
to transformers later, without having to manually adjust the config.
|
||||
|
||||
### configure resume {id="configure-resume", tag="command"}
|
||||
|
||||
Modify the input config for use in resuming training. When resuming training,
|
||||
all components are sourced from the previously trained pipeline.
|
||||
|
||||
```cli
|
||||
$ python -m spacy configure resume [base_model] [output_file]
|
||||
```
|
||||
|
||||
| Name | Description |
|
||||
| ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `base_model` | A trained pipeline to resume training (package name or path). ~~str (positional)~~ |
|
||||
| `output_file` | Path to output `.cfg` file or `-` to write the config to stdout (so you can pipe it forward to a file or to the `train` command). Note that if you're writing to stdout, no additional logging info is printed. ~~Path (positional)~~ |
|
||||
|
||||
### configure transformer {id="configure-transformer", tag="command"}
|
||||
|
||||
Modify the base config to use a transformer component, optionally specifying the
|
||||
base transformer to use. Useful for converting a CNN tok2vec pipeline to use
|
||||
transformers.
|
||||
|
||||
During development of a model, you can use a CNN tok2vec for faster training
|
||||
time and reduced hardware requirements, and then use this command to convert
|
||||
your pipeline to use a transformer once you've verified a proof of concept. This
|
||||
can also help isolate whether any training issues are transformer-related or
|
||||
not.
|
||||
|
||||
```cli
|
||||
$ python -m spacy configure transformer [base_model] [output_file] [--transformer_name]
|
||||
```
|
||||
|
||||
| Name | Description |
|
||||
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `base_model` | A trained pipeline to resume training (package name or path). ~~str (positional)~~ |
|
||||
| `output_file` | Path to output `.cfg` file or `-` to write the config to stdout (so you can pipe it forward to a file or to the `train` command). Note that if you're writing to stdout, no additional logging info is printed. ~~Path (positional)~~ |
|
||||
| `transformer_name` | The name of the base HuggingFace model to use. Defaults to `roberta-base`. ~~str (option)~~ |
|
||||
|
||||
### configure tok2vec {id="configure-tok2vec", tag="command"}
|
||||
|
||||
Modify the base model config to use a CNN tok2vec component. Useful for
|
||||
generating a config from a transformer-based model for faster training
|
||||
iteration.
|
||||
|
||||
```cli
|
||||
$ python -m spacy configure tok2vec [base_model] [output_file]
|
||||
```
|
||||
|
||||
| Name | Description |
|
||||
| ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `base_model` | A trained pipeline to resume training (package name or path). ~~str (positional)~~ |
|
||||
| `output_file` | Path to output `.cfg` file or `-` to write the config to stdout (so you can pipe it forward to a file or to the `train` command). Note that if you're writing to stdout, no additional logging info is printed. ~~Path (positional)~~ |
|
||||
|
||||
## merge {id="merge", tag="command"}
|
||||
|
||||
Take two pipelines and create a new one with components from both of them,
|
||||
handling the configuration of listeners. The output is a serialized pipeline.
|
||||
|
||||
Components in the final pipeline are in the same order as in the original
|
||||
pipelines, with the base pipeline first and the added pipeline after. Because
|
||||
pipeline names must be unique, if there is a name collision in components, the
|
||||
later components will be automatically renamed.
|
||||
|
||||
For components with listeners, the resulting pipeline structure depends on the
|
||||
number of listeners. If the second pipeline has only one listener, then
|
||||
[`replace_listeners`](https://spacy.io/api/language/#replace_listeners) will be
|
||||
used. If there is more than one listener, `replace_listeners` will not be used.
|
||||
In the multi-listener case, the resulting pipeline may require more adjustment
|
||||
for training to work.
|
||||
|
||||
This is useful if you have trained a specialized component, such as NER or
|
||||
textcat, and want to provide with one of the official pretrained pipelines or
|
||||
another pipeline.
|
||||
|
||||
```cli
|
||||
$ python -m spacy merge [base_model] [added_model] [output_file]
|
||||
```
|
||||
|
||||
| Name | Description |
|
||||
| ------------- | ---------------------------------------------------------------------------------------- |
|
||||
| `base_model` | A trained pipeline (package name or path) to use as a base. ~~str (positional)~~ |
|
||||
| `added_model` | A trained pipeline (package name or path) to combine with the base. ~~str (positional)~~ |
|
||||
| `output_file` | Path to output pipeline. ~~Path (positional)~~ |
|
||||
|
||||
## convert {id="convert",tag="command"}
|
||||
|
||||
Convert files into spaCy's
|
||||
|
|
Loading…
Reference in New Issue
Block a user