mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 05:33:15 +03:00
Code reorganization
This commit is contained in:
parent
a76fd0da99
commit
c0a3e9a44a
|
@ -29,7 +29,8 @@ from .project.push import project_push # noqa: F401
|
|||
from .project.pull import project_pull # noqa: F401
|
||||
from .project.document import project_document # noqa: F401
|
||||
from .find_threshold import find_threshold # noqa: F401
|
||||
from .configure import use_tok2vec, use_transformer # noqa: F401
|
||||
from .configure import configure_tok2vec_feature_source # noqa: F401
|
||||
from .configure import configure_transformer_feature_source # noqa: F401
|
||||
from .configure import configure_resume_cli # noqa: F401
|
||||
from .merge import merge_pipelines # noqa: F401
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
from wasabi import msg
|
||||
import typer
|
||||
from thinc.api import Config
|
||||
|
@ -80,23 +79,6 @@ def _get_listeners(nlp: Language) -> List[str]:
|
|||
return out
|
||||
|
||||
|
||||
def _increment_suffix(name: str) -> str:
|
||||
"""Given a name, return an incremented version.
|
||||
|
||||
If no numeric suffix is found, return the original with "2" appended.
|
||||
|
||||
This is used to avoid name collisions in pipelines.
|
||||
"""
|
||||
|
||||
res = re.search(r"\d+$", name)
|
||||
if res is None:
|
||||
return f"{name}2"
|
||||
else:
|
||||
num = res.group()
|
||||
prefix = name[0 : -len(num)]
|
||||
return f"{prefix}{int(num) + 1}"
|
||||
|
||||
|
||||
def _check_single_tok2vec(name: str, config: Config) -> None:
|
||||
"""Check if there is just one tok2vec in a config.
|
||||
|
||||
|
@ -110,31 +92,6 @@ def _check_single_tok2vec(name: str, config: Config) -> None:
|
|||
msg.fail(fail_msg, exits=1)
|
||||
|
||||
|
||||
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
||||
"""Given two pipelines, try to rename any collisions in component names.
|
||||
|
||||
If a simple increment of a numeric suffix doesn't work, will give up.
|
||||
"""
|
||||
|
||||
fail_msg = """
|
||||
Tried automatically renaming {name} to {new_name}, but still
|
||||
had a collision, so bailing out. Please make your pipe names
|
||||
unique.
|
||||
"""
|
||||
|
||||
# map of components to be renamed
|
||||
rename = {}
|
||||
# check pipeline names
|
||||
names = nlp.pipe_names
|
||||
for name in nlp2.pipe_names:
|
||||
if name in names:
|
||||
inc = _increment_suffix(name)
|
||||
if inc in names or inc in nlp2.pipe_names:
|
||||
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
|
||||
rename[name] = inc
|
||||
return rename
|
||||
|
||||
|
||||
@configure_cli.command("resume")
|
||||
def configure_resume_cli(
|
||||
# fmt: off
|
||||
|
@ -167,7 +124,7 @@ def configure_resume_cli(
|
|||
|
||||
|
||||
@configure_cli.command("transformer")
|
||||
def use_transformer(
|
||||
def configure_transformer_feature_source(
|
||||
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
|
||||
) -> Config:
|
||||
"""Replace pipeline tok2vec with transformer."""
|
||||
|
@ -218,7 +175,7 @@ def use_transformer(
|
|||
|
||||
|
||||
@configure_cli.command("tok2vec")
|
||||
def use_tok2vec(base_model: str, output_file: Path) -> Config:
|
||||
def configure_tok2vec_feature_source(base_model: str, output_file: Path) -> Config:
|
||||
"""Replace pipeline tok2vec with CNN tok2vec."""
|
||||
nlp = spacy.load(base_model)
|
||||
_check_single_tok2vec(base_model, nlp.config)
|
||||
|
|
|
@ -1,12 +1,55 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
from wasabi import msg
|
||||
|
||||
import spacy
|
||||
from spacy.language import Language
|
||||
|
||||
from ._util import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, Dict
|
||||
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
|
||||
from .configure import _check_pipeline_names, _has_listener
|
||||
from .configure import _has_listener
|
||||
|
||||
|
||||
def _increment_suffix(name: str) -> str:
|
||||
"""Given a name, return an incremented version.
|
||||
|
||||
If no numeric suffix is found, return the original with "2" appended.
|
||||
|
||||
This is used to avoid name collisions in pipelines.
|
||||
"""
|
||||
|
||||
res = re.search(r"\d+$", name)
|
||||
if res is None:
|
||||
return f"{name}2"
|
||||
else:
|
||||
num = res.group()
|
||||
prefix = name[0 : -len(num)]
|
||||
return f"{prefix}{int(num) + 1}"
|
||||
|
||||
|
||||
def _make_unique_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
||||
"""Given two pipelines, try to rename any collisions in component names.
|
||||
|
||||
If a simple increment of a numeric suffix doesn't work, will give up.
|
||||
"""
|
||||
|
||||
fail_msg = """
|
||||
Tried automatically renaming {name} to {new_name}, but still
|
||||
had a collision, so bailing out. Please make your pipe names
|
||||
unique.
|
||||
"""
|
||||
|
||||
# map of components to be renamed
|
||||
rename = {}
|
||||
# check pipeline names
|
||||
names = nlp.pipe_names
|
||||
for name in nlp2.pipe_names:
|
||||
if name in names:
|
||||
inc = _increment_suffix(name)
|
||||
if inc in names or inc in nlp2.pipe_names:
|
||||
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
|
||||
rename[name] = inc
|
||||
return rename
|
||||
|
||||
|
||||
def _inner_merge(
|
||||
|
@ -21,9 +64,9 @@ def _inner_merge(
|
|||
returns: assembled pipeline.
|
||||
"""
|
||||
|
||||
# we checked earlier, so there's definitely just one
|
||||
# The outer merge already verified there was exactly one tok2vec
|
||||
tok2vec_name = _get_tok2vecs(nlp2.config)[0]
|
||||
rename = _check_pipeline_names(nlp, nlp2)
|
||||
rename = _make_unique_pipeline_names(nlp, nlp2)
|
||||
|
||||
if len(_get_listeners(nlp2)) > 1:
|
||||
if replace_listeners:
|
||||
|
|
|
@ -21,7 +21,8 @@ from spacy.cli._util import parse_config_overrides, string_to_list
|
|||
from spacy.cli._util import substitute_project_variables
|
||||
from spacy.cli._util import validate_project_commands
|
||||
from spacy.cli._util import upload_file, download_file
|
||||
from spacy.cli.configure import configure_resume_cli, use_tok2vec
|
||||
from spacy.cli.configure import configure_resume_cli
|
||||
from spacy.cli.configure import configure_tok2vec_feature_source
|
||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||
|
@ -1206,7 +1207,7 @@ def test_configure_resume(tmp_path):
|
|||
assert "source" in val, f"Non-sourced component: {comp}"
|
||||
|
||||
|
||||
def test_use_tok2vec(tmp_path):
|
||||
def test_configure_tok2vec_feature_source(tmp_path):
|
||||
# Can't add a transformer here because spacy-transformers might not be present
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe("tok2vec")
|
||||
|
@ -1214,7 +1215,7 @@ def test_use_tok2vec(tmp_path):
|
|||
nlp.to_disk(base_path)
|
||||
|
||||
out_path = tmp_path / "converted_to_tok2vec"
|
||||
conf = use_tok2vec(base_path, out_path)
|
||||
conf = configure_tok2vec_feature_source(base_path, out_path)
|
||||
assert out_path.exists(), "No model saved"
|
||||
|
||||
assert "tok2vec" in conf["components"], "No tok2vec component"
|
||||
|
|
Loading…
Reference in New Issue
Block a user