mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
fix spancat initialize with labels (#8620)
This commit is contained in:
parent
608fc1d623
commit
733e8ceea9
|
@ -108,6 +108,10 @@ def init_labels_cli(
|
|||
config = util.load_config(config_path, overrides=overrides)
|
||||
with show_validation_error(hint_fill=False):
|
||||
nlp = init_nlp(config, use_gpu=use_gpu)
|
||||
_init_labels(nlp, output_path)
|
||||
|
||||
|
||||
def _init_labels(nlp, output_path):
|
||||
for name, component in nlp.pipeline:
|
||||
if getattr(component, "label_data", None) is not None:
|
||||
output_file = output_path / f"{name}.json"
|
||||
|
|
|
@ -329,7 +329,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Language = None,
|
||||
labels: Optional[Dict] = None,
|
||||
labels: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
|
|
@ -19,6 +19,7 @@ import srsly
|
|||
import os
|
||||
|
||||
from .util import make_tempdir
|
||||
from ..cli.init_pipeline import _init_labels
|
||||
|
||||
|
||||
def test_cli_info():
|
||||
|
@ -501,3 +502,33 @@ def test_validate_compatibility_table():
|
|||
current_compat = compat.get(spacy_version, {})
|
||||
assert len(current_compat) > 0
|
||||
assert "en_core_web_sm" in current_compat
|
||||
|
||||
|
||||
@pytest.mark.parametrize("component_name", ["ner", "textcat", "spancat", "tagger"])
|
||||
def test_init_labels(component_name):
|
||||
nlp = Dutch()
|
||||
component = nlp.add_pipe(component_name)
|
||||
for label in ["T1", "T2", "T3", "T4"]:
|
||||
component.add_label(label)
|
||||
assert len(nlp.get_pipe(component_name).labels) == 4
|
||||
|
||||
with make_tempdir() as tmp_dir:
|
||||
_init_labels(nlp, tmp_dir)
|
||||
|
||||
config = init_config(
|
||||
lang="nl",
|
||||
pipeline=[component_name],
|
||||
optimize="efficiency",
|
||||
gpu=False,
|
||||
)
|
||||
config["initialize"]["components"][component_name] = {
|
||||
"labels": {
|
||||
"@readers": "spacy.read_labels.v1",
|
||||
"path": f"{tmp_dir}/{component_name}.json",
|
||||
}
|
||||
}
|
||||
|
||||
nlp2 = load_model_from_config(config, auto_fill=True)
|
||||
assert len(nlp2.get_pipe(component_name).labels) == 0
|
||||
nlp2.initialize()
|
||||
assert len(nlp2.get_pipe(component_name).labels) == 4
|
||||
|
|
Loading…
Reference in New Issue
Block a user