spaCy/spacy/cli/configure.py
Paul O'Leary McCann c0a3e9a44a Code reorganization
2023-02-10 15:12:40 +09:00

207 lines
5.9 KiB
Python

from pathlib import Path
from wasabi import msg
import typer
from thinc.api import Config
from typing import Any, Dict, Iterable, List, Union
import spacy
from spacy.language import Language
from ._util import configure_cli, Arg, Opt
# These are the architectures that are recognized as tok2vec/feature sources.
TOK2VEC_ARCHS = [
("spacy", "Tok2Vec"),
("spacy", "HashEmbedCNN"),
("spacy-transformers", "TransformerModel"),
]
# These are the listeners.
LISTENER_ARCHS = [
("spacy", "Tok2VecListener"),
("spacy-transformers", "TransformerListener"),
]
def _deep_get(
obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any
) -> Any:
"""Given a multi-part key, try to get the key. If at any point this isn't
possible, return the default.
"""
out = None
slot = obj
for notch in key:
if slot is None or notch not in slot:
return default
slot = slot[notch]
return slot
def _get_tok2vecs(config: Config) -> List[str]:
"""Given a pipeline config, return the names of components that are
tok2vecs (or Transformers).
"""
out = []
for name, comp in config["components"].items():
arch = _deep_get(comp, ("model", "@architectures"), False)
if not arch:
continue
ns, model, ver = arch.split(".")
if (ns, model) in TOK2VEC_ARCHS:
out.append(name)
return out
def _has_listener(nlp: Language, pipe_name: str):
"""Given a pipeline and a component name, check if it has a listener."""
arch = _deep_get(
nlp.config,
("components", pipe_name, "model", "tok2vec", "@architectures"),
False,
)
if not arch:
return False
ns, model, ver = arch.split(".")
return (ns, model) in LISTENER_ARCHS
def _get_listeners(nlp: Language) -> List[str]:
"""Get the name of every component that contains a listener.
Does not check that they listen to the same thing; assumes a pipeline has
only one feature source.
"""
out = []
for name in nlp.pipe_names:
if _has_listener(nlp, name):
out.append(name)
return out
def _check_single_tok2vec(name: str, config: Config) -> None:
"""Check if there is just one tok2vec in a config.
A very simple check, but used in multiple functions.
"""
tok2vecs = _get_tok2vecs(config)
fail_msg = f"""
Can't handle pipelines with more than one feature source,
but {name} has {len(tok2vecs)}."""
if len(tok2vecs) > 1:
msg.fail(fail_msg, exits=1)
@configure_cli.command("resume")
def configure_resume_cli(
# fmt: off
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
# fmt: on
) -> Config:
"""Create a config for resuming training.
A config for resuming training is the same as the input config, but with
all components sourced.
"""
nlp = spacy.load(base_model)
conf = nlp.config
# Paths are not JSON serializable
path_str = str(base_model)
for comp in nlp.pipe_names:
conf["components"][comp] = {"source": path_str}
if str(output_file) == "-":
print(conf.to_str())
else:
conf.to_disk(output_file)
msg.good("Saved config", output_file)
return conf
@configure_cli.command("transformer")
def configure_transformer_feature_source(
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
) -> Config:
"""Replace pipeline tok2vec with transformer."""
# 1. identify tok2vec
# 2. replace tok2vec
# 3. replace listeners
nlp = spacy.load(base_model)
_check_single_tok2vec(base_model, nlp.config)
tok2vecs = _get_tok2vecs(nlp.config)
assert len(tok2vecs) > 0, "Must have tok2vec to replace!"
nlp.remove_pipe(tok2vecs[0])
# the rest can be default values
trf_config = {
"model": {
"name": transformer_name,
}
}
try:
trf = nlp.add_pipe("transformer", config=trf_config, first=True)
except ValueError:
fail_msg = (
"Configuring a transformer requires spacy-transformers. "
"Install with: pip install spacy-transformers"
)
msg.fail(fail_msg, exits=1)
# now update the listeners
listeners = _get_listeners(nlp)
for listener in listeners:
listener_config = {
"@architectures": "spacy-transformers.TransformerListener.v1",
"grad_factor": 1.0,
"upstream": "transformer",
"pooling": {"@layers": "reduce_mean.v1"},
}
nlp.config["components"][listener]["model"]["tok2vec"] = listener_config
if str(output_file) == "-":
print(nlp.config.to_str())
else:
nlp.config.to_disk(output_file)
msg.good("Saved config", output_file)
return nlp.config
@configure_cli.command("tok2vec")
def configure_tok2vec_feature_source(base_model: str, output_file: Path) -> Config:
"""Replace pipeline tok2vec with CNN tok2vec."""
nlp = spacy.load(base_model)
_check_single_tok2vec(base_model, nlp.config)
tok2vecs = _get_tok2vecs(nlp.config)
assert len(tok2vecs) > 0, "Must have tok2vec to replace!"
nlp.remove_pipe(tok2vecs[0])
tok2vec = nlp.add_pipe("tok2vec", first=True)
width = "${components.tok2vec.model.encode:width}"
listeners = _get_listeners(nlp)
for listener in listeners:
listener_config = {
"@architectures": "spacy.Tok2VecListener.v1",
"width": width,
"upstream": "tok2vec",
}
nlp.config["components"][listener]["model"]["tok2vec"] = listener_config
if str(output_file) == "-":
print(nlp.config.to_str())
else:
nlp.config.to_disk(output_file)
msg.good("Saved config", output_file)
return nlp.config