decorator to require annotations on Doc

This commit is contained in:
kadarakos 2022-10-27 16:49:24 +00:00
parent 865691d169
commit 6f8a69152d
3 changed files with 147 additions and 2 deletions

View File

@ -952,6 +952,7 @@ class Errors(metaclass=ErrorsWithCodes):
"sure it's overwritten on the subclass.")
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
"knowledge base, use `InMemoryLookupKB`.")
E1047 = ("The function `{name}` requires annotations {annotations} to be set, but did not find {missing}.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,14 +1,17 @@
import pytest
import os
import ctypes
import spacy
from typing import Union
from pathlib import Path
from spacy.about import __version__ as spacy_version
from spacy import util
from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.tokens import Doc, Span
from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList, import_file
from spacy.util import to_ternary_int
from spacy.util import to_ternary_int, require_annotation
from thinc.api import Config, Optimizer, ConfigValidationError
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
@ -434,3 +437,75 @@ def test_to_ternary_int():
assert to_ternary_int(-10) == -1
assert to_ternary_int("string") == -1
assert to_ternary_int([0, "string"]) == -1
def test_require_annotations():
@require_annotation([])
def func_none(doc) -> bool:
return True
@require_annotation(["TAG"])
def func_tag(doc) -> bool:
return True
@require_annotation(["ENT_IOB"])
def func_ner(doc) -> bool:
return True
@require_annotation(["TAG"], require_complete=[True])
def func_tag_complete(doc) -> bool:
return True
@require_annotation(["ENT_IOB"])
def func_ner_complete(doc, require_complete=[True]) -> bool:
return True
@require_annotation(["TAG", "ENT_IOB"])
def func_tag_ner(doc) -> bool:
return True
@require_annotation(
["TAG", "ENT_IOB"], require_complete=[False, True]
)
def func_tag_nercomplete(doc):
return True
@require_annotation(
["TAG", "ENT_IOB"], require_complete=[True, False]
)
def func_tagcomplete_ner(doc):
return True
text = "Bob is a person for sure."
nlp = spacy.blank("en")
blank = nlp(text)
tagger = nlp.add_pipe("tagger")
tagger.add_label("A")
nlp.initialize()
tagged = nlp(text)
tagged_partial = nlp(text)
tagged_partial[-1].tag_ = ""
ner = nlp.add_pipe("ner")
ner.add_label("CHEESE")
nlp.initialize()
tagged_nered = nlp(text)
nlp.remove_pipe("tagger")
nered = nlp(text)
tagged_partial_nered = nlp(tagged_partial)
assert func_none(blank)
assert func_none(tagged)
assert func_tag(tagged)
assert func_ner(nered)
assert func_tag_complete(tagged)
assert func_ner_complete(nered)
assert func_tag_ner(tagged_nered)
with pytest.raises(ValueError, match=r"[E1047]"):
func_tag(blank)
with pytest.raises(ValueError, match=r"[E1047]"):
func_tag_complete(tagged_partial)
with pytest.raises(ValueError, match=r"[E1047]"):
func_tagcomplete_ner(tagged_partial_nered)
with pytest.raises(ValueError, match=r"Have to provide the same number"):
require_annotation(["TAG", "NER"], require_complete=[True])

View File

@ -1,5 +1,5 @@
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
from typing import Optional, Iterable, Callable, Tuple, Type
from typing import Optional, Iterable, Callable, Tuple, Type, Sequence
from typing import Iterator, Pattern, Generator, TYPE_CHECKING
from types import ModuleType
import os
@ -1735,3 +1735,72 @@ def all_equal(iterable):
(or if the input is an empty sequence), False otherwise."""
g = itertools.groupby(iterable)
return next(g, True) and not next(g, False)
def require_annotation(
annotations: Sequence[Union[str, int]],
*,
require_complete: Sequence[bool] = None
) -> Callable:
"""
To be used as a decorator for functions whose first argument
is a Doc or a Span. For example:
@require_annotation(["POS"])
def extract_pos(doc: Doc):
return [token.pos_ for token in doc]
Here we check if the input Doc has "POS" annotation and
otherwise we raise an informative error.
annotations: Sequence[Union[str, int]]
The annotations to check for e.g.: "DEP", "POS".
require_complete: Sequence[bool]
For each annotation only consider that it fullfils
the requirement if all tokens are annotated. Otherwise
partial annotation is accepted.
"""
# Check input.
if require_complete is None:
require_complete = [False for _ in range(len(annotations))]
else:
if len(require_complete) != len(annotations):
raise ValueError(
"Have to provide the same number of values for "
"`annotations` and `require_complete`, but found "
f"{len(annotations)} values for `annotations` "
f"and {len(require_complete)} values for `require_complete"
)
def require_annotation_decorator(func: Callable) -> Callable:
def func_with_require(doclike, *args, **kwargs) -> Any:
missing = []
# Check for missing annotations
for attr, complete in zip(annotations, require_complete):
if not doclike.doc.has_annotation(attr, require_complete=complete):
missing.append(attr)
# Build error message and raise error.
if missing:
if len(annotations) == 1:
msg = "{}".format(annotations[0])
if require_complete[0]:
msg += " (complete)"
else:
msg = ""
for i, (attr, complete) in enumerate(zip(annotations, require_complete)):
if i != 0:
msg += " "
msg += str(attr)
if complete:
msg += " (complete)"
if i != len(annotations) - 1:
msg += ","
raise ValueError(
Errors.E1047.format(
name=func.__name__, annotations=msg, missing=missing
)
)
else:
return func(doclike, *args, **kwargs)
return func_with_require
return require_annotation_decorator