mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
💫 Add Language.pipe_labels (#4276)
* Add Language.pipe_labels * Update spacy/language.py Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
71909cdf22
commit
ac0e27a825
|
@ -248,6 +248,18 @@ class Language(object):
|
|||
"""
|
||||
return [pipe_name for pipe_name, _ in self.pipeline]
|
||||
|
||||
@property
|
||||
def pipe_labels(self):
|
||||
"""Get the labels set by the pipeline components, if available.
|
||||
|
||||
RETURNS (dict): Labels keyed by component name.
|
||||
"""
|
||||
labels = OrderedDict()
|
||||
for name, pipe in self.pipeline:
|
||||
if hasattr(pipe, "labels"):
|
||||
labels[name] = list(pipe.labels)
|
||||
return labels
|
||||
|
||||
def get_pipe(self, name):
|
||||
"""Get a pipeline component for a given component name.
|
||||
|
||||
|
|
|
@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component):
|
|||
assert label in pipe.labels
|
||||
else:
|
||||
assert pipe.labels == (label,)
|
||||
|
||||
|
||||
def test_pipe_labels(nlp):
|
||||
input_labels = {
|
||||
"ner": ["PERSON", "ORG", "GPE"],
|
||||
"textcat": ["POSITIVE", "NEGATIVE"],
|
||||
}
|
||||
for name, labels in input_labels.items():
|
||||
pipe = nlp.create_pipe(name)
|
||||
for label in labels:
|
||||
pipe.add_label(label)
|
||||
assert len(pipe.labels) == len(labels)
|
||||
nlp.add_pipe(pipe)
|
||||
assert len(nlp.pipe_labels) == len(input_labels)
|
||||
for name, labels in nlp.pipe_labels.items():
|
||||
assert sorted(input_labels[name]) == sorted(labels)
|
||||
|
|
Loading…
Reference in New Issue
Block a user