diff --git a/spacy/analysis.py b/spacy/analysis.py index c2600048f..41591661c 100644 --- a/spacy/analysis.py +++ b/spacy/analysis.py @@ -173,3 +173,24 @@ def print_summary(nlp, pretty=True, no_print=False): msg.good("No problems found.") if no_print: return {"overview": overview, "problems": problems} + + +def count_pipeline_interdependencies(pipeline): + """Count how many subsequent components require an annotation set by each + component in the pipeline. + """ + pipe_assigns = [] + pipe_requires = [] + for name, pipe in pipeline: + pipe_assigns.append(set(getattr(pipe, "assigns", []))) + pipe_requires.append(set(getattr(pipe, "requires", []))) + counts = [] + for i, assigns in enumerate(pipe_assigns): + count = 0 + for requires in pipe_requires[i+1:]: + if assigns.intersection(requires): + count += 1 + counts.append(count) + return counts + + diff --git a/spacy/language.py b/spacy/language.py index afc988583..b228c2155 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -18,6 +18,7 @@ from .vocab import Vocab from .lemmatizer import Lemmatizer from .lookups import Lookups from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs +from .analysis import count_pipeline_interdependencies from .gold import Example from .scorer import Scorer from .util import link_vectors_to_models, create_default_optimizer, registry @@ -545,7 +546,7 @@ class Language(object): if component_cfg is None: component_cfg = {} - component_deps = _count_pipeline_inter_dependencies(self.pipeline) + component_deps = count_pipeline_interdependencies(self.pipeline) # Determine whether component should set annotations. In theory I guess # we should do this by inspecting the meta? Or we could just always # say "yes" @@ -1160,25 +1161,6 @@ class DisabledPipes(list): self[:] = [] -def _count_pipeline_inter_dependencies(pipeline): - """Count how many subsequent components require an annotation set by each - component in the pipeline. - """ - pipe_assigns = [] - pipe_requires = [] - for name, pipe in pipeline: - pipe_assigns.append(set(getattr(pipe, "assigns", []))) - pipe_requires.append(set(getattr(pipe, "requires", []))) - counts = [] - for i, assigns in enumerate(pipe_assigns): - count = 0 - for requires in pipe_requires[i+1:]: - if assigns.intersection(requires): - count += 1 - counts.append(count) - return counts - - def _pipe(examples, proc, kwargs): # We added some args for pipe that __call__ doesn't expect. kwargs = dict(kwargs) diff --git a/spacy/tests/pipeline/test_analysis.py b/spacy/tests/pipeline/test_analysis.py index cda39f6ee..e608f2c34 100644 --- a/spacy/tests/pipeline/test_analysis.py +++ b/spacy/tests/pipeline/test_analysis.py @@ -2,6 +2,7 @@ import spacy.language from spacy.language import Language, component from spacy.analysis import print_summary, validate_attrs from spacy.analysis import get_assigns_for_attr, get_requires_for_attr +from spacy.analysis import count_pipeline_interdependencies from mock import Mock, ANY import pytest @@ -161,3 +162,19 @@ def test_analysis_validate_attrs_remove_pipe(): with pytest.warns(None) as record: nlp.remove_pipe("c2") assert not record.list + + +def test_pipe_interdependencies(): + class Fancifier: + name = "fancifier" + assigns = ("doc._.fancy",) + requires = tuple() + + class FancyNeeder: + name = "needer" + assigns = tuple() + requires = ("doc._.fancy",) + + pipeline = [("fancifier", Fancifier()), ("needer", FancyNeeder())] + counts = count_pipeline_interdependencies(pipeline) + assert counts == [1, 0] diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 0397d490d..d42216655 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -1,5 +1,5 @@ import pytest -from spacy.language import Language, _count_pipeline_inter_dependencies +from spacy.language import Language @pytest.fixture @@ -198,19 +198,3 @@ def test_pipe_labels(nlp): assert len(nlp.pipe_labels) == len(input_labels) for name, labels in nlp.pipe_labels.items(): assert sorted(input_labels[name]) == sorted(labels) - - -def test_pipe_inter_dependencies(): - class Fancifier: - name = "fancifier" - assigns = ("doc._.fancy",) - requires = tuple() - - class FancyNeeder: - name = "needer" - assigns = tuple() - requires = ("doc._.fancy",) - - pipeline = [("fancifier", Fancifier()), ("needer", FancyNeeder())] - counts = _count_pipeline_inter_dependencies(pipeline) - assert counts == [1, 0]