Code reorganization

This commit is contained in:
Paul O'Leary McCann 2023-02-10 15:12:40 +09:00
parent a76fd0da99
commit c0a3e9a44a
4 changed files with 55 additions and 53 deletions

View File

@ -29,7 +29,8 @@ from .project.push import project_push # noqa: F401
from .project.pull import project_pull # noqa: F401
from .project.document import project_document # noqa: F401
from .find_threshold import find_threshold # noqa: F401
from .configure import use_tok2vec, use_transformer # noqa: F401
from .configure import configure_tok2vec_feature_source # noqa: F401
from .configure import configure_transformer_feature_source # noqa: F401
from .configure import configure_resume_cli # noqa: F401
from .merge import merge_pipelines # noqa: F401

View File

@ -1,5 +1,4 @@
from pathlib import Path
import re
from wasabi import msg
import typer
from thinc.api import Config
@ -80,23 +79,6 @@ def _get_listeners(nlp: Language) -> List[str]:
return out
def _increment_suffix(name: str) -> str:
"""Given a name, return an incremented version.
If no numeric suffix is found, return the original with "2" appended.
This is used to avoid name collisions in pipelines.
"""
res = re.search(r"\d+$", name)
if res is None:
return f"{name}2"
else:
num = res.group()
prefix = name[0 : -len(num)]
return f"{prefix}{int(num) + 1}"
def _check_single_tok2vec(name: str, config: Config) -> None:
"""Check if there is just one tok2vec in a config.
@ -110,31 +92,6 @@ def _check_single_tok2vec(name: str, config: Config) -> None:
msg.fail(fail_msg, exits=1)
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
"""Given two pipelines, try to rename any collisions in component names.
If a simple increment of a numeric suffix doesn't work, will give up.
"""
fail_msg = """
Tried automatically renaming {name} to {new_name}, but still
had a collision, so bailing out. Please make your pipe names
unique.
"""
# map of components to be renamed
rename = {}
# check pipeline names
names = nlp.pipe_names
for name in nlp2.pipe_names:
if name in names:
inc = _increment_suffix(name)
if inc in names or inc in nlp2.pipe_names:
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
rename[name] = inc
return rename
@configure_cli.command("resume")
def configure_resume_cli(
# fmt: off
@ -167,7 +124,7 @@ def configure_resume_cli(
@configure_cli.command("transformer")
def use_transformer(
def configure_transformer_feature_source(
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
) -> Config:
"""Replace pipeline tok2vec with transformer."""
@ -218,7 +175,7 @@ def use_transformer(
@configure_cli.command("tok2vec")
def use_tok2vec(base_model: str, output_file: Path) -> Config:
def configure_tok2vec_feature_source(base_model: str, output_file: Path) -> Config:
"""Replace pipeline tok2vec with CNN tok2vec."""
nlp = spacy.load(base_model)
_check_single_tok2vec(base_model, nlp.config)

View File

@ -1,12 +1,55 @@
from pathlib import Path
import re
from wasabi import msg
import spacy
from spacy.language import Language
from ._util import app, Arg, Opt
from ._util import app, Arg, Opt, Dict
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
from .configure import _check_pipeline_names, _has_listener
from .configure import _has_listener
def _increment_suffix(name: str) -> str:
"""Given a name, return an incremented version.
If no numeric suffix is found, return the original with "2" appended.
This is used to avoid name collisions in pipelines.
"""
res = re.search(r"\d+$", name)
if res is None:
return f"{name}2"
else:
num = res.group()
prefix = name[0 : -len(num)]
return f"{prefix}{int(num) + 1}"
def _make_unique_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
"""Given two pipelines, try to rename any collisions in component names.
If a simple increment of a numeric suffix doesn't work, will give up.
"""
fail_msg = """
Tried automatically renaming {name} to {new_name}, but still
had a collision, so bailing out. Please make your pipe names
unique.
"""
# map of components to be renamed
rename = {}
# check pipeline names
names = nlp.pipe_names
for name in nlp2.pipe_names:
if name in names:
inc = _increment_suffix(name)
if inc in names or inc in nlp2.pipe_names:
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
rename[name] = inc
return rename
def _inner_merge(
@ -21,9 +64,9 @@ def _inner_merge(
returns: assembled pipeline.
"""
# we checked earlier, so there's definitely just one
# The outer merge already verified there was exactly one tok2vec
tok2vec_name = _get_tok2vecs(nlp2.config)[0]
rename = _check_pipeline_names(nlp, nlp2)
rename = _make_unique_pipeline_names(nlp, nlp2)
if len(_get_listeners(nlp2)) > 1:
if replace_listeners:

View File

@ -21,7 +21,8 @@ from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands
from spacy.cli._util import upload_file, download_file
from spacy.cli.configure import configure_resume_cli, use_tok2vec
from spacy.cli.configure import configure_resume_cli
from spacy.cli.configure import configure_tok2vec_feature_source
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
@ -1206,7 +1207,7 @@ def test_configure_resume(tmp_path):
assert "source" in val, f"Non-sourced component: {comp}"
def test_use_tok2vec(tmp_path):
def test_configure_tok2vec_feature_source(tmp_path):
# Can't add a transformer here because spacy-transformers might not be present
nlp = spacy.blank("en")
nlp.add_pipe("tok2vec")
@ -1214,7 +1215,7 @@ def test_use_tok2vec(tmp_path):
nlp.to_disk(base_path)
out_path = tmp_path / "converted_to_tok2vec"
conf = use_tok2vec(base_path, out_path)
conf = configure_tok2vec_feature_source(base_path, out_path)
assert out_path.exists(), "No model saved"
assert "tok2vec" in conf["components"], "No tok2vec component"