Support span._. in component decorator attrs (#4555)

* Support span._. in component decorator attrs

* Adjust error [ci skip]
This commit is contained in:
Ines Montani 2019-10-30 17:19:36 +01:00 committed by GitHub
parent 4e1de85e43
commit 85f2b04c45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 7 deletions

View File

@ -4,7 +4,7 @@ from __future__ import unicode_literals
from collections import OrderedDict
from wasabi import Printer
from .tokens import Doc, Token
from .tokens import Doc, Token, Span
from .errors import Errors, Warnings, user_warning
@ -78,12 +78,15 @@ def validate_attrs(values):
RETURNS (iterable): The checked attributes.
"""
data = dot_to_dict(values)
objs = {"doc": Doc, "token": Token}
objs = {"doc": Doc, "token": Token, "span": Span}
for obj_key, attrs in data.items():
if obj_key not in objs: # first element is not doc/token
if obj_key == "span":
# Support Span only for custom extension attributes
span_attrs = [attr for attr in values if attr.startswith("span.")]
span_attrs = [attr for attr in span_attrs if not attr.startswith("span._.")]
if span_attrs:
raise ValueError(Errors.E180.format(attrs=", ".join(span_attrs)))
if obj_key not in objs: # first element is not doc/token/span
invalid_attrs = ", ".join(a for a in values if a.startswith(obj_key))
raise ValueError(Errors.E181.format(obj=obj_key, attrs=invalid_attrs))
if not isinstance(attrs, dict): # attr is something like "doc"

View File

@ -515,7 +515,8 @@ class Errors(object):
"in a list. For example: matcher.add('{key}', [doc])")
E180 = ("Span attributes can't be declared as required or assigned by "
"components, since spans are only views of the Doc. Use Doc and "
"Token attributes only and remove the following: {attrs}")
"Token attributes (or custom extension attributes) only and remove "
"the following: {attrs}")
E181 = ("Received invalid attributes for unkown object {obj}: {attrs}. "
"Only Doc and Token attributes are supported.")
E182 = ("Received invalid attribute declaration: {attr}\nDid you forget "

View File

@ -121,7 +121,7 @@ def test_component_factories_from_nlp():
def test_analysis_validate_attrs_valid():
attrs = ["doc.sents", "doc.ents", "token.tag", "token._.xyz"]
attrs = ["doc.sents", "doc.ents", "token.tag", "token._.xyz", "span._.xyz"]
assert validate_attrs(attrs)
for attr in attrs:
assert validate_attrs([attr])
@ -139,6 +139,7 @@ def test_analysis_validate_attrs_valid():
"token.tag_",
"token.tag.xyz",
"token._.xyz.abc",
"span.label",
],
)
def test_analysis_validate_attrs_invalid(attr):