💫 Add Language.pipe_labels (#4276)

* Add Language.pipe_labels

* Update spacy/language.py

Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Ines Montani 2019-09-12 10:56:28 +02:00 committed by GitHub
parent 71909cdf22
commit ac0e27a825
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 0 deletions

View File

@ -248,6 +248,18 @@ class Language(object):
""" """
return [pipe_name for pipe_name, _ in self.pipeline] return [pipe_name for pipe_name, _ in self.pipeline]
@property
def pipe_labels(self):
"""Get the labels set by the pipeline components, if available.
RETURNS (dict): Labels keyed by component name.
"""
labels = OrderedDict()
for name, pipe in self.pipeline:
if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels)
return labels
def get_pipe(self, name): def get_pipe(self, name):
"""Get a pipeline component for a given component name. """Get a pipeline component for a given component name.

View File

@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component):
assert label in pipe.labels assert label in pipe.labels
else: else:
assert pipe.labels == (label,) assert pipe.labels == (label,)
def test_pipe_labels(nlp):
input_labels = {
"ner": ["PERSON", "ORG", "GPE"],
"textcat": ["POSITIVE", "NEGATIVE"],
}
for name, labels in input_labels.items():
pipe = nlp.create_pipe(name)
for label in labels:
pipe.add_label(label)
assert len(pipe.labels) == len(labels)
nlp.add_pipe(pipe)
assert len(nlp.pipe_labels) == len(input_labels)
for name, labels in nlp.pipe_labels.items():
assert sorted(input_labels[name]) == sorted(labels)