mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
fix spancat initialize with labels (#8620)
This commit is contained in:
parent
608fc1d623
commit
733e8ceea9
|
@ -108,6 +108,10 @@ def init_labels_cli(
|
||||||
config = util.load_config(config_path, overrides=overrides)
|
config = util.load_config(config_path, overrides=overrides)
|
||||||
with show_validation_error(hint_fill=False):
|
with show_validation_error(hint_fill=False):
|
||||||
nlp = init_nlp(config, use_gpu=use_gpu)
|
nlp = init_nlp(config, use_gpu=use_gpu)
|
||||||
|
_init_labels(nlp, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_labels(nlp, output_path):
|
||||||
for name, component in nlp.pipeline:
|
for name, component in nlp.pipeline:
|
||||||
if getattr(component, "label_data", None) is not None:
|
if getattr(component, "label_data", None) is not None:
|
||||||
output_file = output_path / f"{name}.json"
|
output_file = output_path / f"{name}.json"
|
||||||
|
|
|
@ -329,7 +329,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
get_examples: Callable[[], Iterable[Example]],
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
*,
|
*,
|
||||||
nlp: Language = None,
|
nlp: Language = None,
|
||||||
labels: Optional[Dict] = None,
|
labels: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the pipe for training, using a representative set
|
"""Initialize the pipe for training, using a representative set
|
||||||
of data examples.
|
of data examples.
|
||||||
|
|
|
@ -19,6 +19,7 @@ import srsly
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .util import make_tempdir
|
from .util import make_tempdir
|
||||||
|
from ..cli.init_pipeline import _init_labels
|
||||||
|
|
||||||
|
|
||||||
def test_cli_info():
|
def test_cli_info():
|
||||||
|
@ -501,3 +502,33 @@ def test_validate_compatibility_table():
|
||||||
current_compat = compat.get(spacy_version, {})
|
current_compat = compat.get(spacy_version, {})
|
||||||
assert len(current_compat) > 0
|
assert len(current_compat) > 0
|
||||||
assert "en_core_web_sm" in current_compat
|
assert "en_core_web_sm" in current_compat
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("component_name", ["ner", "textcat", "spancat", "tagger"])
|
||||||
|
def test_init_labels(component_name):
|
||||||
|
nlp = Dutch()
|
||||||
|
component = nlp.add_pipe(component_name)
|
||||||
|
for label in ["T1", "T2", "T3", "T4"]:
|
||||||
|
component.add_label(label)
|
||||||
|
assert len(nlp.get_pipe(component_name).labels) == 4
|
||||||
|
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
_init_labels(nlp, tmp_dir)
|
||||||
|
|
||||||
|
config = init_config(
|
||||||
|
lang="nl",
|
||||||
|
pipeline=[component_name],
|
||||||
|
optimize="efficiency",
|
||||||
|
gpu=False,
|
||||||
|
)
|
||||||
|
config["initialize"]["components"][component_name] = {
|
||||||
|
"labels": {
|
||||||
|
"@readers": "spacy.read_labels.v1",
|
||||||
|
"path": f"{tmp_dir}/{component_name}.json",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nlp2 = load_model_from_config(config, auto_fill=True)
|
||||||
|
assert len(nlp2.get_pipe(component_name).labels) == 0
|
||||||
|
nlp2.initialize()
|
||||||
|
assert len(nlp2.get_pipe(component_name).labels) == 4
|
||||||
|
|
Loading…
Reference in New Issue
Block a user