Preserve missing entity annotation in augmenters (#11540)

Preserve both `-` and `O` annotation in augmenters rather than relying
on `Example.to_dict`'s default support for one option outside of labeled
entity spans.

This is intended as a temporary workaround for augmenters for v3.4.x.
The behavior of `Example` and related IOB utils could be improved in the
general case for v3.5.
This commit is contained in:
Adriane Boyd 2022-09-27 10:16:51 +02:00 committed by GitHub
parent 936a5f0506
commit 877671e09a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 3 deletions

View File

@ -31,7 +31,7 @@ def doc(nlp):
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."] words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."] tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"] pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
ents = ["B-PERSON", "I-PERSON", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"] ents = ["B-PERSON", "I-PERSON", "O", "", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
cats = {"TRAVEL": 1.0, "BAKING": 0.0} cats = {"TRAVEL": 1.0, "BAKING": 0.0}
# fmt: on # fmt: on
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents) doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
@ -106,6 +106,7 @@ def test_lowercase_augmenter(nlp, doc):
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents): for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
assert ref_ent.text == orig_ent.text.lower() assert ref_ent.text == orig_ent.text.lower()
assert [t.ent_iob for t in doc] == [t.ent_iob for t in eg.reference]
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
# check that augmentation works when lowercasing leads to different # check that augmentation works when lowercasing leads to different
@ -166,7 +167,7 @@ def test_make_whitespace_variant(nlp):
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."] lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12] heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"] deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
ents = ["O", "O", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"] ents = ["O", "", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
# fmt: on # fmt: on
doc = Doc( doc = Doc(
nlp.vocab, nlp.vocab,
@ -215,6 +216,8 @@ def test_make_whitespace_variant(nlp):
assert mod_ex2.reference[j].head.i == j - 1 assert mod_ex2.reference[j].head.i == j - 1
# entities are well-formed # entities are well-formed
assert len(doc.ents) == len(mod_ex.reference.ents) assert len(doc.ents) == len(mod_ex.reference.ents)
# there is one token with missing entity information
assert any(t.ent_iob == 0 for t in mod_ex.reference)
for ent in mod_ex.reference.ents: for ent in mod_ex.reference.ents:
assert not ent[0].is_space assert not ent[0].is_space
assert not ent[-1].is_space assert not ent[-1].is_space

View File

@ -6,7 +6,7 @@ from functools import partial
from ..util import registry from ..util import registry
from .example import Example from .example import Example
from .iob_utils import split_bilu_label from .iob_utils import split_bilu_label, _doc_to_biluo_tags_with_partial
if TYPE_CHECKING: if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
@ -62,6 +62,9 @@ def combined_augmenter(
if orth_variants and random.random() < orth_level: if orth_variants and random.random() < orth_level:
raw_text = example.text raw_text = example.text
orig_dict = example.to_dict() orig_dict = example.to_dict()
orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
example.reference
)
variant_text, variant_token_annot = make_orth_variants( variant_text, variant_token_annot = make_orth_variants(
nlp, nlp,
raw_text, raw_text,
@ -128,6 +131,9 @@ def lower_casing_augmenter(
def make_lowercase_variant(nlp: "Language", example: Example): def make_lowercase_variant(nlp: "Language", example: Example):
example_dict = example.to_dict() example_dict = example.to_dict()
example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
example.reference
)
doc = nlp.make_doc(example.text.lower()) doc = nlp.make_doc(example.text.lower())
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference] example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
return example.from_dict(doc, example_dict) return example.from_dict(doc, example_dict)
@ -146,6 +152,9 @@ def orth_variants_augmenter(
else: else:
raw_text = example.text raw_text = example.text
orig_dict = example.to_dict() orig_dict = example.to_dict()
orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
example.reference
)
variant_text, variant_token_annot = make_orth_variants( variant_text, variant_token_annot = make_orth_variants(
nlp, nlp,
raw_text, raw_text,
@ -248,6 +257,9 @@ def make_whitespace_variant(
RETURNS (Example): Example with one additional space token. RETURNS (Example): Example with one additional space token.
""" """
example_dict = example.to_dict() example_dict = example.to_dict()
example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
example.reference
)
doc_dict = example_dict.get("doc_annotation", {}) doc_dict = example_dict.get("doc_annotation", {})
token_dict = example_dict.get("token_annotation", {}) token_dict = example_dict.get("token_annotation", {})
# returned unmodified if: # returned unmodified if:

View File

@ -60,6 +60,14 @@ def doc_to_biluo_tags(doc: Doc, missing: str = "O"):
) )
def _doc_to_biluo_tags_with_partial(doc: Doc) -> List[str]:
ents = doc_to_biluo_tags(doc, missing="-")
for i, token in enumerate(doc):
if token.ent_iob == 2:
ents[i] = "O"
return ents
def offsets_to_biluo_tags( def offsets_to_biluo_tags(
doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O" doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O"
) -> List[str]: ) -> List[str]: