mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
28ba31e793
Add whitespace augmenter that inserts a single whitespace token into a doc containing annotation used in core trained pipelines. Add a combined augmenter that handles lowercasing, orth variants and whitespace augmentation.
340 lines
13 KiB
Python
340 lines
13 KiB
Python
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
|
|
from typing import Optional
|
|
import random
|
|
import itertools
|
|
from functools import partial
|
|
from pydantic import BaseModel, StrictStr
|
|
|
|
from ..util import registry
|
|
from .example import Example
|
|
|
|
if TYPE_CHECKING:
|
|
from ..language import Language # noqa: F401
|
|
|
|
|
|
@registry.augmenters("spacy.combined_augmenter.v1")
|
|
def create_combined_augmenter(
|
|
lower_level: float,
|
|
orth_level: float,
|
|
orth_variants: Optional[Dict[str, List[Dict]]],
|
|
whitespace_level: float,
|
|
whitespace_per_token: float,
|
|
whitespace_variants: Optional[List[str]],
|
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
|
"""Create a data augmentation callback that uses orth-variant replacement.
|
|
The callback can be added to a corpus or other data iterator during training.
|
|
|
|
lower_level (float): The percentage of texts that will be lowercased.
|
|
orth_level (float): The percentage of texts that will be augmented.
|
|
orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
|
|
single and paired orth variants. Typically loaded from a JSON file.
|
|
whitespace_level (float): The percentage of texts that will have whitespace
|
|
tokens inserted.
|
|
whitespace_per_token (float): The number of whitespace tokens to insert in
|
|
the modified doc as a percentage of the doc length.
|
|
whitespace_variants (Optional[List[str]]): The whitespace token texts.
|
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
|
"""
|
|
return partial(
|
|
combined_augmenter,
|
|
lower_level=lower_level,
|
|
orth_level=orth_level,
|
|
orth_variants=orth_variants,
|
|
whitespace_level=whitespace_level,
|
|
whitespace_per_token=whitespace_per_token,
|
|
whitespace_variants=whitespace_variants,
|
|
)
|
|
|
|
|
|
def combined_augmenter(
|
|
nlp: "Language",
|
|
example: Example,
|
|
*,
|
|
lower_level: float = 0.0,
|
|
orth_level: float = 0.0,
|
|
orth_variants: Optional[Dict[str, List[Dict]]] = None,
|
|
whitespace_level: float = 0.0,
|
|
whitespace_per_token: float = 0.0,
|
|
whitespace_variants: Optional[List[str]] = None,
|
|
) -> Iterator[Example]:
|
|
if random.random() < lower_level:
|
|
example = make_lowercase_variant(nlp, example)
|
|
if orth_variants and random.random() < orth_level:
|
|
raw_text = example.text
|
|
orig_dict = example.to_dict()
|
|
variant_text, variant_token_annot = make_orth_variants(
|
|
nlp,
|
|
raw_text,
|
|
orig_dict["token_annotation"],
|
|
orth_variants,
|
|
lower=False,
|
|
)
|
|
orig_dict["token_annotation"] = variant_token_annot
|
|
example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
|
|
if whitespace_variants and random.random() < whitespace_level:
|
|
for _ in range(int(len(example.reference) * whitespace_per_token)):
|
|
example = make_whitespace_variant(
|
|
nlp,
|
|
example,
|
|
random.choice(whitespace_variants),
|
|
random.randrange(0, len(example.reference)),
|
|
)
|
|
yield example
|
|
|
|
|
|
@registry.augmenters("spacy.orth_variants.v1")
|
|
def create_orth_variants_augmenter(
|
|
level: float, lower: float, orth_variants: Dict[str, List[Dict]]
|
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
|
"""Create a data augmentation callback that uses orth-variant replacement.
|
|
The callback can be added to a corpus or other data iterator during training.
|
|
|
|
level (float): The percentage of texts that will be augmented.
|
|
lower (float): The percentage of texts that will be lowercased.
|
|
orth_variants (Dict[str, List[Dict]]): A dictionary containing
|
|
the single and paired orth variants. Typically loaded from a JSON file.
|
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
|
"""
|
|
return partial(
|
|
orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
|
|
)
|
|
|
|
|
|
@registry.augmenters("spacy.lower_case.v1")
|
|
def create_lower_casing_augmenter(
|
|
level: float,
|
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
|
"""Create a data augmentation callback that converts documents to lowercase.
|
|
The callback can be added to a corpus or other data iterator during training.
|
|
|
|
level (float): The percentage of texts that will be augmented.
|
|
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
|
"""
|
|
return partial(lower_casing_augmenter, level=level)
|
|
|
|
|
|
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
|
yield example
|
|
|
|
|
|
def lower_casing_augmenter(
|
|
nlp: "Language", example: Example, *, level: float
|
|
) -> Iterator[Example]:
|
|
if random.random() >= level:
|
|
yield example
|
|
else:
|
|
yield make_lowercase_variant(nlp, example)
|
|
|
|
|
|
def make_lowercase_variant(nlp: "Language", example: Example):
|
|
example_dict = example.to_dict()
|
|
doc = nlp.make_doc(example.text.lower())
|
|
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
|
|
return example.from_dict(doc, example_dict)
|
|
|
|
|
|
def orth_variants_augmenter(
|
|
nlp: "Language",
|
|
example: Example,
|
|
orth_variants: Dict[str, List[Dict]],
|
|
*,
|
|
level: float = 0.0,
|
|
lower: float = 0.0,
|
|
) -> Iterator[Example]:
|
|
if random.random() >= level:
|
|
yield example
|
|
else:
|
|
raw_text = example.text
|
|
orig_dict = example.to_dict()
|
|
variant_text, variant_token_annot = make_orth_variants(
|
|
nlp,
|
|
raw_text,
|
|
orig_dict["token_annotation"],
|
|
orth_variants,
|
|
lower=raw_text is not None and random.random() < lower,
|
|
)
|
|
orig_dict["token_annotation"] = variant_token_annot
|
|
yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
|
|
|
|
|
|
def make_orth_variants(
|
|
nlp: "Language",
|
|
raw: str,
|
|
token_dict: Dict[str, List[str]],
|
|
orth_variants: Dict[str, List[Dict[str, List[str]]]],
|
|
*,
|
|
lower: bool = False,
|
|
) -> Tuple[str, Dict[str, List[str]]]:
|
|
words = token_dict.get("ORTH", [])
|
|
tags = token_dict.get("TAG", [])
|
|
# keep unmodified if words are not defined
|
|
if not words:
|
|
return raw, token_dict
|
|
if lower:
|
|
words = [w.lower() for w in words]
|
|
raw = raw.lower()
|
|
# if no tags, only lowercase
|
|
if not tags:
|
|
token_dict["ORTH"] = words
|
|
return raw, token_dict
|
|
# single variants
|
|
ndsv = orth_variants.get("single", [])
|
|
punct_choices = [random.choice(x["variants"]) for x in ndsv]
|
|
for word_idx in range(len(words)):
|
|
for punct_idx in range(len(ndsv)):
|
|
if (
|
|
tags[word_idx] in ndsv[punct_idx]["tags"]
|
|
and words[word_idx] in ndsv[punct_idx]["variants"]
|
|
):
|
|
words[word_idx] = punct_choices[punct_idx]
|
|
# paired variants
|
|
ndpv = orth_variants.get("paired", [])
|
|
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
|
for word_idx in range(len(words)):
|
|
for punct_idx in range(len(ndpv)):
|
|
if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
|
|
word_idx
|
|
] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
|
|
# backup option: random left vs. right from pair
|
|
pair_idx = random.choice([0, 1])
|
|
# best option: rely on paired POS tags like `` / ''
|
|
if len(ndpv[punct_idx]["tags"]) == 2:
|
|
pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
|
|
# next best option: rely on position in variants
|
|
# (may not be unambiguous, so order of variants matters)
|
|
else:
|
|
for pair in ndpv[punct_idx]["variants"]:
|
|
if words[word_idx] in pair:
|
|
pair_idx = pair.index(words[word_idx])
|
|
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
|
token_dict["ORTH"] = words
|
|
raw = construct_modified_raw_text(token_dict)
|
|
return raw, token_dict
|
|
|
|
|
|
def make_whitespace_variant(
|
|
nlp: "Language",
|
|
example: Example,
|
|
whitespace: str,
|
|
position: int,
|
|
) -> Example:
|
|
"""Insert the whitespace token at the specified token offset in the doc.
|
|
This is primarily intended for v2-compatible training data that doesn't
|
|
include links or spans. If the document includes links, spans, or partial
|
|
dependency annotation, it is returned without modifications.
|
|
|
|
The augmentation follows the basics of the v2 space attachment policy, but
|
|
without a distinction between "real" and other tokens, so space tokens
|
|
may be attached to space tokens:
|
|
- at the beginning of a sentence attach the space token to the following
|
|
token
|
|
- otherwise attach the space token to the preceding token
|
|
|
|
The augmenter does not attempt to consolidate adjacent whitespace in the
|
|
same way that the tokenizer would.
|
|
|
|
The following annotation is used for the space token:
|
|
TAG: "_SP"
|
|
MORPH: ""
|
|
POS: "SPACE"
|
|
LEMMA: ORTH
|
|
DEP: "dep"
|
|
SENT_START: False
|
|
|
|
The annotation for each attribute is only set for the space token if there
|
|
is already at least partial annotation for that attribute in the original
|
|
example.
|
|
|
|
RETURNS (Example): Example with one additional space token.
|
|
"""
|
|
example_dict = example.to_dict()
|
|
doc_dict = example_dict.get("doc_annotation", {})
|
|
token_dict = example_dict.get("token_annotation", {})
|
|
# returned unmodified if:
|
|
# - doc is empty
|
|
# - words are not defined
|
|
# - links are defined (only character-based offsets, which is more a quirk
|
|
# of Example.to_dict than a technical constraint)
|
|
# - spans are defined
|
|
# - there are partial dependencies
|
|
if (
|
|
len(example.reference) == 0
|
|
or "ORTH" not in token_dict
|
|
or len(doc_dict.get("links", [])) > 0
|
|
or len(example.reference.spans) > 0
|
|
or (
|
|
example.reference.has_annotation("DEP")
|
|
and not example.reference.has_annotation("DEP", require_complete=True)
|
|
)
|
|
):
|
|
return example
|
|
words = token_dict.get("ORTH", [])
|
|
length = len(words)
|
|
assert 0 <= position <= length
|
|
if example.reference.has_annotation("ENT_TYPE"):
|
|
# I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
|
|
entity = "O"
|
|
if position > 1 and position < length:
|
|
ent_prev = doc_dict["entities"][position - 1]
|
|
ent_next = doc_dict["entities"][position]
|
|
if "-" in ent_prev and "-" in ent_next:
|
|
ent_iob_prev = ent_prev.split("-")[0]
|
|
ent_type_prev = ent_prev.split("-", 1)[1]
|
|
ent_iob_next = ent_next.split("-")[0]
|
|
ent_type_next = ent_next.split("-", 1)[1]
|
|
if (
|
|
ent_iob_prev in ("B", "I")
|
|
and ent_iob_next in ("I", "L")
|
|
and ent_type_prev == ent_type_next
|
|
):
|
|
entity = f"I-{ent_type_prev}"
|
|
doc_dict["entities"].insert(position, entity)
|
|
else:
|
|
del doc_dict["entities"]
|
|
token_dict["ORTH"].insert(position, whitespace)
|
|
token_dict["SPACY"].insert(position, False)
|
|
if example.reference.has_annotation("TAG"):
|
|
token_dict["TAG"].insert(position, "_SP")
|
|
else:
|
|
del token_dict["TAG"]
|
|
if example.reference.has_annotation("LEMMA"):
|
|
token_dict["LEMMA"].insert(position, whitespace)
|
|
else:
|
|
del token_dict["LEMMA"]
|
|
if example.reference.has_annotation("POS"):
|
|
token_dict["POS"].insert(position, "SPACE")
|
|
else:
|
|
del token_dict["POS"]
|
|
if example.reference.has_annotation("MORPH"):
|
|
token_dict["MORPH"].insert(position, "")
|
|
else:
|
|
del token_dict["MORPH"]
|
|
if example.reference.has_annotation("DEP", require_complete=True):
|
|
if position == 0:
|
|
token_dict["HEAD"].insert(position, 0)
|
|
else:
|
|
token_dict["HEAD"].insert(position, position - 1)
|
|
for i in range(len(token_dict["HEAD"])):
|
|
if token_dict["HEAD"][i] >= position:
|
|
token_dict["HEAD"][i] += 1
|
|
token_dict["DEP"].insert(position, "dep")
|
|
else:
|
|
del token_dict["HEAD"]
|
|
del token_dict["DEP"]
|
|
if example.reference.has_annotation("SENT_START"):
|
|
token_dict["SENT_START"].insert(position, False)
|
|
else:
|
|
del token_dict["SENT_START"]
|
|
raw = construct_modified_raw_text(token_dict)
|
|
return Example.from_dict(nlp.make_doc(raw), example_dict)
|
|
|
|
|
|
def construct_modified_raw_text(token_dict):
|
|
"""Construct modified raw text from words and spaces."""
|
|
raw = ""
|
|
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
|
|
raw += orth
|
|
if spacy:
|
|
raw += " "
|
|
return raw
|