mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Add whitespace and combined augmenters (#10170)
Add whitespace augmenter that inserts a single whitespace token into a doc containing annotation used in core trained pipelines. Add a combined augmenter that handles lowercasing, orth variants and whitespace augmentation.
This commit is contained in:
parent
aa93b471a1
commit
28ba31e793
|
@ -1,9 +1,11 @@
|
|||
import pytest
|
||||
from spacy.training import Corpus
|
||||
from spacy.pipeline._parser_internals.nonproj import contains_cycle
|
||||
from spacy.training import Corpus, Example
|
||||
from spacy.training.augment import create_orth_variants_augmenter
|
||||
from spacy.training.augment import create_lower_casing_augmenter
|
||||
from spacy.training.augment import make_whitespace_variant
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import DocBin, Doc
|
||||
from spacy.tokens import DocBin, Doc, Span
|
||||
from contextlib import contextmanager
|
||||
import random
|
||||
|
||||
|
@ -153,3 +155,84 @@ def test_custom_data_augmentation(nlp, doc):
|
|||
ents = [(e.start, e.end, e.label) for e in doc.ents]
|
||||
assert [(e.start, e.end, e.label) for e in corpus[0].reference.ents] == ents
|
||||
assert [(e.start, e.end, e.label) for e in corpus[1].reference.ents] == ents
|
||||
|
||||
|
||||
def test_make_whitespace_variant(nlp):
|
||||
# fmt: off
|
||||
text = "They flew to New York City.\nThen they drove to Washington, D.C."
|
||||
words = ["They", "flew", "to", "New", "York", "City", ".", "\n", "Then", "they", "drove", "to", "Washington", ",", "D.C."]
|
||||
spaces = [True, True, True, True, True, False, False, False, True, True, True, True, False, True, False]
|
||||
tags = ["PRP", "VBD", "IN", "NNP", "NNP", "NNP", ".", "_SP", "RB", "PRP", "VBD", "IN", "NNP", ",", "NNP"]
|
||||
lemmas = ["they", "fly", "to", "New", "York", "City", ".", "\n", "then", "they", "drive", "to", "Washington", ",", "D.C."]
|
||||
heads = [1, 1, 1, 4, 5, 2, 1, 10, 10, 10, 10, 10, 11, 12, 12]
|
||||
deps = ["nsubj", "ROOT", "prep", "compound", "compound", "pobj", "punct", "dep", "advmod", "nsubj", "ROOT", "prep", "pobj", "punct", "appos"]
|
||||
ents = ["O", "O", "O", "B-GPE", "I-GPE", "I-GPE", "O", "O", "O", "O", "O", "O", "B-GPE", "O", "B-GPE"]
|
||||
# fmt: on
|
||||
doc = Doc(
|
||||
nlp.vocab,
|
||||
words=words,
|
||||
spaces=spaces,
|
||||
tags=tags,
|
||||
lemmas=lemmas,
|
||||
heads=heads,
|
||||
deps=deps,
|
||||
ents=ents,
|
||||
)
|
||||
assert doc.text == text
|
||||
example = Example(nlp.make_doc(text), doc)
|
||||
# whitespace is only added internally in entity spans
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 3)
|
||||
assert mod_ex.reference.ents[0].text == "New York City"
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 4)
|
||||
assert mod_ex.reference.ents[0].text == "New York City"
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
|
||||
assert mod_ex.reference.ents[0].text == "New York City"
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 6)
|
||||
assert mod_ex.reference.ents[0].text == "New York City"
|
||||
# add a space at every possible position
|
||||
for i in range(len(doc) + 1):
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", i)
|
||||
assert mod_ex.reference[i].is_space
|
||||
# adds annotation when the doc contains at least partial annotation
|
||||
assert [t.tag_ for t in mod_ex.reference] == tags[:i] + ["_SP"] + tags[i:]
|
||||
assert [t.lemma_ for t in mod_ex.reference] == lemmas[:i] + [" "] + lemmas[i:]
|
||||
assert [t.dep_ for t in mod_ex.reference] == deps[:i] + ["dep"] + deps[i:]
|
||||
# does not add partial annotation if doc does not contain this feature
|
||||
assert not mod_ex.reference.has_annotation("POS")
|
||||
assert not mod_ex.reference.has_annotation("MORPH")
|
||||
# produces well-formed trees
|
||||
assert not contains_cycle([t.head.i for t in mod_ex.reference])
|
||||
assert len(list(doc.sents)) == 2
|
||||
if i == 0:
|
||||
assert mod_ex.reference[i].head.i == 1
|
||||
else:
|
||||
assert mod_ex.reference[i].head.i == i - 1
|
||||
# adding another space also produces well-formed trees
|
||||
for j in (3, 8, 10):
|
||||
mod_ex2 = make_whitespace_variant(nlp, mod_ex, "\t\t\n", j)
|
||||
assert not contains_cycle([t.head.i for t in mod_ex2.reference])
|
||||
assert len(list(doc.sents)) == 2
|
||||
assert mod_ex2.reference[j].head.i == j - 1
|
||||
# entities are well-formed
|
||||
assert len(doc.ents) == len(mod_ex.reference.ents)
|
||||
for ent in mod_ex.reference.ents:
|
||||
assert not ent[0].is_space
|
||||
assert not ent[-1].is_space
|
||||
|
||||
# no modifications if:
|
||||
# partial dependencies
|
||||
example.reference[0].dep_ = ""
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
|
||||
assert mod_ex.text == example.reference.text
|
||||
example.reference[0].dep_ = "nsubj" # reset
|
||||
|
||||
# spans
|
||||
example.reference.spans["spans"] = [example.reference[0:5]]
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
|
||||
assert mod_ex.text == example.reference.text
|
||||
del example.reference.spans["spans"] # reset
|
||||
|
||||
# links
|
||||
example.reference.ents = [Span(doc, 0, 2, label="ENT", kb_id="Q123")]
|
||||
mod_ex = make_whitespace_variant(nlp, example, " ", 5)
|
||||
assert mod_ex.text == example.reference.text
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Optional
|
||||
import random
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
@ -11,32 +12,87 @@ if TYPE_CHECKING:
|
|||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
class OrthVariantsSingle(BaseModel):
|
||||
tags: List[StrictStr]
|
||||
variants: List[StrictStr]
|
||||
@registry.augmenters("spacy.combined_augmenter.v1")
|
||||
def create_combined_augmenter(
|
||||
lower_level: float,
|
||||
orth_level: float,
|
||||
orth_variants: Optional[Dict[str, List[Dict]]],
|
||||
whitespace_level: float,
|
||||
whitespace_per_token: float,
|
||||
whitespace_variants: Optional[List[str]],
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
"""Create a data augmentation callback that uses orth-variant replacement.
|
||||
The callback can be added to a corpus or other data iterator during training.
|
||||
|
||||
lower_level (float): The percentage of texts that will be lowercased.
|
||||
orth_level (float): The percentage of texts that will be augmented.
|
||||
orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
|
||||
single and paired orth variants. Typically loaded from a JSON file.
|
||||
whitespace_level (float): The percentage of texts that will have whitespace
|
||||
tokens inserted.
|
||||
whitespace_per_token (float): The number of whitespace tokens to insert in
|
||||
the modified doc as a percentage of the doc length.
|
||||
whitespace_variants (Optional[List[str]]): The whitespace token texts.
|
||||
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
||||
"""
|
||||
return partial(
|
||||
combined_augmenter,
|
||||
lower_level=lower_level,
|
||||
orth_level=orth_level,
|
||||
orth_variants=orth_variants,
|
||||
whitespace_level=whitespace_level,
|
||||
whitespace_per_token=whitespace_per_token,
|
||||
whitespace_variants=whitespace_variants,
|
||||
)
|
||||
|
||||
|
||||
class OrthVariantsPaired(BaseModel):
|
||||
tags: List[StrictStr]
|
||||
variants: List[List[StrictStr]]
|
||||
|
||||
|
||||
class OrthVariants(BaseModel):
|
||||
paired: List[OrthVariantsPaired] = []
|
||||
single: List[OrthVariantsSingle] = []
|
||||
def combined_augmenter(
|
||||
nlp: "Language",
|
||||
example: Example,
|
||||
*,
|
||||
lower_level: float = 0.0,
|
||||
orth_level: float = 0.0,
|
||||
orth_variants: Optional[Dict[str, List[Dict]]] = None,
|
||||
whitespace_level: float = 0.0,
|
||||
whitespace_per_token: float = 0.0,
|
||||
whitespace_variants: Optional[List[str]] = None,
|
||||
) -> Iterator[Example]:
|
||||
if random.random() < lower_level:
|
||||
example = make_lowercase_variant(nlp, example)
|
||||
if orth_variants and random.random() < orth_level:
|
||||
raw_text = example.text
|
||||
orig_dict = example.to_dict()
|
||||
variant_text, variant_token_annot = make_orth_variants(
|
||||
nlp,
|
||||
raw_text,
|
||||
orig_dict["token_annotation"],
|
||||
orth_variants,
|
||||
lower=False,
|
||||
)
|
||||
orig_dict["token_annotation"] = variant_token_annot
|
||||
example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
|
||||
if whitespace_variants and random.random() < whitespace_level:
|
||||
for _ in range(int(len(example.reference) * whitespace_per_token)):
|
||||
example = make_whitespace_variant(
|
||||
nlp,
|
||||
example,
|
||||
random.choice(whitespace_variants),
|
||||
random.randrange(0, len(example.reference)),
|
||||
)
|
||||
yield example
|
||||
|
||||
|
||||
@registry.augmenters("spacy.orth_variants.v1")
|
||||
def create_orth_variants_augmenter(
|
||||
level: float, lower: float, orth_variants: OrthVariants
|
||||
level: float, lower: float, orth_variants: Dict[str, List[Dict]]
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
"""Create a data augmentation callback that uses orth-variant replacement.
|
||||
The callback can be added to a corpus or other data iterator during training.
|
||||
|
||||
level (float): The percentage of texts that will be augmented.
|
||||
lower (float): The percentage of texts that will be lowercased.
|
||||
orth_variants (Dict[str, dict]): A dictionary containing the single and
|
||||
paired orth variants. Typically loaded from a JSON file.
|
||||
orth_variants (Dict[str, List[Dict]]): A dictionary containing
|
||||
the single and paired orth variants. Typically loaded from a JSON file.
|
||||
RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
|
||||
"""
|
||||
return partial(
|
||||
|
@ -67,16 +123,20 @@ def lower_casing_augmenter(
|
|||
if random.random() >= level:
|
||||
yield example
|
||||
else:
|
||||
yield make_lowercase_variant(nlp, example)
|
||||
|
||||
|
||||
def make_lowercase_variant(nlp: "Language", example: Example):
|
||||
example_dict = example.to_dict()
|
||||
doc = nlp.make_doc(example.text.lower())
|
||||
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
|
||||
yield example.from_dict(doc, example_dict)
|
||||
return example.from_dict(doc, example_dict)
|
||||
|
||||
|
||||
def orth_variants_augmenter(
|
||||
nlp: "Language",
|
||||
example: Example,
|
||||
orth_variants: Dict,
|
||||
orth_variants: Dict[str, List[Dict]],
|
||||
*,
|
||||
level: float = 0.0,
|
||||
lower: float = 0.0,
|
||||
|
@ -148,10 +208,132 @@ def make_orth_variants(
|
|||
pair_idx = pair.index(words[word_idx])
|
||||
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
||||
token_dict["ORTH"] = words
|
||||
# construct modified raw text from words and spaces
|
||||
raw = construct_modified_raw_text(token_dict)
|
||||
return raw, token_dict
|
||||
|
||||
|
||||
def make_whitespace_variant(
|
||||
nlp: "Language",
|
||||
example: Example,
|
||||
whitespace: str,
|
||||
position: int,
|
||||
) -> Example:
|
||||
"""Insert the whitespace token at the specified token offset in the doc.
|
||||
This is primarily intended for v2-compatible training data that doesn't
|
||||
include links or spans. If the document includes links, spans, or partial
|
||||
dependency annotation, it is returned without modifications.
|
||||
|
||||
The augmentation follows the basics of the v2 space attachment policy, but
|
||||
without a distinction between "real" and other tokens, so space tokens
|
||||
may be attached to space tokens:
|
||||
- at the beginning of a sentence attach the space token to the following
|
||||
token
|
||||
- otherwise attach the space token to the preceding token
|
||||
|
||||
The augmenter does not attempt to consolidate adjacent whitespace in the
|
||||
same way that the tokenizer would.
|
||||
|
||||
The following annotation is used for the space token:
|
||||
TAG: "_SP"
|
||||
MORPH: ""
|
||||
POS: "SPACE"
|
||||
LEMMA: ORTH
|
||||
DEP: "dep"
|
||||
SENT_START: False
|
||||
|
||||
The annotation for each attribute is only set for the space token if there
|
||||
is already at least partial annotation for that attribute in the original
|
||||
example.
|
||||
|
||||
RETURNS (Example): Example with one additional space token.
|
||||
"""
|
||||
example_dict = example.to_dict()
|
||||
doc_dict = example_dict.get("doc_annotation", {})
|
||||
token_dict = example_dict.get("token_annotation", {})
|
||||
# returned unmodified if:
|
||||
# - doc is empty
|
||||
# - words are not defined
|
||||
# - links are defined (only character-based offsets, which is more a quirk
|
||||
# of Example.to_dict than a technical constraint)
|
||||
# - spans are defined
|
||||
# - there are partial dependencies
|
||||
if (
|
||||
len(example.reference) == 0
|
||||
or "ORTH" not in token_dict
|
||||
or len(doc_dict.get("links", [])) > 0
|
||||
or len(example.reference.spans) > 0
|
||||
or (
|
||||
example.reference.has_annotation("DEP")
|
||||
and not example.reference.has_annotation("DEP", require_complete=True)
|
||||
)
|
||||
):
|
||||
return example
|
||||
words = token_dict.get("ORTH", [])
|
||||
length = len(words)
|
||||
assert 0 <= position <= length
|
||||
if example.reference.has_annotation("ENT_TYPE"):
|
||||
# I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
|
||||
entity = "O"
|
||||
if position > 1 and position < length:
|
||||
ent_prev = doc_dict["entities"][position - 1]
|
||||
ent_next = doc_dict["entities"][position]
|
||||
if "-" in ent_prev and "-" in ent_next:
|
||||
ent_iob_prev = ent_prev.split("-")[0]
|
||||
ent_type_prev = ent_prev.split("-", 1)[1]
|
||||
ent_iob_next = ent_next.split("-")[0]
|
||||
ent_type_next = ent_next.split("-", 1)[1]
|
||||
if (
|
||||
ent_iob_prev in ("B", "I")
|
||||
and ent_iob_next in ("I", "L")
|
||||
and ent_type_prev == ent_type_next
|
||||
):
|
||||
entity = f"I-{ent_type_prev}"
|
||||
doc_dict["entities"].insert(position, entity)
|
||||
else:
|
||||
del doc_dict["entities"]
|
||||
token_dict["ORTH"].insert(position, whitespace)
|
||||
token_dict["SPACY"].insert(position, False)
|
||||
if example.reference.has_annotation("TAG"):
|
||||
token_dict["TAG"].insert(position, "_SP")
|
||||
else:
|
||||
del token_dict["TAG"]
|
||||
if example.reference.has_annotation("LEMMA"):
|
||||
token_dict["LEMMA"].insert(position, whitespace)
|
||||
else:
|
||||
del token_dict["LEMMA"]
|
||||
if example.reference.has_annotation("POS"):
|
||||
token_dict["POS"].insert(position, "SPACE")
|
||||
else:
|
||||
del token_dict["POS"]
|
||||
if example.reference.has_annotation("MORPH"):
|
||||
token_dict["MORPH"].insert(position, "")
|
||||
else:
|
||||
del token_dict["MORPH"]
|
||||
if example.reference.has_annotation("DEP", require_complete=True):
|
||||
if position == 0:
|
||||
token_dict["HEAD"].insert(position, 0)
|
||||
else:
|
||||
token_dict["HEAD"].insert(position, position - 1)
|
||||
for i in range(len(token_dict["HEAD"])):
|
||||
if token_dict["HEAD"][i] >= position:
|
||||
token_dict["HEAD"][i] += 1
|
||||
token_dict["DEP"].insert(position, "dep")
|
||||
else:
|
||||
del token_dict["HEAD"]
|
||||
del token_dict["DEP"]
|
||||
if example.reference.has_annotation("SENT_START"):
|
||||
token_dict["SENT_START"].insert(position, False)
|
||||
else:
|
||||
del token_dict["SENT_START"]
|
||||
raw = construct_modified_raw_text(token_dict)
|
||||
return Example.from_dict(nlp.make_doc(raw), example_dict)
|
||||
|
||||
|
||||
def construct_modified_raw_text(token_dict):
|
||||
"""Construct modified raw text from words and spaces."""
|
||||
raw = ""
|
||||
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
|
||||
raw += orth
|
||||
if spacy:
|
||||
raw += " "
|
||||
return raw, token_dict
|
||||
return raw
|
||||
|
|
Loading…
Reference in New Issue
Block a user