This commit is contained in:
Ines Montani 2020-10-04 17:53:00 +02:00
parent 3c36a57e84
commit f5853e9b62
2 changed files with 75 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import pytest
from spacy.training import Corpus
from spacy.training.augment import create_orth_variants_augmenter
from spacy.training.augment import create_lower_casing_augmenter
from spacy.training.augment import create_remove_punct_augmenter
from spacy.lang.en import English
from spacy.tokens import DocBin, Doc
from contextlib import contextmanager
@ -67,6 +68,33 @@ def test_lowercase_augmenter(nlp, doc):
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
def test_remove_punct_augmenter(nlp):
doc1 = Doc(
nlp.vocab,
words=["hello", ",", "world", ".", "."],
spaces=[False, True, False, False, True],
pos=["X", "PUNCT", "X", "PUNCT", "PUNCT"],
)
doc2 = Doc(
nlp.vocab,
words=[";", ".", "yo", "."],
spaces=[False, True, False, False],
pos=["PUNCT", "PUNCT", "X", "PUNCT"],
)
augmenter = create_remove_punct_augmenter(level=1.0, token_level=1.0)
with make_docbin([doc1, doc2]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg1 = corpus[0]
assert [t.text for t in eg1.reference] == ["hello", "world"]
assert [t.text for t in eg1.predicted] == ["hello", "world"]
assert [t.pos_ for t in eg1.reference] == ["X", "X"]
eg2 = corpus[1]
assert [t.text for t in eg2.reference] == [";", ".", "yo"]
assert [t.text for t in eg2.predicted] == [";", ".", "yo"]
assert [t.pos_ for t in eg2.reference] == ["PUNCT", "PUNCT", "X"]
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc):
def create_spongebob_augmenter(randomize: bool = False):

View File

@ -6,6 +6,7 @@ from functools import partial
from pydantic import BaseModel, StrictStr
from ..util import registry, logger
from ..matcher import Matcher
from ..tokens import Doc
from .example import Example
@ -59,6 +60,18 @@ def create_lower_casing_augmenter(
return partial(lower_casing_augmenter, level=level)
@registry.augmenters("spacy.remove_punct.v1")
def create_remove_punct_augmenter(
level: float, token_level: float, punct_tokens: List[str] = [".", ",", ";"],
) -> Callable[["Language", Example], Iterator[Example]]:
return partial(
remove_punct_augmenter,
level=level,
token_level=token_level,
punct_tokens=punct_tokens,
)
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
yield example
@ -75,6 +88,40 @@ def lower_casing_augmenter(
yield example.from_dict(doc, example_dict)
def remove_punct_augmenter(
nlp: "Language",
example: Example,
*,
level: float,
token_level: float,
punct_tokens: List[str],
) -> Iterator[Example]:
# Token plus one or more punctuation characters
pattern = [{"ORTH": {"IN": punct_tokens}, "OP": "+"}]
if random.random() >= level:
yield example
else:
doc = example.reference
# This is a bit unfortunate but we need the nlp.vocab in oder to
# create the matcher
matcher = Matcher(nlp.vocab)
matcher.add("PUNCT", [pattern], greedy="LONGEST")
matches = matcher(doc)
with doc.retokenize() as retokenizer:
for _, start, end in matches:
# Don't merge if the first token is punctuation
if start > 0 and random.random() < token_level:
prev_idx = start - 1
span = doc[prev_idx:end]
retokenizer.merge(span, attrs={"NORM": doc[prev_idx].text})
example_dict = example.to_dict()
words = [t.norm_ for t in doc]
spaces = [bool(t.whitespace_) for t in doc]
example_dict["token_annotation"]["ORTH"] = words
new_doc = Doc(nlp.vocab, words=words, spaces=spaces)
yield example.from_dict(new_doc, example_dict)
def orth_variants_augmenter(
nlp: "Language",
example: Example,