fix spancat initialize with labels (#8620)

This commit is contained in:
Sofie Van Landeghem 2021-07-06 19:08:25 +02:00 committed by GitHub
parent 608fc1d623
commit 733e8ceea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 1 deletions

View File

@ -108,6 +108,10 @@ def init_labels_cli(
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
with show_validation_error(hint_fill=False): with show_validation_error(hint_fill=False):
nlp = init_nlp(config, use_gpu=use_gpu) nlp = init_nlp(config, use_gpu=use_gpu)
_init_labels(nlp, output_path)
def _init_labels(nlp, output_path):
for name, component in nlp.pipeline: for name, component in nlp.pipeline:
if getattr(component, "label_data", None) is not None: if getattr(component, "label_data", None) is not None:
output_file = output_path / f"{name}.json" output_file = output_path / f"{name}.json"

View File

@ -329,7 +329,7 @@ class SpanCategorizer(TrainablePipe):
get_examples: Callable[[], Iterable[Example]], get_examples: Callable[[], Iterable[Example]],
*, *,
nlp: Language = None, nlp: Language = None,
labels: Optional[Dict] = None, labels: Optional[List[str]] = None,
) -> None: ) -> None:
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.

View File

@ -19,6 +19,7 @@ import srsly
import os import os
from .util import make_tempdir from .util import make_tempdir
from ..cli.init_pipeline import _init_labels
def test_cli_info(): def test_cli_info():
@ -501,3 +502,33 @@ def test_validate_compatibility_table():
current_compat = compat.get(spacy_version, {}) current_compat = compat.get(spacy_version, {})
assert len(current_compat) > 0 assert len(current_compat) > 0
assert "en_core_web_sm" in current_compat assert "en_core_web_sm" in current_compat
@pytest.mark.parametrize("component_name", ["ner", "textcat", "spancat", "tagger"])
def test_init_labels(component_name):
nlp = Dutch()
component = nlp.add_pipe(component_name)
for label in ["T1", "T2", "T3", "T4"]:
component.add_label(label)
assert len(nlp.get_pipe(component_name).labels) == 4
with make_tempdir() as tmp_dir:
_init_labels(nlp, tmp_dir)
config = init_config(
lang="nl",
pipeline=[component_name],
optimize="efficiency",
gpu=False,
)
config["initialize"]["components"][component_name] = {
"labels": {
"@readers": "spacy.read_labels.v1",
"path": f"{tmp_dir}/{component_name}.json",
}
}
nlp2 = load_model_from_config(config, auto_fill=True)
assert len(nlp2.get_pipe(component_name).labels) == 0
nlp2.initialize()
assert len(nlp2.get_pipe(component_name).labels) == 4