From ac0e27a825c1b26cb016b7107b18f0de1c7969ff Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 12 Sep 2019 10:56:28 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AB=20Add=20Language.pipe=5Flabels=20(?= =?UTF-8?q?#4276)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Language.pipe_labels * Update spacy/language.py Co-Authored-By: Matthew Honnibal --- spacy/language.py | 12 ++++++++++++ spacy/tests/pipeline/test_pipe_methods.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/spacy/language.py b/spacy/language.py index 10381573d..9dc48ca6f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -248,6 +248,18 @@ class Language(object): """ return [pipe_name for pipe_name, _ in self.pipeline] + @property + def pipe_labels(self): + """Get the labels set by the pipeline components, if available. + + RETURNS (dict): Labels keyed by component name. + """ + labels = OrderedDict() + for name, pipe in self.pipeline: + if hasattr(pipe, "labels"): + labels[name] = list(pipe.labels) + return labels + def get_pipe(self, name): """Get a pipeline component for a given component name. diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 8755cc27a..5f1fa5cfe 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component): assert label in pipe.labels else: assert pipe.labels == (label,) + + +def test_pipe_labels(nlp): + input_labels = { + "ner": ["PERSON", "ORG", "GPE"], + "textcat": ["POSITIVE", "NEGATIVE"], + } + for name, labels in input_labels.items(): + pipe = nlp.create_pipe(name) + for label in labels: + pipe.add_label(label) + assert len(pipe.labels) == len(labels) + nlp.add_pipe(pipe) + assert len(nlp.pipe_labels) == len(input_labels) + for name, labels in nlp.pipe_labels.items(): + assert sorted(input_labels[name]) == sorted(labels)