Code reorganization

This commit is contained in:
Paul O'Leary McCann 2023-02-10 15:12:40 +09:00
parent a76fd0da99
commit c0a3e9a44a
4 changed files with 55 additions and 53 deletions

View File

@ -29,7 +29,8 @@ from .project.push import project_push # noqa: F401
from .project.pull import project_pull # noqa: F401 from .project.pull import project_pull # noqa: F401
from .project.document import project_document # noqa: F401 from .project.document import project_document # noqa: F401
from .find_threshold import find_threshold # noqa: F401 from .find_threshold import find_threshold # noqa: F401
from .configure import use_tok2vec, use_transformer # noqa: F401 from .configure import configure_tok2vec_feature_source # noqa: F401
from .configure import configure_transformer_feature_source # noqa: F401
from .configure import configure_resume_cli # noqa: F401 from .configure import configure_resume_cli # noqa: F401
from .merge import merge_pipelines # noqa: F401 from .merge import merge_pipelines # noqa: F401

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
import re
from wasabi import msg from wasabi import msg
import typer import typer
from thinc.api import Config from thinc.api import Config
@ -80,23 +79,6 @@ def _get_listeners(nlp: Language) -> List[str]:
return out return out
def _increment_suffix(name: str) -> str:
"""Given a name, return an incremented version.
If no numeric suffix is found, return the original with "2" appended.
This is used to avoid name collisions in pipelines.
"""
res = re.search(r"\d+$", name)
if res is None:
return f"{name}2"
else:
num = res.group()
prefix = name[0 : -len(num)]
return f"{prefix}{int(num) + 1}"
def _check_single_tok2vec(name: str, config: Config) -> None: def _check_single_tok2vec(name: str, config: Config) -> None:
"""Check if there is just one tok2vec in a config. """Check if there is just one tok2vec in a config.
@ -110,31 +92,6 @@ def _check_single_tok2vec(name: str, config: Config) -> None:
msg.fail(fail_msg, exits=1) msg.fail(fail_msg, exits=1)
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
"""Given two pipelines, try to rename any collisions in component names.
If a simple increment of a numeric suffix doesn't work, will give up.
"""
fail_msg = """
Tried automatically renaming {name} to {new_name}, but still
had a collision, so bailing out. Please make your pipe names
unique.
"""
# map of components to be renamed
rename = {}
# check pipeline names
names = nlp.pipe_names
for name in nlp2.pipe_names:
if name in names:
inc = _increment_suffix(name)
if inc in names or inc in nlp2.pipe_names:
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
rename[name] = inc
return rename
@configure_cli.command("resume") @configure_cli.command("resume")
def configure_resume_cli( def configure_resume_cli(
# fmt: off # fmt: off
@ -167,7 +124,7 @@ def configure_resume_cli(
@configure_cli.command("transformer") @configure_cli.command("transformer")
def use_transformer( def configure_transformer_feature_source(
base_model: str, output_file: Path, transformer_name: str = "roberta-base" base_model: str, output_file: Path, transformer_name: str = "roberta-base"
) -> Config: ) -> Config:
"""Replace pipeline tok2vec with transformer.""" """Replace pipeline tok2vec with transformer."""
@ -218,7 +175,7 @@ def use_transformer(
@configure_cli.command("tok2vec") @configure_cli.command("tok2vec")
def use_tok2vec(base_model: str, output_file: Path) -> Config: def configure_tok2vec_feature_source(base_model: str, output_file: Path) -> Config:
"""Replace pipeline tok2vec with CNN tok2vec.""" """Replace pipeline tok2vec with CNN tok2vec."""
nlp = spacy.load(base_model) nlp = spacy.load(base_model)
_check_single_tok2vec(base_model, nlp.config) _check_single_tok2vec(base_model, nlp.config)

View File

@ -1,12 +1,55 @@
from pathlib import Path from pathlib import Path
import re
from wasabi import msg from wasabi import msg
import spacy import spacy
from spacy.language import Language from spacy.language import Language
from ._util import app, Arg, Opt from ._util import app, Arg, Opt, Dict
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
from .configure import _check_pipeline_names, _has_listener from .configure import _has_listener
def _increment_suffix(name: str) -> str:
"""Given a name, return an incremented version.
If no numeric suffix is found, return the original with "2" appended.
This is used to avoid name collisions in pipelines.
"""
res = re.search(r"\d+$", name)
if res is None:
return f"{name}2"
else:
num = res.group()
prefix = name[0 : -len(num)]
return f"{prefix}{int(num) + 1}"
def _make_unique_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
"""Given two pipelines, try to rename any collisions in component names.
If a simple increment of a numeric suffix doesn't work, will give up.
"""
fail_msg = """
Tried automatically renaming {name} to {new_name}, but still
had a collision, so bailing out. Please make your pipe names
unique.
"""
# map of components to be renamed
rename = {}
# check pipeline names
names = nlp.pipe_names
for name in nlp2.pipe_names:
if name in names:
inc = _increment_suffix(name)
if inc in names or inc in nlp2.pipe_names:
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
rename[name] = inc
return rename
def _inner_merge( def _inner_merge(
@ -21,9 +64,9 @@ def _inner_merge(
returns: assembled pipeline. returns: assembled pipeline.
""" """
# we checked earlier, so there's definitely just one # The outer merge already verified there was exactly one tok2vec
tok2vec_name = _get_tok2vecs(nlp2.config)[0] tok2vec_name = _get_tok2vecs(nlp2.config)[0]
rename = _check_pipeline_names(nlp, nlp2) rename = _make_unique_pipeline_names(nlp, nlp2)
if len(_get_listeners(nlp2)) > 1: if len(_get_listeners(nlp2)) > 1:
if replace_listeners: if replace_listeners:

View File

@ -21,7 +21,8 @@ from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands from spacy.cli._util import validate_project_commands
from spacy.cli._util import upload_file, download_file from spacy.cli._util import upload_file, download_file
from spacy.cli.configure import configure_resume_cli, use_tok2vec from spacy.cli.configure import configure_resume_cli
from spacy.cli.configure import configure_tok2vec_feature_source
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
@ -1206,7 +1207,7 @@ def test_configure_resume(tmp_path):
assert "source" in val, f"Non-sourced component: {comp}" assert "source" in val, f"Non-sourced component: {comp}"
def test_use_tok2vec(tmp_path): def test_configure_tok2vec_feature_source(tmp_path):
# Can't add a transformer here because spacy-transformers might not be present # Can't add a transformer here because spacy-transformers might not be present
nlp = spacy.blank("en") nlp = spacy.blank("en")
nlp.add_pipe("tok2vec") nlp.add_pipe("tok2vec")
@ -1214,7 +1215,7 @@ def test_use_tok2vec(tmp_path):
nlp.to_disk(base_path) nlp.to_disk(base_path)
out_path = tmp_path / "converted_to_tok2vec" out_path = tmp_path / "converted_to_tok2vec"
conf = use_tok2vec(base_path, out_path) conf = configure_tok2vec_feature_source(base_path, out_path)
assert out_path.exists(), "No model saved" assert out_path.exists(), "No model saved"
assert "tok2vec" in conf["components"], "No tok2vec component" assert "tok2vec" in conf["components"], "No tok2vec component"