spaCy/spacy/analysis.py
2020-05-22 16:43:18 +02:00

197 lines
8.0 KiB
Python

from wasabi import Printer
import warnings
from .tokens import Doc, Token, Span
from .errors import Errors, Warnings
def analyze_pipes(pipeline, name, pipe, index, warn=True):
"""Analyze a pipeline component with respect to its position in the current
pipeline and the other components. Will check whether requirements are
fulfilled (e.g. if previous components assign the attributes).
pipeline (list): A list of (name, pipe) tuples e.g. nlp.pipeline.
name (unicode): The name of the pipeline component to analyze.
pipe (callable): The pipeline component function to analyze.
index (int): The index of the component in the pipeline.
warn (bool): Show user warning if problem is found.
RETURNS (list): The problems found for the given pipeline component.
"""
assert pipeline[index][0] == name
prev_pipes = pipeline[:index]
pipe_requires = getattr(pipe, "requires", [])
requires = {annot: False for annot in pipe_requires}
if requires:
for prev_name, prev_pipe in prev_pipes:
prev_assigns = getattr(prev_pipe, "assigns", [])
for annot in prev_assigns:
requires[annot] = True
problems = []
for annot, fulfilled in requires.items():
if not fulfilled:
problems.append(annot)
if warn:
warnings.warn(Warnings.W025.format(name=name, attr=annot))
return problems
def analyze_all_pipes(pipeline, warn=True):
"""Analyze all pipes in the pipeline in order.
pipeline (list): A list of (name, pipe) tuples e.g. nlp.pipeline.
warn (bool): Show user warning if problem is found.
RETURNS (dict): The problems found, keyed by component name.
"""
problems = {}
for i, (name, pipe) in enumerate(pipeline):
problems[name] = analyze_pipes(pipeline, name, pipe, i, warn=warn)
return problems
def dot_to_dict(values):
"""Convert dot notation to a dict. For example: ["token.pos", "token._.xyz"]
become {"token": {"pos": True, "_": {"xyz": True }}}.
values (iterable): The values to convert.
RETURNS (dict): The converted values.
"""
result = {}
for value in values:
path = result
parts = value.lower().split(".")
for i, item in enumerate(parts):
is_last = i == len(parts) - 1
path = path.setdefault(item, True if is_last else {})
return result
def validate_attrs(values):
"""Validate component attributes provided to "assigns", "requires" etc.
Raises error for invalid attributes and formatting. Doesn't check if
custom extension attributes are registered, since this is something the
user might want to do themselves later in the component.
values (iterable): The string attributes to check, e.g. `["token.pos"]`.
RETURNS (iterable): The checked attributes.
"""
data = dot_to_dict(values)
objs = {"doc": Doc, "token": Token, "span": Span}
for obj_key, attrs in data.items():
if obj_key == "span":
# Support Span only for custom extension attributes
span_attrs = [attr for attr in values if attr.startswith("span.")]
span_attrs = [attr for attr in span_attrs if not attr.startswith("span._.")]
if span_attrs:
raise ValueError(Errors.E180.format(attrs=", ".join(span_attrs)))
if obj_key not in objs: # first element is not doc/token/span
invalid_attrs = ", ".join(a for a in values if a.startswith(obj_key))
raise ValueError(Errors.E181.format(obj=obj_key, attrs=invalid_attrs))
if not isinstance(attrs, dict): # attr is something like "doc"
raise ValueError(Errors.E182.format(attr=obj_key))
for attr, value in attrs.items():
if attr == "_":
if value is True: # attr is something like "doc._"
raise ValueError(Errors.E182.format(attr="{}._".format(obj_key)))
for ext_attr, ext_value in value.items():
# We don't check whether the attribute actually exists
if ext_value is not True: # attr is something like doc._.x.y
good = f"{obj_key}._.{ext_attr}"
bad = f"{good}.{'.'.join(ext_value)}"
raise ValueError(Errors.E183.format(attr=bad, solution=good))
continue # we can't validate those further
if attr.endswith("_"): # attr is something like "token.pos_"
raise ValueError(Errors.E184.format(attr=attr, solution=attr[:-1]))
if value is not True: # attr is something like doc.x.y
good = f"{obj_key}.{attr}"
bad = f"{good}.{'.'.join(value)}"
raise ValueError(Errors.E183.format(attr=bad, solution=good))
obj = objs[obj_key]
if not hasattr(obj, attr):
raise ValueError(Errors.E185.format(obj=obj_key, attr=attr))
return values
def _get_feature_for_attr(pipeline, attr, feature):
assert feature in ["assigns", "requires"]
result = []
for pipe_name, pipe in pipeline:
pipe_assigns = getattr(pipe, feature, [])
if attr in pipe_assigns:
result.append((pipe_name, pipe))
return result
def get_assigns_for_attr(pipeline, attr):
"""Get all pipeline components that assign an attr, e.g. "doc.tensor".
pipeline (list): A list of (name, pipe) tuples e.g. nlp.pipeline.
attr (unicode): The attribute to check.
RETURNS (list): (name, pipeline) tuples of components that assign the attr.
"""
return _get_feature_for_attr(pipeline, attr, "assigns")
def get_requires_for_attr(pipeline, attr):
"""Get all pipeline components that require an attr, e.g. "doc.tensor".
pipeline (list): A list of (name, pipe) tuples e.g. nlp.pipeline.
attr (unicode): The attribute to check.
RETURNS (list): (name, pipeline) tuples of components that require the attr.
"""
return _get_feature_for_attr(pipeline, attr, "requires")
def print_summary(nlp, pretty=True, no_print=False):
"""Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as
well as any problems if available.
nlp (Language): The nlp object.
pretty (bool): Pretty-print the results (color etc).
no_print (bool): Don't print anything, just return the data.
RETURNS (dict): A dict with "overview" and "problems".
"""
msg = Printer(pretty=pretty, no_print=no_print)
overview = []
problems = {}
for i, (name, pipe) in enumerate(nlp.pipeline):
requires = getattr(pipe, "requires", [])
assigns = getattr(pipe, "assigns", [])
retok = getattr(pipe, "retokenizes", False)
overview.append((i, name, requires, assigns, retok))
problems[name] = analyze_pipes(nlp.pipeline, name, pipe, i, warn=False)
msg.divider("Pipeline Overview")
header = ("#", "Component", "Requires", "Assigns", "Retokenizes")
msg.table(overview, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in problems.values())
if any(p for p in problems.values()):
msg.divider(f"Problems ({n_problems})")
for name, problem in problems.items():
if problem:
msg.warn(f"'{name}' requirements not met: {', '.join(problem)}")
else:
msg.good("No problems found.")
if no_print:
return {"overview": overview, "problems": problems}
def count_pipeline_interdependencies(pipeline):
"""Count how many subsequent components require an annotation set by each
component in the pipeline.
"""
pipe_assigns = []
pipe_requires = []
for name, pipe in pipeline:
pipe_assigns.append(set(getattr(pipe, "assigns", [])))
pipe_requires.append(set(getattr(pipe, "requires", [])))
counts = []
for i, assigns in enumerate(pipe_assigns):
count = 0
for requires in pipe_requires[i+1:]:
if assigns.intersection(requires):
count += 1
counts.append(count)
return counts