mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
877671e09a
Preserve both `-` and `O` annotation in augmenters rather than relying on `Example.to_dict`'s default support for one option outside of labeled entity spans. This is intended as a temporary workaround for augmenters for v3.4.x. The behavior of `Example` and related IOB utils could be improved in the general case for v3.5.
350 lines
13 KiB
Python
350 lines
13 KiB
Python
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
|
|
from typing import Optional
|
|
import random
|
|
import itertools
|
|
from functools import partial
|
|
|
|
from ..util import registry
|
|
from .example import Example
|
|
from .iob_utils import split_bilu_label, _doc_to_biluo_tags_with_partial
|
|
|
|
if TYPE_CHECKING:
|
|
from ..language import Language # noqa: F401
|
|
|
|
|
|
@registry.augmenters("spacy.combined_augmenter.v1")
|
|
def create_combined_augmenter(
|
|
lower_level: float,
|
|
orth_level: float,
|
|
orth_variants: Optional[Dict[str, List[Dict]]],
|
|
whitespace_level: float,
|
|
whitespace_per_token: float,
|
|
whitespace_variants: Optional[List[str]],
|
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
|
"""Create a data augmentation callback that uses orth-variant replacement.
|
|
The callback can be added to a corpus or other data iterator during training.
|
|
|
|
lower_level (float): The percentage of texts that will be lowercased.
|
|
orth_level (float): The percentage of texts that will be augmented.
|
|
orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
|
|
single and paired orth variants. Typically loaded from a JSON file.
|
|
whitespace_level (float): The percentage of texts that will have whitespace
|
|
tokens inserted.
|
|
whitespace_per_token (float): The number of whitespace tokens to insert in
|
|
the modified doc as a percentage of the doc length.
|
|
whitespace_variants (Optional[List[str]]): The whitespace token texts.
|
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
|
"""
|
|
return partial(
|
|
combined_augmenter,
|
|
lower_level=lower_level,
|
|
orth_level=orth_level,
|
|
orth_variants=orth_variants,
|
|
whitespace_level=whitespace_level,
|
|
whitespace_per_token=whitespace_per_token,
|
|
whitespace_variants=whitespace_variants,
|
|
)
|
|
|
|
|
|
def combined_augmenter(
|
|
nlp: "Language",
|
|
example: Example,
|
|
*,
|
|
lower_level: float = 0.0,
|
|
orth_level: float = 0.0,
|
|
orth_variants: Optional[Dict[str, List[Dict]]] = None,
|
|
whitespace_level: float = 0.0,
|
|
whitespace_per_token: float = 0.0,
|
|
whitespace_variants: Optional[List[str]] = None,
|
|
) -> Iterator[Example]:
|
|
if random.random() < lower_level:
|
|
example = make_lowercase_variant(nlp, example)
|
|
if orth_variants and random.random() < orth_level:
|
|
raw_text = example.text
|
|
orig_dict = example.to_dict()
|
|
orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
|
example.reference
|
|
)
|
|
variant_text, variant_token_annot = make_orth_variants(
|
|
nlp,
|
|
raw_text,
|
|
orig_dict["token_annotation"],
|
|
orth_variants,
|
|
lower=False,
|
|
)
|
|
orig_dict["token_annotation"] = variant_token_annot
|
|
example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
|
|
if whitespace_variants and random.random() < whitespace_level:
|
|
for _ in range(int(len(example.reference) * whitespace_per_token)):
|
|
example = make_whitespace_variant(
|
|
nlp,
|
|
example,
|
|
random.choice(whitespace_variants),
|
|
random.randrange(0, len(example.reference)),
|
|
)
|
|
yield example
|
|
|
|
|
|
@registry.augmenters("spacy.orth_variants.v1")
|
|
def create_orth_variants_augmenter(
|
|
level: float, lower: float, orth_variants: Dict[str, List[Dict]]
|
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
|
"""Create a data augmentation callback that uses orth-variant replacement.
|
|
The callback can be added to a corpus or other data iterator during training.
|
|
|
|
level (float): The percentage of texts that will be augmented.
|
|
lower (float): The percentage of texts that will be lowercased.
|
|
orth_variants (Dict[str, List[Dict]]): A dictionary containing
|
|
the single and paired orth variants. Typically loaded from a JSON file.
|
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
|
"""
|
|
return partial(
|
|
orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
|
|
)
|
|
|
|
|
|
@registry.augmenters("spacy.lower_case.v1")
|
|
def create_lower_casing_augmenter(
|
|
level: float,
|
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
|
"""Create a data augmentation callback that converts documents to lowercase.
|
|
The callback can be added to a corpus or other data iterator during training.
|
|
|
|
level (float): The percentage of texts that will be augmented.
|
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
|
"""
|
|
return partial(lower_casing_augmenter, level=level)
|
|
|
|
|
|
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
|
yield example
|
|
|
|
|
|
def lower_casing_augmenter(
|
|
nlp: "Language", example: Example, *, level: float
|
|
) -> Iterator[Example]:
|
|
if random.random() >= level:
|
|
yield example
|
|
else:
|
|
yield make_lowercase_variant(nlp, example)
|
|
|
|
|
|
def make_lowercase_variant(nlp: "Language", example: Example):
|
|
example_dict = example.to_dict()
|
|
example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
|
example.reference
|
|
)
|
|
doc = nlp.make_doc(example.text.lower())
|
|
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
|
|
return example.from_dict(doc, example_dict)
|
|
|
|
|
|
def orth_variants_augmenter(
|
|
nlp: "Language",
|
|
example: Example,
|
|
orth_variants: Dict[str, List[Dict]],
|
|
*,
|
|
level: float = 0.0,
|
|
lower: float = 0.0,
|
|
) -> Iterator[Example]:
|
|
if random.random() >= level:
|
|
yield example
|
|
else:
|
|
raw_text = example.text
|
|
orig_dict = example.to_dict()
|
|
orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
|
example.reference
|
|
)
|
|
variant_text, variant_token_annot = make_orth_variants(
|
|
nlp,
|
|
raw_text,
|
|
orig_dict["token_annotation"],
|
|
orth_variants,
|
|
lower=raw_text is not None and random.random() < lower,
|
|
)
|
|
orig_dict["token_annotation"] = variant_token_annot
|
|
yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
|
|
|
|
|
|
def make_orth_variants(
|
|
nlp: "Language",
|
|
raw: str,
|
|
token_dict: Dict[str, List[str]],
|
|
orth_variants: Dict[str, List[Dict[str, List[str]]]],
|
|
*,
|
|
lower: bool = False,
|
|
) -> Tuple[str, Dict[str, List[str]]]:
|
|
words = token_dict.get("ORTH", [])
|
|
tags = token_dict.get("TAG", [])
|
|
# keep unmodified if words are not defined
|
|
if not words:
|
|
return raw, token_dict
|
|
if lower:
|
|
words = [w.lower() for w in words]
|
|
raw = raw.lower()
|
|
# if no tags, only lowercase
|
|
if not tags:
|
|
token_dict["ORTH"] = words
|
|
return raw, token_dict
|
|
# single variants
|
|
ndsv = orth_variants.get("single", [])
|
|
punct_choices = [random.choice(x["variants"]) for x in ndsv]
|
|
for word_idx in range(len(words)):
|
|
for punct_idx in range(len(ndsv)):
|
|
if (
|
|
tags[word_idx] in ndsv[punct_idx]["tags"]
|
|
and words[word_idx] in ndsv[punct_idx]["variants"]
|
|
):
|
|
words[word_idx] = punct_choices[punct_idx]
|
|
# paired variants
|
|
ndpv = orth_variants.get("paired", [])
|
|
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
|
for word_idx in range(len(words)):
|
|
for punct_idx in range(len(ndpv)):
|
|
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
|
|
word_idx
|
|
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
|
|
# backup option: random left vs. right from pair
|
|
pair_idx = random.choice([0, 1])
|
|
# best option: rely on paired POS tags like `` / ''
|
|
if len(ndpv[punct_idx]["tags"]) == 2:
|
|
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
|
|
# next best option: rely on position in variants
|
|
# (may not be unambiguous, so order of variants matters)
|
|
else:
|
|
for pair in ndpv[punct_idx]["variants"]:
|
|
if words[word_idx] in pair:
|
|
pair_idx = pair.index(words[word_idx])
|
|
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
|
token_dict["ORTH"] = words
|
|
raw = construct_modified_raw_text(token_dict)
|
|
return raw, token_dict
|
|
|
|
|
|
def make_whitespace_variant(
|
|
nlp: "Language",
|
|
example: Example,
|
|
whitespace: str,
|
|
position: int,
|
|
) -> Example:
|
|
"""Insert the whitespace token at the specified token offset in the doc.
|
|
This is primarily intended for v2-compatible training data that doesn't
|
|
include links or spans. If the document includes links, spans, or partial
|
|
dependency annotation, it is returned without modifications.
|
|
|
|
The augmentation follows the basics of the v2 space attachment policy, but
|
|
without a distinction between "real" and other tokens, so space tokens
|
|
may be attached to space tokens:
|
|
- at the beginning of a sentence attach the space token to the following
|
|
token
|
|
- otherwise attach the space token to the preceding token
|
|
|
|
The augmenter does not attempt to consolidate adjacent whitespace in the
|
|
same way that the tokenizer would.
|
|
|
|
The following annotation is used for the space token:
|
|
TAG: "_SP"
|
|
MORPH: ""
|
|
POS: "SPACE"
|
|
LEMMA: ORTH
|
|
DEP: "dep"
|
|
SENT_START: False
|
|
|
|
The annotation for each attribute is only set for the space token if there
|
|
is already at least partial annotation for that attribute in the original
|
|
example.
|
|
|
|
RETURNS (Example): Example with one additional space token.
|
|
"""
|
|
example_dict = example.to_dict()
|
|
example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
|
|
example.reference
|
|
)
|
|
doc_dict = example_dict.get("doc_annotation", {})
|
|
token_dict = example_dict.get("token_annotation", {})
|
|
# returned unmodified if:
|
|
# - doc is empty
|
|
# - words are not defined
|
|
# - links are defined (only character-based offsets, which is more a quirk
|
|
# of Example.to_dict than a technical constraint)
|
|
# - spans are defined
|
|
# - there are partial dependencies
|
|
if (
|
|
len(example.reference) == 0
|
|
or "ORTH" not in token_dict
|
|
or len(doc_dict.get("links", [])) > 0
|
|
or len(example.reference.spans) > 0
|
|
or (
|
|
example.reference.has_annotation("DEP")
|
|
and not example.reference.has_annotation("DEP", require_complete=True)
|
|
)
|
|
):
|
|
return example
|
|
words = token_dict.get("ORTH", [])
|
|
length = len(words)
|
|
assert 0 <= position <= length
|
|
if example.reference.has_annotation("ENT_TYPE"):
|
|
# I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
|
|
entity = "O"
|
|
if position > 1 and position < length:
|
|
ent_prev = doc_dict["entities"][position - 1]
|
|
ent_next = doc_dict["entities"][position]
|
|
if "-" in ent_prev and "-" in ent_next:
|
|
ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
|
|
ent_iob_next, ent_type_next = split_bilu_label(ent_next)
|
|
if (
|
|
ent_iob_prev in ("B", "I")
|
|
and ent_iob_next in ("I", "L")
|
|
and ent_type_prev == ent_type_next
|
|
):
|
|
entity = f"I-{ent_type_prev}"
|
|
doc_dict["entities"].insert(position, entity)
|
|
else:
|
|
del doc_dict["entities"]
|
|
token_dict["ORTH"].insert(position, whitespace)
|
|
token_dict["SPACY"].insert(position, False)
|
|
if example.reference.has_annotation("TAG"):
|
|
token_dict["TAG"].insert(position, "_SP")
|
|
else:
|
|
del token_dict["TAG"]
|
|
if example.reference.has_annotation("LEMMA"):
|
|
token_dict["LEMMA"].insert(position, whitespace)
|
|
else:
|
|
del token_dict["LEMMA"]
|
|
if example.reference.has_annotation("POS"):
|
|
token_dict["POS"].insert(position, "SPACE")
|
|
else:
|
|
del token_dict["POS"]
|
|
if example.reference.has_annotation("MORPH"):
|
|
token_dict["MORPH"].insert(position, "")
|
|
else:
|
|
del token_dict["MORPH"]
|
|
if example.reference.has_annotation("DEP", require_complete=True):
|
|
if position == 0:
|
|
token_dict["HEAD"].insert(position, 0)
|
|
else:
|
|
token_dict["HEAD"].insert(position, position - 1)
|
|
for i in range(len(token_dict["HEAD"])):
|
|
if token_dict["HEAD"][i] >= position:
|
|
token_dict["HEAD"][i] += 1
|
|
token_dict["DEP"].insert(position, "dep")
|
|
else:
|
|
del token_dict["HEAD"]
|
|
del token_dict["DEP"]
|
|
if example.reference.has_annotation("SENT_START"):
|
|
token_dict["SENT_START"].insert(position, False)
|
|
else:
|
|
del token_dict["SENT_START"]
|
|
raw = construct_modified_raw_text(token_dict)
|
|
return Example.from_dict(nlp.make_doc(raw), example_dict)
|
|
|
|
|
|
def construct_modified_raw_text(token_dict):
|
|
"""Construct modified raw text from words and spaces."""
|
|
raw = ""
|
|
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
|
|
raw += orth
|
|
if spacy:
|
|
raw += " "
|
|
return raw
|