mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 05:40:20 +03:00
decorator to require annotations on Doc
This commit is contained in:
parent
865691d169
commit
6f8a69152d
|
@ -952,6 +952,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"sure it's overwritten on the subclass.")
|
||||
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
|
||||
"knowledge base, use `InMemoryLookupKB`.")
|
||||
E1047 = ("The function `{name}` requires annotations {annotations} to be set, but did not find {missing}.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import pytest
|
||||
import os
|
||||
import ctypes
|
||||
import spacy
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from spacy.about import __version__ as spacy_version
|
||||
from spacy import util
|
||||
from spacy import prefer_gpu, require_gpu, require_cpu
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
||||
from spacy.util import dot_to_object, SimpleFrozenList, import_file
|
||||
from spacy.util import to_ternary_int
|
||||
from spacy.util import to_ternary_int, require_annotation
|
||||
from thinc.api import Config, Optimizer, ConfigValidationError
|
||||
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
|
||||
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
|
||||
|
@ -434,3 +437,75 @@ def test_to_ternary_int():
|
|||
assert to_ternary_int(-10) == -1
|
||||
assert to_ternary_int("string") == -1
|
||||
assert to_ternary_int([0, "string"]) == -1
|
||||
|
||||
|
||||
def test_require_annotations():
|
||||
|
||||
@require_annotation([])
|
||||
def func_none(doc) -> bool:
|
||||
return True
|
||||
|
||||
@require_annotation(["TAG"])
|
||||
def func_tag(doc) -> bool:
|
||||
return True
|
||||
|
||||
@require_annotation(["ENT_IOB"])
|
||||
def func_ner(doc) -> bool:
|
||||
return True
|
||||
|
||||
@require_annotation(["TAG"], require_complete=[True])
|
||||
def func_tag_complete(doc) -> bool:
|
||||
return True
|
||||
|
||||
@require_annotation(["ENT_IOB"])
|
||||
def func_ner_complete(doc, require_complete=[True]) -> bool:
|
||||
return True
|
||||
|
||||
@require_annotation(["TAG", "ENT_IOB"])
|
||||
def func_tag_ner(doc) -> bool:
|
||||
return True
|
||||
|
||||
@require_annotation(
|
||||
["TAG", "ENT_IOB"], require_complete=[False, True]
|
||||
)
|
||||
def func_tag_nercomplete(doc):
|
||||
return True
|
||||
|
||||
@require_annotation(
|
||||
["TAG", "ENT_IOB"], require_complete=[True, False]
|
||||
)
|
||||
def func_tagcomplete_ner(doc):
|
||||
return True
|
||||
|
||||
text = "Bob is a person for sure."
|
||||
nlp = spacy.blank("en")
|
||||
blank = nlp(text)
|
||||
tagger = nlp.add_pipe("tagger")
|
||||
tagger.add_label("A")
|
||||
nlp.initialize()
|
||||
tagged = nlp(text)
|
||||
tagged_partial = nlp(text)
|
||||
tagged_partial[-1].tag_ = ""
|
||||
ner = nlp.add_pipe("ner")
|
||||
ner.add_label("CHEESE")
|
||||
nlp.initialize()
|
||||
tagged_nered = nlp(text)
|
||||
nlp.remove_pipe("tagger")
|
||||
nered = nlp(text)
|
||||
tagged_partial_nered = nlp(tagged_partial)
|
||||
|
||||
assert func_none(blank)
|
||||
assert func_none(tagged)
|
||||
assert func_tag(tagged)
|
||||
assert func_ner(nered)
|
||||
assert func_tag_complete(tagged)
|
||||
assert func_ner_complete(nered)
|
||||
assert func_tag_ner(tagged_nered)
|
||||
with pytest.raises(ValueError, match=r"[E1047]"):
|
||||
func_tag(blank)
|
||||
with pytest.raises(ValueError, match=r"[E1047]"):
|
||||
func_tag_complete(tagged_partial)
|
||||
with pytest.raises(ValueError, match=r"[E1047]"):
|
||||
func_tagcomplete_ner(tagged_partial_nered)
|
||||
with pytest.raises(ValueError, match=r"Have to provide the same number"):
|
||||
require_annotation(["TAG", "NER"], require_complete=[True])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List, Mapping, NoReturn, Union, Dict, Any, Set, cast
|
||||
from typing import Optional, Iterable, Callable, Tuple, Type
|
||||
from typing import Optional, Iterable, Callable, Tuple, Type, Sequence
|
||||
from typing import Iterator, Pattern, Generator, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
import os
|
||||
|
@ -1735,3 +1735,72 @@ def all_equal(iterable):
|
|||
(or if the input is an empty sequence), False otherwise."""
|
||||
g = itertools.groupby(iterable)
|
||||
return next(g, True) and not next(g, False)
|
||||
|
||||
|
||||
def require_annotation(
|
||||
annotations: Sequence[Union[str, int]],
|
||||
*,
|
||||
require_complete: Sequence[bool] = None
|
||||
) -> Callable:
|
||||
"""
|
||||
To be used as a decorator for functions whose first argument
|
||||
is a Doc or a Span. For example:
|
||||
|
||||
@require_annotation(["POS"])
|
||||
def extract_pos(doc: Doc):
|
||||
return [token.pos_ for token in doc]
|
||||
|
||||
Here we check if the input Doc has "POS" annotation and
|
||||
otherwise we raise an informative error.
|
||||
|
||||
annotations: Sequence[Union[str, int]]
|
||||
The annotations to check for e.g.: "DEP", "POS".
|
||||
require_complete: Sequence[bool]
|
||||
For each annotation only consider that it fullfils
|
||||
the requirement if all tokens are annotated. Otherwise
|
||||
partial annotation is accepted.
|
||||
"""
|
||||
# Check input.
|
||||
if require_complete is None:
|
||||
require_complete = [False for _ in range(len(annotations))]
|
||||
else:
|
||||
if len(require_complete) != len(annotations):
|
||||
raise ValueError(
|
||||
"Have to provide the same number of values for "
|
||||
"`annotations` and `require_complete`, but found "
|
||||
f"{len(annotations)} values for `annotations` "
|
||||
f"and {len(require_complete)} values for `require_complete"
|
||||
)
|
||||
|
||||
def require_annotation_decorator(func: Callable) -> Callable:
|
||||
def func_with_require(doclike, *args, **kwargs) -> Any:
|
||||
missing = []
|
||||
# Check for missing annotations
|
||||
for attr, complete in zip(annotations, require_complete):
|
||||
if not doclike.doc.has_annotation(attr, require_complete=complete):
|
||||
missing.append(attr)
|
||||
# Build error message and raise error.
|
||||
if missing:
|
||||
if len(annotations) == 1:
|
||||
msg = "{}".format(annotations[0])
|
||||
if require_complete[0]:
|
||||
msg += " (complete)"
|
||||
else:
|
||||
msg = ""
|
||||
for i, (attr, complete) in enumerate(zip(annotations, require_complete)):
|
||||
if i != 0:
|
||||
msg += " "
|
||||
msg += str(attr)
|
||||
if complete:
|
||||
msg += " (complete)"
|
||||
if i != len(annotations) - 1:
|
||||
msg += ","
|
||||
raise ValueError(
|
||||
Errors.E1047.format(
|
||||
name=func.__name__, annotations=msg, missing=missing
|
||||
)
|
||||
)
|
||||
else:
|
||||
return func(doclike, *args, **kwargs)
|
||||
return func_with_require
|
||||
return require_annotation_decorator
|
||||
|
|
Loading…
Reference in New Issue
Block a user