mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-24 23:13:04 +03:00
Code reorganization
This commit is contained in:
parent
a76fd0da99
commit
c0a3e9a44a
|
@ -29,7 +29,8 @@ from .project.push import project_push # noqa: F401
|
||||||
from .project.pull import project_pull # noqa: F401
|
from .project.pull import project_pull # noqa: F401
|
||||||
from .project.document import project_document # noqa: F401
|
from .project.document import project_document # noqa: F401
|
||||||
from .find_threshold import find_threshold # noqa: F401
|
from .find_threshold import find_threshold # noqa: F401
|
||||||
from .configure import use_tok2vec, use_transformer # noqa: F401
|
from .configure import configure_tok2vec_feature_source # noqa: F401
|
||||||
|
from .configure import configure_transformer_feature_source # noqa: F401
|
||||||
from .configure import configure_resume_cli # noqa: F401
|
from .configure import configure_resume_cli # noqa: F401
|
||||||
from .merge import merge_pipelines # noqa: F401
|
from .merge import merge_pipelines # noqa: F401
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import typer
|
import typer
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
|
@ -80,23 +79,6 @@ def _get_listeners(nlp: Language) -> List[str]:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _increment_suffix(name: str) -> str:
|
|
||||||
"""Given a name, return an incremented version.
|
|
||||||
|
|
||||||
If no numeric suffix is found, return the original with "2" appended.
|
|
||||||
|
|
||||||
This is used to avoid name collisions in pipelines.
|
|
||||||
"""
|
|
||||||
|
|
||||||
res = re.search(r"\d+$", name)
|
|
||||||
if res is None:
|
|
||||||
return f"{name}2"
|
|
||||||
else:
|
|
||||||
num = res.group()
|
|
||||||
prefix = name[0 : -len(num)]
|
|
||||||
return f"{prefix}{int(num) + 1}"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_single_tok2vec(name: str, config: Config) -> None:
|
def _check_single_tok2vec(name: str, config: Config) -> None:
|
||||||
"""Check if there is just one tok2vec in a config.
|
"""Check if there is just one tok2vec in a config.
|
||||||
|
|
||||||
|
@ -110,31 +92,6 @@ def _check_single_tok2vec(name: str, config: Config) -> None:
|
||||||
msg.fail(fail_msg, exits=1)
|
msg.fail(fail_msg, exits=1)
|
||||||
|
|
||||||
|
|
||||||
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
|
||||||
"""Given two pipelines, try to rename any collisions in component names.
|
|
||||||
|
|
||||||
If a simple increment of a numeric suffix doesn't work, will give up.
|
|
||||||
"""
|
|
||||||
|
|
||||||
fail_msg = """
|
|
||||||
Tried automatically renaming {name} to {new_name}, but still
|
|
||||||
had a collision, so bailing out. Please make your pipe names
|
|
||||||
unique.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# map of components to be renamed
|
|
||||||
rename = {}
|
|
||||||
# check pipeline names
|
|
||||||
names = nlp.pipe_names
|
|
||||||
for name in nlp2.pipe_names:
|
|
||||||
if name in names:
|
|
||||||
inc = _increment_suffix(name)
|
|
||||||
if inc in names or inc in nlp2.pipe_names:
|
|
||||||
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
|
|
||||||
rename[name] = inc
|
|
||||||
return rename
|
|
||||||
|
|
||||||
|
|
||||||
@configure_cli.command("resume")
|
@configure_cli.command("resume")
|
||||||
def configure_resume_cli(
|
def configure_resume_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
@ -167,7 +124,7 @@ def configure_resume_cli(
|
||||||
|
|
||||||
|
|
||||||
@configure_cli.command("transformer")
|
@configure_cli.command("transformer")
|
||||||
def use_transformer(
|
def configure_transformer_feature_source(
|
||||||
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
|
base_model: str, output_file: Path, transformer_name: str = "roberta-base"
|
||||||
) -> Config:
|
) -> Config:
|
||||||
"""Replace pipeline tok2vec with transformer."""
|
"""Replace pipeline tok2vec with transformer."""
|
||||||
|
@ -218,7 +175,7 @@ def use_transformer(
|
||||||
|
|
||||||
|
|
||||||
@configure_cli.command("tok2vec")
|
@configure_cli.command("tok2vec")
|
||||||
def use_tok2vec(base_model: str, output_file: Path) -> Config:
|
def configure_tok2vec_feature_source(base_model: str, output_file: Path) -> Config:
|
||||||
"""Replace pipeline tok2vec with CNN tok2vec."""
|
"""Replace pipeline tok2vec with CNN tok2vec."""
|
||||||
nlp = spacy.load(base_model)
|
nlp = spacy.load(base_model)
|
||||||
_check_single_tok2vec(base_model, nlp.config)
|
_check_single_tok2vec(base_model, nlp.config)
|
||||||
|
|
|
@ -1,12 +1,55 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import re
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
|
||||||
from ._util import app, Arg, Opt
|
from ._util import app, Arg, Opt, Dict
|
||||||
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
|
from .configure import _check_single_tok2vec, _get_listeners, _get_tok2vecs
|
||||||
from .configure import _check_pipeline_names, _has_listener
|
from .configure import _has_listener
|
||||||
|
|
||||||
|
|
||||||
|
def _increment_suffix(name: str) -> str:
|
||||||
|
"""Given a name, return an incremented version.
|
||||||
|
|
||||||
|
If no numeric suffix is found, return the original with "2" appended.
|
||||||
|
|
||||||
|
This is used to avoid name collisions in pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = re.search(r"\d+$", name)
|
||||||
|
if res is None:
|
||||||
|
return f"{name}2"
|
||||||
|
else:
|
||||||
|
num = res.group()
|
||||||
|
prefix = name[0 : -len(num)]
|
||||||
|
return f"{prefix}{int(num) + 1}"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_unique_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
||||||
|
"""Given two pipelines, try to rename any collisions in component names.
|
||||||
|
|
||||||
|
If a simple increment of a numeric suffix doesn't work, will give up.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fail_msg = """
|
||||||
|
Tried automatically renaming {name} to {new_name}, but still
|
||||||
|
had a collision, so bailing out. Please make your pipe names
|
||||||
|
unique.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# map of components to be renamed
|
||||||
|
rename = {}
|
||||||
|
# check pipeline names
|
||||||
|
names = nlp.pipe_names
|
||||||
|
for name in nlp2.pipe_names:
|
||||||
|
if name in names:
|
||||||
|
inc = _increment_suffix(name)
|
||||||
|
if inc in names or inc in nlp2.pipe_names:
|
||||||
|
msg.fail(fail_msg.format(name=name, new_name=inc), exits=1)
|
||||||
|
rename[name] = inc
|
||||||
|
return rename
|
||||||
|
|
||||||
|
|
||||||
def _inner_merge(
|
def _inner_merge(
|
||||||
|
@ -21,9 +64,9 @@ def _inner_merge(
|
||||||
returns: assembled pipeline.
|
returns: assembled pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# we checked earlier, so there's definitely just one
|
# The outer merge already verified there was exactly one tok2vec
|
||||||
tok2vec_name = _get_tok2vecs(nlp2.config)[0]
|
tok2vec_name = _get_tok2vecs(nlp2.config)[0]
|
||||||
rename = _check_pipeline_names(nlp, nlp2)
|
rename = _make_unique_pipeline_names(nlp, nlp2)
|
||||||
|
|
||||||
if len(_get_listeners(nlp2)) > 1:
|
if len(_get_listeners(nlp2)) > 1:
|
||||||
if replace_listeners:
|
if replace_listeners:
|
||||||
|
|
|
@ -21,7 +21,8 @@ from spacy.cli._util import parse_config_overrides, string_to_list
|
||||||
from spacy.cli._util import substitute_project_variables
|
from spacy.cli._util import substitute_project_variables
|
||||||
from spacy.cli._util import validate_project_commands
|
from spacy.cli._util import validate_project_commands
|
||||||
from spacy.cli._util import upload_file, download_file
|
from spacy.cli._util import upload_file, download_file
|
||||||
from spacy.cli.configure import configure_resume_cli, use_tok2vec
|
from spacy.cli.configure import configure_resume_cli
|
||||||
|
from spacy.cli.configure import configure_tok2vec_feature_source
|
||||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||||
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||||
|
@ -1206,7 +1207,7 @@ def test_configure_resume(tmp_path):
|
||||||
assert "source" in val, f"Non-sourced component: {comp}"
|
assert "source" in val, f"Non-sourced component: {comp}"
|
||||||
|
|
||||||
|
|
||||||
def test_use_tok2vec(tmp_path):
|
def test_configure_tok2vec_feature_source(tmp_path):
|
||||||
# Can't add a transformer here because spacy-transformers might not be present
|
# Can't add a transformer here because spacy-transformers might not be present
|
||||||
nlp = spacy.blank("en")
|
nlp = spacy.blank("en")
|
||||||
nlp.add_pipe("tok2vec")
|
nlp.add_pipe("tok2vec")
|
||||||
|
@ -1214,7 +1215,7 @@ def test_use_tok2vec(tmp_path):
|
||||||
nlp.to_disk(base_path)
|
nlp.to_disk(base_path)
|
||||||
|
|
||||||
out_path = tmp_path / "converted_to_tok2vec"
|
out_path = tmp_path / "converted_to_tok2vec"
|
||||||
conf = use_tok2vec(base_path, out_path)
|
conf = configure_tok2vec_feature_source(base_path, out_path)
|
||||||
assert out_path.exists(), "No model saved"
|
assert out_path.exists(), "No model saved"
|
||||||
|
|
||||||
assert "tok2vec" in conf["components"], "No tok2vec component"
|
assert "tok2vec" in conf["components"], "No tok2vec component"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user