mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Preserve missing entity annotation in augmenters (#11540)
Preserve both `-` and `O` annotation in augmenters rather than relying on `Example.to_dict`'s default support for one option outside of labeled entity spans. This is intended as a temporary workaround for augmenters for v3.4.x. The behavior of `Example` and related IOB utils could be improved in the general case for v3.5.
This commit is contained in:
parent
936a5f0506
commit
877671e09a
|
@ -31,7 +31,7 @@ def doc(nlp):
|
||||||
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
words = ["Sarah", "'s", "sister", "flew", "to", "Silicon", "Valley", "via", "London", "."]
|
||||||
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
|
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
|
||||||
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
|
pos = ["PROPN", "PART", "NOUN", "VERB", "ADP", "PROPN", "PROPN", "ADP", "PROPN", "PUNCT"]
|
||||||
ents = ["B-PERSON", "I-PERSON", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
|
ents = ["B-PERSON", "I-PERSON", "O", "", "O", "B-LOC", "I-LOC", "O", "B-GPE", "O"]
|
||||||
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
||||||
# fmt: on
|
# fmt: on
|
||||||
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
|
doc = Doc(nlp.vocab, words=words, tags=tags, pos=pos, ents=ents)
|
||||||
|
@ -106,6 +106,7 @@ def test_lowercase_augmenter(nlp, doc):
|
||||||
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
|
assert [(e.start, e.end, e.label) for e in eg.reference.ents] == ents
|
||||||
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
|
for ref_ent, orig_ent in zip(eg.reference.ents, doc.ents):
|
||||||
assert ref_ent.text == orig_ent.text.lower()
|
assert ref_ent.text == orig_ent.text.lower()
|
||||||
|
assert [t.ent_iob for t in doc] == [t.ent_iob for t in eg.reference]
|
||||||
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
||||||
|
|
||||||
# check that augmentation works when lowercasing leads to different
|
# check that augmentation works when lowercasing leads to different
|
||||||
|
@ -166,7 +167,7 @@ def test_make_whitespace_variant(nlp):
|
||||||
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
|
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
|
||||||
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
|
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
|
||||||
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
|
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
|
||||||
ents = ["O", "O", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
|
ents = ["O", "", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
|
@ -215,6 +216,8 @@ def test_make_whitespace_variant(nlp):
|
||||||
assert mod_ex2.reference[j].head.i == j - 1
|
assert mod_ex2.reference[j].head.i == j - 1
|
||||||
# entities are well-formed
|
# entities are well-formed
|
||||||
assert len(doc.ents) == len(mod_ex.reference.ents)
|
assert len(doc.ents) == len(mod_ex.reference.ents)
|
||||||
|
# there is one token with missing entity information
|
||||||
|
assert any(t.ent_iob == 0 for t in mod_ex.reference)
|
||||||
for ent in mod_ex.reference.ents:
|
for ent in mod_ex.reference.ents:
|
||||||
assert not ent[0].is_space
|
assert not ent[0].is_space
|
||||||
assert not ent[-1].is_space
|
assert not ent[-1].is_space
|
||||||
|
|
|
@ -6,7 +6,7 @@ from functools import partial
|
||||||
|
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from .iob_utils import split_bilu_label
|
from .iob_utils import split_bilu_label, _doc_to_biluo_tags_with_partial
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
@ -62,6 +62,9 @@ def combined_augmenter(
|
||||||
if orth_variants and random.random() < orth_level:
|
if orth_variants and random.random() < orth_level:
|
||||||
raw_text = example.text
|
raw_text = example.text
|
||||||
orig_dict = example.to_dict()
|
orig_dict = example.to_dict()
|
||||||
|
orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
||||||
|
example.reference
|
||||||
|
)
|
||||||
variant_text, variant_token_annot = make_orth_variants(
|
variant_text, variant_token_annot = make_orth_variants(
|
||||||
nlp,
|
nlp,
|
||||||
raw_text,
|
raw_text,
|
||||||
|
@ -128,6 +131,9 @@ def lower_casing_augmenter(
|
||||||
|
|
||||||
def make_lowercase_variant(nlp: "Language", example: Example):
|
def make_lowercase_variant(nlp: "Language", example: Example):
|
||||||
example_dict = example.to_dict()
|
example_dict = example.to_dict()
|
||||||
|
example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
||||||
|
example.reference
|
||||||
|
)
|
||||||
doc = nlp.make_doc(example.text.lower())
|
doc = nlp.make_doc(example.text.lower())
|
||||||
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
|
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
|
||||||
return example.from_dict(doc, example_dict)
|
return example.from_dict(doc, example_dict)
|
||||||
|
@ -146,6 +152,9 @@ def orth_variants_augmenter(
|
||||||
else:
|
else:
|
||||||
raw_text = example.text
|
raw_text = example.text
|
||||||
orig_dict = example.to_dict()
|
orig_dict = example.to_dict()
|
||||||
|
orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
||||||
|
example.reference
|
||||||
|
)
|
||||||
variant_text, variant_token_annot = make_orth_variants(
|
variant_text, variant_token_annot = make_orth_variants(
|
||||||
nlp,
|
nlp,
|
||||||
raw_text,
|
raw_text,
|
||||||
|
@ -248,6 +257,9 @@ def make_whitespace_variant(
|
||||||
RETURNS (Example): Example with one additional space token.
|
RETURNS (Example): Example with one additional space token.
|
||||||
"""
|
"""
|
||||||
example_dict = example.to_dict()
|
example_dict = example.to_dict()
|
||||||
|
example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
||||||
|
example.reference
|
||||||
|
)
|
||||||
doc_dict = example_dict.get("doc_annotation", {})
|
doc_dict = example_dict.get("doc_annotation", {})
|
||||||
token_dict = example_dict.get("token_annotation", {})
|
token_dict = example_dict.get("token_annotation", {})
|
||||||
# returned unmodified if:
|
# returned unmodified if:
|
||||||
|
|
|
@ -60,6 +60,14 @@ def doc_to_biluo_tags(doc: Doc, missing: str = "O"):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _doc_to_biluo_tags_with_partial(doc: Doc) -> List[str]:
|
||||||
|
ents = doc_to_biluo_tags(doc, missing="-")
|
||||||
|
for i, token in enumerate(doc):
|
||||||
|
if token.ent_iob == 2:
|
||||||
|
ents[i] = "O"
|
||||||
|
return ents
|
||||||
|
|
||||||
|
|
||||||
def offsets_to_biluo_tags(
|
def offsets_to_biluo_tags(
|
||||||
doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O"
|
doc: Doc, entities: Iterable[Tuple[int, int, Union[str, int]]], missing: str = "O"
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user