This commit is contained in:
Ines Montani 2020-10-04 17:53:00 +02:00
parent 3c36a57e84
commit f5853e9b62
2 changed files with 75 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import pytest
from spacy.training import Corpus from spacy.training import Corpus
from spacy.training.augment import create_orth_variants_augmenter from spacy.training.augment import create_orth_variants_augmenter
from spacy.training.augment import create_lower_casing_augmenter from spacy.training.augment import create_lower_casing_augmenter
from spacy.training.augment import create_remove_punct_augmenter
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import DocBin, Doc from spacy.tokens import DocBin, Doc
from contextlib import contextmanager from contextlib import contextmanager
@ -67,6 +68,33 @@ def test_lowercase_augmenter(nlp, doc):
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc] assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
def test_remove_punct_augmenter(nlp):
doc1 = Doc(
nlp.vocab,
words=["hello", ",", "world", ".", "."],
spaces=[False, True, False, False, True],
pos=["X", "PUNCT", "X", "PUNCT", "PUNCT"],
)
doc2 = Doc(
nlp.vocab,
words=[";", ".", "yo", "."],
spaces=[False, True, False, False],
pos=["PUNCT", "PUNCT", "X", "PUNCT"],
)
augmenter = create_remove_punct_augmenter(level=1.0, token_level=1.0)
with make_docbin([doc1, doc2]) as output_file:
reader = Corpus(output_file, augmenter=augmenter)
corpus = list(reader(nlp))
eg1 = corpus[0]
assert [t.text for t in eg1.reference] == ["hello", "world"]
assert [t.text for t in eg1.predicted] == ["hello", "world"]
assert [t.pos_ for t in eg1.reference] == ["X", "X"]
eg2 = corpus[1]
assert [t.text for t in eg2.reference] == [";", ".", "yo"]
assert [t.text for t in eg2.predicted] == [";", ".", "yo"]
assert [t.pos_ for t in eg2.reference] == ["PUNCT", "PUNCT", "X"]
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
def test_custom_data_augmentation(nlp, doc): def test_custom_data_augmentation(nlp, doc):
def create_spongebob_augmenter(randomize: bool = False): def create_spongebob_augmenter(randomize: bool = False):

View File

@ -6,6 +6,7 @@ from functools import partial
from pydantic import BaseModel, StrictStr from pydantic import BaseModel, StrictStr
from ..util import registry, logger from ..util import registry, logger
from ..matcher import Matcher
from ..tokens import Doc from ..tokens import Doc
from .example import Example from .example import Example
@ -59,6 +60,18 @@ def create_lower_casing_augmenter(
return partial(lower_casing_augmenter, level=level) return partial(lower_casing_augmenter, level=level)
@registry.augmenters("spacy.remove_punct.v1")
def create_remove_punct_augmenter(
level: float, token_level: float, punct_tokens: List[str] = [".", ",", ";"],
) -> Callable[["Language", Example], Iterator[Example]]:
return partial(
remove_punct_augmenter,
level=level,
token_level=token_level,
punct_tokens=punct_tokens,
)
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]: def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
yield example yield example
@ -75,6 +88,40 @@ def lower_casing_augmenter(
yield example.from_dict(doc, example_dict) yield example.from_dict(doc, example_dict)
def remove_punct_augmenter(
nlp: "Language",
example: Example,
*,
level: float,
token_level: float,
punct_tokens: List[str],
) -> Iterator[Example]:
# Token plus one or more punctuation characters
pattern = [{"ORTH": {"IN": punct_tokens}, "OP": "+"}]
if random.random() >= level:
yield example
else:
doc = example.reference
# This is a bit unfortunate but we need the nlp.vocab in oder to
# create the matcher
matcher = Matcher(nlp.vocab)
matcher.add("PUNCT", [pattern], greedy="LONGEST")
matches = matcher(doc)
with doc.retokenize() as retokenizer:
for _, start, end in matches:
# Don't merge if the first token is punctuation
if start > 0 and random.random() < token_level:
prev_idx = start - 1
span = doc[prev_idx:end]
retokenizer.merge(span, attrs={"NORM": doc[prev_idx].text})
example_dict = example.to_dict()
words = [t.norm_ for t in doc]
spaces = [bool(t.whitespace_) for t in doc]
example_dict["token_annotation"]["ORTH"] = words
new_doc = Doc(nlp.vocab, words=words, spaces=spaces)
yield example.from_dict(new_doc, example_dict)
def orth_variants_augmenter( def orth_variants_augmenter(
nlp: "Language", nlp: "Language",
example: Example, example: Example,