Merge pull request #5489 from explosion/feature/connected-components

This commit is contained in:
Ines Montani 2020-05-22 17:40:11 +02:00 committed by GitHub
commit 25d6ed3fb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 2 deletions

View File

@ -173,3 +173,24 @@ def print_summary(nlp, pretty=True, no_print=False):
msg.good("No problems found.")
if no_print:
return {"overview": overview, "problems": problems}
def count_pipeline_interdependencies(pipeline):
"""Count how many subsequent components require an annotation set by each
component in the pipeline.
"""
pipe_assigns = []
pipe_requires = []
for name, pipe in pipeline:
pipe_assigns.append(set(getattr(pipe, "assigns", [])))
pipe_requires.append(set(getattr(pipe, "requires", [])))
counts = []
for i, assigns in enumerate(pipe_assigns):
count = 0
for requires in pipe_requires[i+1:]:
if assigns.intersection(requires):
count += 1
counts.append(count)
return counts

View File

@ -18,6 +18,7 @@ from .vocab import Vocab
from .lemmatizer import Lemmatizer
from .lookups import Lookups
from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .analysis import count_pipeline_interdependencies
from .gold import Example
from .scorer import Scorer
from .util import link_vectors_to_models, create_default_optimizer, registry
@ -545,13 +546,14 @@ class Language(object):
if component_cfg is None:
component_cfg = {}
component_deps = count_pipeline_interdependencies(self.pipeline)
# Determine whether component should set annotations. In theory I guess
# we should do this by inspecting the meta? Or we could just always
# say "yes"
for name, proc in self.pipeline:
for i, (name, proc) in enumerate(self.pipeline):
component_cfg.setdefault(name, {})
component_cfg[name].setdefault("drop", drop)
component_cfg[name].setdefault("set_annotations", False)
component_cfg[name]["set_annotations"] = bool(component_deps[i])
for name, proc in self.pipeline:
if not hasattr(proc, "update"):
continue

View File

@ -2,6 +2,7 @@ import spacy.language
from spacy.language import Language, component
from spacy.analysis import print_summary, validate_attrs
from spacy.analysis import get_assigns_for_attr, get_requires_for_attr
from spacy.analysis import count_pipeline_interdependencies
from mock import Mock, ANY
import pytest
@ -161,3 +162,19 @@ def test_analysis_validate_attrs_remove_pipe():
with pytest.warns(None) as record:
nlp.remove_pipe("c2")
assert not record.list
def test_pipe_interdependencies():
class Fancifier:
name = "fancifier"
assigns = ("doc._.fancy",)
requires = tuple()
class FancyNeeder:
name = "needer"
assigns = tuple()
requires = ("doc._.fancy",)
pipeline = [("fancifier", Fancifier()), ("needer", FancyNeeder())]
counts = count_pipeline_interdependencies(pipeline)
assert counts == [1, 0]