diff --git a/spacy/analysis.py b/spacy/analysis.py index c2600048f..41591661c 100644 --- a/spacy/analysis.py +++ b/spacy/analysis.py @@ -173,3 +173,24 @@ def print_summary(nlp, pretty=True, no_print=False): msg.good("No problems found.") if no_print: return {"overview": overview, "problems": problems} + + +def count_pipeline_interdependencies(pipeline): + """Count how many subsequent components require an annotation set by each + component in the pipeline. + """ + pipe_assigns = [] + pipe_requires = [] + for name, pipe in pipeline: + pipe_assigns.append(set(getattr(pipe, "assigns", []))) + pipe_requires.append(set(getattr(pipe, "requires", []))) + counts = [] + for i, assigns in enumerate(pipe_assigns): + count = 0 + for requires in pipe_requires[i+1:]: + if assigns.intersection(requires): + count += 1 + counts.append(count) + return counts + + diff --git a/spacy/language.py b/spacy/language.py index f770cda2c..8c44cf26b 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -18,6 +18,7 @@ from .vocab import Vocab from .lemmatizer import Lemmatizer from .lookups import Lookups from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs +from .analysis import count_pipeline_interdependencies from .gold import Example from .scorer import Scorer from .util import link_vectors_to_models, create_default_optimizer, registry @@ -545,13 +546,14 @@ class Language(object): if component_cfg is None: component_cfg = {} + component_deps = count_pipeline_interdependencies(self.pipeline) # Determine whether component should set annotations. In theory I guess # we should do this by inspecting the meta? Or we could just always # say "yes" - for name, proc in self.pipeline: + for i, (name, proc) in enumerate(self.pipeline): component_cfg.setdefault(name, {}) component_cfg[name].setdefault("drop", drop) - component_cfg[name].setdefault("set_annotations", False) + component_cfg[name]["set_annotations"] = bool(component_deps[i]) for name, proc in self.pipeline: if not hasattr(proc, "update"): continue diff --git a/spacy/tests/pipeline/test_analysis.py b/spacy/tests/pipeline/test_analysis.py index cda39f6ee..e608f2c34 100644 --- a/spacy/tests/pipeline/test_analysis.py +++ b/spacy/tests/pipeline/test_analysis.py @@ -2,6 +2,7 @@ import spacy.language from spacy.language import Language, component from spacy.analysis import print_summary, validate_attrs from spacy.analysis import get_assigns_for_attr, get_requires_for_attr +from spacy.analysis import count_pipeline_interdependencies from mock import Mock, ANY import pytest @@ -161,3 +162,19 @@ def test_analysis_validate_attrs_remove_pipe(): with pytest.warns(None) as record: nlp.remove_pipe("c2") assert not record.list + + +def test_pipe_interdependencies(): + class Fancifier: + name = "fancifier" + assigns = ("doc._.fancy",) + requires = tuple() + + class FancyNeeder: + name = "needer" + assigns = tuple() + requires = ("doc._.fancy",) + + pipeline = [("fancifier", Fancifier()), ("needer", FancyNeeder())] + counts = count_pipeline_interdependencies(pipeline) + assert counts == [1, 0]