diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index 4041e01b6..003859777 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -29,7 +29,8 @@ from .project.push import project_push # noqa: F401 from .project.pull import project_pull # noqa: F401 from .project.document import project_document # noqa: F401 from .find_threshold import find_threshold # noqa: F401 -from .configure import use_tok2vec, use_transformer # noqa: F401 +from .configure import configure_tok2vec_feature_source # noqa: F401 +from .configure import configure_transformer_feature_source # noqa: F401 from .configure import configure_resume_cli # noqa: F401 from .merge import merge_pipelines # noqa: F401 diff --git a/spacy/cli/configure.py b/spacy/cli/configure.py index c7637cf04..cc36924bc 100644 --- a/spacy/cli/configure.py +++ b/spacy/cli/configure.py @@ -1,5 +1,4 @@ from pathlib import Path -import re from wasabi import msg import typer from thinc.api import Config @@ -80,23 +79,6 @@ def _get_listeners(nlp: Language) -> List[str]: return out -def _increment_suffix(name: str) -> str: - """Given a name, return an incremented version. - - If no numeric suffix is found, return the original with "2" appended. - - This is used to avoid name collisions in pipelines. - """ - - res = re.search(r"\d+$", name) - if res is None: - return f"{name}2" - else: - num = res.group() - prefix = name[0 : -len(num)] - return f"{prefix}{int(num) + 1}" - - def _check_single_tok2vec(name: str, config: Config) -> None: """Check if there is just one tok2vec in a config. @@ -110,31 +92,6 @@ def _check_single_tok2vec(name: str, config: Config) -> None: msg.fail(fail_msg, exits=1) -def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]: - """Given two pipelines, try to rename any collisions in component names. - - If a simple increment of a numeric suffix doesn't work, will give up. - """ - - fail_msg = """ - Tried automatically renaming {name} to {new_name}, but still - had a collision, so bailing out. Please make your pipe names - unique. - """ - - # map of components to be renamed - rename = {} - # check pipeline names - names = nlp.pipe_names - for name in nlp2.pipe_names: - if name in names: - inc = _increment_suffix(name) - if inc in names or inc in nlp2.pipe_names: - msg.fail(fail_msg.format(name=name, new_name=inc), exits=1) - rename[name] = inc - return rename - - @configure_cli.command("resume") def configure_resume_cli( # fmt: off @@ -167,7 +124,7 @@ def configure_resume_cli( @configure_cli.command("transformer") -def use_transformer( +def configure_transformer_feature_source( base_model: str, output_file: Path, transformer_name: str = "roberta-base" ) -> Config: """Replace pipeline tok2vec with transformer.""" @@ -218,7 +175,7 @@ def use_transformer( @configure_cli.command("tok2vec") -def use_tok2vec(base_model: str, output_file: Path) -> Config: +def configure_tok2vec_feature_source(base_model: str, output_file: Path) -> Config: """Replace pipeline tok2vec with CNN tok2vec.""" nlp = spacy.load(base_model) _check_single_tok2vec(base_model, nlp.config) diff --git a/spacy/cli/merge.py b/spacy/cli/merge.py index 528237dae..8ad909bfc 100644 --- a/spacy/cli/merge.py +++ b/spacy/cli/merge.py @@ -1,12 +1,55 @@ from pathlib import Path +import re from wasabi import msg import spacy from spacy.language import Language -from ._util import app, Arg, Opt +from ._util import app, Arg, Opt, Dict from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs -from .configure import _check_pipeline_names, _has_listener +from .configure import _has_listener + + +def _increment_suffix(name: str) -> str: + """Given a name, return an incremented version. + + If no numeric suffix is found, return the original with "2" appended. + + This is used to avoid name collisions in pipelines. + """ + + res = re.search(r"\d+$", name) + if res is None: + return f"{name}2" + else: + num = res.group() + prefix = name[0 : -len(num)] + return f"{prefix}{int(num) + 1}" + + +def _make_unique_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]: + """Given two pipelines, try to rename any collisions in component names. + + If a simple increment of a numeric suffix doesn't work, will give up. + """ + + fail_msg = """ + Tried automatically renaming {name} to {new_name}, but still + had a collision, so bailing out. Please make your pipe names + unique. + """ + + # map of components to be renamed + rename = {} + # check pipeline names + names = nlp.pipe_names + for name in nlp2.pipe_names: + if name in names: + inc = _increment_suffix(name) + if inc in names or inc in nlp2.pipe_names: + msg.fail(fail_msg.format(name=name, new_name=inc), exits=1) + rename[name] = inc + return rename def _inner_merge( @@ -21,9 +64,9 @@ def _inner_merge( returns: assembled pipeline. """ - # we checked earlier, so there's definitely just one + # The outer merge already verified there was exactly one tok2vec tok2vec_name = _get_tok2vecs(nlp2.config)[0] - rename = _check_pipeline_names(nlp, nlp2) + rename = _make_unique_pipeline_names(nlp, nlp2) if len(_get_listeners(nlp2)) > 1: if replace_listeners: diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 242a3a885..7ede6339f 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -21,7 +21,8 @@ from spacy.cli._util import parse_config_overrides, string_to_list from spacy.cli._util import substitute_project_variables from spacy.cli._util import validate_project_commands from spacy.cli._util import upload_file, download_file -from spacy.cli.configure import configure_resume_cli, use_tok2vec +from spacy.cli.configure import configure_resume_cli +from spacy.cli.configure import configure_tok2vec_feature_source from spacy.cli.debug_data import _compile_gold, _get_labels_from_model from spacy.cli.debug_data import _get_labels_from_spancat from spacy.cli.debug_data import _get_distribution, _get_kl_divergence @@ -1206,7 +1207,7 @@ def test_configure_resume(tmp_path): assert "source" in val, f"Non-sourced component: {comp}" -def test_use_tok2vec(tmp_path): +def test_configure_tok2vec_feature_source(tmp_path): # Can't add a transformer here because spacy-transformers might not be present nlp = spacy.blank("en") nlp.add_pipe("tok2vec") @@ -1214,7 +1215,7 @@ def test_use_tok2vec(tmp_path): nlp.to_disk(base_path) out_path = tmp_path / "converted_to_tok2vec" - conf = use_tok2vec(base_path, out_path) + conf = configure_tok2vec_feature_source(base_path, out_path) assert out_path.exists(), "No model saved" assert "tok2vec" in conf["components"], "No tok2vec component"