mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Fix lowercase augmentation (#7336)
* Fix aborted/skipped augmentation for `spacy.orth_variants.v1` if lowercasing was enabled for an example * Simplify `spacy.orth_variants.v1` for `Example` vs. `GoldParse` * Preserve reference tokenization in `spacy.lower_case.v1`
This commit is contained in:
parent
cd70c3cb79
commit
3f3e8110dc
|
@ -38,19 +38,59 @@ def doc(nlp):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_make_orth_variants(nlp, doc):
|
def test_make_orth_variants(nlp):
|
||||||
single = [
|
single = [
|
||||||
{"tags": ["NFP"], "variants": ["…", "..."]},
|
{"tags": ["NFP"], "variants": ["…", "..."]},
|
||||||
{"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
|
{"tags": [":"], "variants": ["-", "—", "–", "--", "---", "——"]},
|
||||||
]
|
]
|
||||||
|
# fmt: off
|
||||||
|
words = ["\n\n", "A", "\t", "B", "a", "b", "…", "...", "-", "—", "–", "--", "---", "——"]
|
||||||
|
tags = ["_SP", "NN", "\t", "NN", "NN", "NN", "NFP", "NFP", ":", ":", ":", ":", ":", ":"]
|
||||||
|
# fmt: on
|
||||||
|
spaces = [True] * len(words)
|
||||||
|
spaces[0] = False
|
||||||
|
spaces[2] = False
|
||||||
|
doc = Doc(nlp.vocab, words=words, spaces=spaces, tags=tags)
|
||||||
augmenter = create_orth_variants_augmenter(
|
augmenter = create_orth_variants_augmenter(
|
||||||
level=0.2, lower=0.5, orth_variants={"single": single}
|
level=0.2, lower=0.5, orth_variants={"single": single}
|
||||||
)
|
)
|
||||||
with make_docbin([doc]) as output_file:
|
with make_docbin([doc] * 10) as output_file:
|
||||||
reader = Corpus(output_file, augmenter=augmenter)
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
# Due to randomness, only test that it works without errors for now
|
# Due to randomness, only test that it works without errors
|
||||||
list(reader(nlp))
|
list(reader(nlp))
|
||||||
|
|
||||||
|
# check that the following settings lowercase everything
|
||||||
|
augmenter = create_orth_variants_augmenter(
|
||||||
|
level=1.0, lower=1.0, orth_variants={"single": single}
|
||||||
|
)
|
||||||
|
with make_docbin([doc] * 10) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
for example in reader(nlp):
|
||||||
|
for token in example.reference:
|
||||||
|
assert token.text == token.text.lower()
|
||||||
|
|
||||||
|
# check that lowercasing is applied without tags
|
||||||
|
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
|
||||||
|
augmenter = create_orth_variants_augmenter(
|
||||||
|
level=1.0, lower=1.0, orth_variants={"single": single}
|
||||||
|
)
|
||||||
|
with make_docbin([doc] * 10) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
for example in reader(nlp):
|
||||||
|
for ex_token, doc_token in zip(example.reference, doc):
|
||||||
|
assert ex_token.text == doc_token.text.lower()
|
||||||
|
|
||||||
|
# check that no lowercasing is applied with lower=0.0
|
||||||
|
doc = Doc(nlp.vocab, words=words, spaces=[True] * len(words))
|
||||||
|
augmenter = create_orth_variants_augmenter(
|
||||||
|
level=1.0, lower=0.0, orth_variants={"single": single}
|
||||||
|
)
|
||||||
|
with make_docbin([doc] * 10) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
for example in reader(nlp):
|
||||||
|
for ex_token, doc_token in zip(example.reference, doc):
|
||||||
|
assert ex_token.text == doc_token.text
|
||||||
|
|
||||||
|
|
||||||
def test_lowercase_augmenter(nlp, doc):
|
def test_lowercase_augmenter(nlp, doc):
|
||||||
augmenter = create_lower_casing_augmenter(level=1.0)
|
augmenter = create_lower_casing_augmenter(level=1.0)
|
||||||
|
@ -66,6 +106,21 @@ def test_lowercase_augmenter(nlp, doc):
|
||||||
assert ref_ent.text == orig_ent.text.lower()
|
assert ref_ent.text == orig_ent.text.lower()
|
||||||
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
||||||
|
|
||||||
|
# check that augmentation works when lowercasing leads to different
|
||||||
|
# predicted tokenization
|
||||||
|
words = ["A", "B", "CCC."]
|
||||||
|
doc = Doc(nlp.vocab, words=words)
|
||||||
|
with make_docbin([doc]) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
corpus = list(reader(nlp))
|
||||||
|
eg = corpus[0]
|
||||||
|
assert eg.reference.text == doc.text.lower()
|
||||||
|
assert eg.predicted.text == doc.text.lower()
|
||||||
|
assert [t.text for t in eg.reference] == [t.lower() for t in words]
|
||||||
|
assert [t.text for t in eg.predicted] == [
|
||||||
|
t.text for t in nlp.make_doc(doc.text.lower())
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_custom_data_augmentation(nlp, doc):
|
def test_custom_data_augmentation(nlp, doc):
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
|
from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
|
||||||
import random
|
import random
|
||||||
import itertools
|
import itertools
|
||||||
import copy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pydantic import BaseModel, StrictStr
|
from pydantic import BaseModel, StrictStr
|
||||||
|
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from ..tokens import Doc
|
|
||||||
from .example import Example
|
from .example import Example
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -71,7 +69,7 @@ def lower_casing_augmenter(
|
||||||
else:
|
else:
|
||||||
example_dict = example.to_dict()
|
example_dict = example.to_dict()
|
||||||
doc = nlp.make_doc(example.text.lower())
|
doc = nlp.make_doc(example.text.lower())
|
||||||
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in doc]
|
example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
|
||||||
yield example.from_dict(doc, example_dict)
|
yield example.from_dict(doc, example_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,9 +86,6 @@ def orth_variants_augmenter(
|
||||||
else:
|
else:
|
||||||
raw_text = example.text
|
raw_text = example.text
|
||||||
orig_dict = example.to_dict()
|
orig_dict = example.to_dict()
|
||||||
if not orig_dict["token_annotation"]:
|
|
||||||
yield example
|
|
||||||
else:
|
|
||||||
variant_text, variant_token_annot = make_orth_variants(
|
variant_text, variant_token_annot = make_orth_variants(
|
||||||
nlp,
|
nlp,
|
||||||
raw_text,
|
raw_text,
|
||||||
|
@ -98,14 +93,8 @@ def orth_variants_augmenter(
|
||||||
orth_variants,
|
orth_variants,
|
||||||
lower=raw_text is not None and random.random() < lower,
|
lower=raw_text is not None and random.random() < lower,
|
||||||
)
|
)
|
||||||
if variant_text:
|
|
||||||
doc = nlp.make_doc(variant_text)
|
|
||||||
else:
|
|
||||||
doc = Doc(nlp.vocab, words=variant_token_annot["ORTH"])
|
|
||||||
variant_token_annot["ORTH"] = [w.text for w in doc]
|
|
||||||
variant_token_annot["SPACY"] = [w.whitespace_ for w in doc]
|
|
||||||
orig_dict["token_annotation"] = variant_token_annot
|
orig_dict["token_annotation"] = variant_token_annot
|
||||||
yield example.from_dict(doc, orig_dict)
|
yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
|
||||||
|
|
||||||
|
|
||||||
def make_orth_variants(
|
def make_orth_variants(
|
||||||
|
@ -116,16 +105,20 @@ def make_orth_variants(
|
||||||
*,
|
*,
|
||||||
lower: bool = False,
|
lower: bool = False,
|
||||||
) -> Tuple[str, Dict[str, List[str]]]:
|
) -> Tuple[str, Dict[str, List[str]]]:
|
||||||
orig_token_dict = copy.deepcopy(token_dict)
|
|
||||||
ndsv = orth_variants.get("single", [])
|
|
||||||
ndpv = orth_variants.get("paired", [])
|
|
||||||
words = token_dict.get("ORTH", [])
|
words = token_dict.get("ORTH", [])
|
||||||
tags = token_dict.get("TAG", [])
|
tags = token_dict.get("TAG", [])
|
||||||
# keep unmodified if words or tags are not defined
|
# keep unmodified if words are not defined
|
||||||
if words and tags:
|
if not words:
|
||||||
|
return raw, token_dict
|
||||||
if lower:
|
if lower:
|
||||||
words = [w.lower() for w in words]
|
words = [w.lower() for w in words]
|
||||||
|
raw = raw.lower()
|
||||||
|
# if no tags, only lowercase
|
||||||
|
if not tags:
|
||||||
|
token_dict["ORTH"] = words
|
||||||
|
return raw, token_dict
|
||||||
# single variants
|
# single variants
|
||||||
|
ndsv = orth_variants.get("single", [])
|
||||||
punct_choices = [random.choice(x["variants"]) for x in ndsv]
|
punct_choices = [random.choice(x["variants"]) for x in ndsv]
|
||||||
for word_idx in range(len(words)):
|
for word_idx in range(len(words)):
|
||||||
for punct_idx in range(len(ndsv)):
|
for punct_idx in range(len(ndsv)):
|
||||||
|
@ -135,6 +128,7 @@ def make_orth_variants(
|
||||||
):
|
):
|
||||||
words[word_idx] = punct_choices[punct_idx]
|
words[word_idx] = punct_choices[punct_idx]
|
||||||
# paired variants
|
# paired variants
|
||||||
|
ndpv = orth_variants.get("paired", [])
|
||||||
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
punct_choices = [random.choice(x["variants"]) for x in ndpv]
|
||||||
for word_idx in range(len(words)):
|
for word_idx in range(len(words)):
|
||||||
for punct_idx in range(len(ndpv)):
|
for punct_idx in range(len(ndpv)):
|
||||||
|
@ -154,50 +148,10 @@ def make_orth_variants(
|
||||||
pair_idx = pair.index(words[word_idx])
|
pair_idx = pair.index(words[word_idx])
|
||||||
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
||||||
token_dict["ORTH"] = words
|
token_dict["ORTH"] = words
|
||||||
token_dict["TAG"] = tags
|
# construct modified raw text from words and spaces
|
||||||
# modify raw
|
raw = ""
|
||||||
if raw is not None:
|
for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
|
||||||
variants = []
|
raw += orth
|
||||||
for single_variants in ndsv:
|
if spacy:
|
||||||
variants.extend(single_variants["variants"])
|
raw += " "
|
||||||
for paired_variants in ndpv:
|
|
||||||
variants.extend(
|
|
||||||
list(itertools.chain.from_iterable(paired_variants["variants"]))
|
|
||||||
)
|
|
||||||
# store variants in reverse length order to be able to prioritize
|
|
||||||
# longer matches (e.g., "---" before "--")
|
|
||||||
variants = sorted(variants, key=lambda x: len(x))
|
|
||||||
variants.reverse()
|
|
||||||
variant_raw = ""
|
|
||||||
raw_idx = 0
|
|
||||||
# add initial whitespace
|
|
||||||
while raw_idx < len(raw) and raw[raw_idx].isspace():
|
|
||||||
variant_raw += raw[raw_idx]
|
|
||||||
raw_idx += 1
|
|
||||||
for word in words:
|
|
||||||
match_found = False
|
|
||||||
# skip whitespace words
|
|
||||||
if word.isspace():
|
|
||||||
match_found = True
|
|
||||||
# add identical word
|
|
||||||
elif word not in variants and raw[raw_idx:].startswith(word):
|
|
||||||
variant_raw += word
|
|
||||||
raw_idx += len(word)
|
|
||||||
match_found = True
|
|
||||||
# add variant word
|
|
||||||
else:
|
|
||||||
for variant in variants:
|
|
||||||
if not match_found and raw[raw_idx:].startswith(variant):
|
|
||||||
raw_idx += len(variant)
|
|
||||||
variant_raw += word
|
|
||||||
match_found = True
|
|
||||||
# something went wrong, abort
|
|
||||||
# (add a warning message?)
|
|
||||||
if not match_found:
|
|
||||||
return raw, orig_token_dict
|
|
||||||
# add following whitespace
|
|
||||||
while raw_idx < len(raw) and raw[raw_idx].isspace():
|
|
||||||
variant_raw += raw[raw_idx]
|
|
||||||
raw_idx += 1
|
|
||||||
raw = variant_raw
|
|
||||||
return raw, token_dict
|
return raw, token_dict
|
||||||
|
|
Loading…
Reference in New Issue
Block a user