mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
WIP
This commit is contained in:
parent
3c36a57e84
commit
f5853e9b62
|
@ -2,6 +2,7 @@ import pytest
|
|||
from spacy.training import Corpus
|
||||
from spacy.training.augment import create_orth_variants_augmenter
|
||||
from spacy.training.augment import create_lower_casing_augmenter
|
||||
from spacy.training.augment import create_remove_punct_augmenter
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import DocBin, Doc
|
||||
from contextlib import contextmanager
|
||||
|
@ -67,6 +68,33 @@ def test_lowercase_augmenter(nlp, doc):
|
|||
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
||||
|
||||
|
||||
def test_remove_punct_augmenter(nlp):
|
||||
doc1 = Doc(
|
||||
nlp.vocab,
|
||||
words=["hello", ",", "world", ".", "."],
|
||||
spaces=[False, True, False, False, True],
|
||||
pos=["X", "PUNCT", "X", "PUNCT", "PUNCT"],
|
||||
)
|
||||
doc2 = Doc(
|
||||
nlp.vocab,
|
||||
words=[";", ".", "yo", "."],
|
||||
spaces=[False, True, False, False],
|
||||
pos=["PUNCT", "PUNCT", "X", "PUNCT"],
|
||||
)
|
||||
augmenter = create_remove_punct_augmenter(level=1.0, token_level=1.0)
|
||||
with make_docbin([doc1, doc2]) as output_file:
|
||||
reader = Corpus(output_file, augmenter=augmenter)
|
||||
corpus = list(reader(nlp))
|
||||
eg1 = corpus[0]
|
||||
assert [t.text for t in eg1.reference] == ["hello", "world"]
|
||||
assert [t.text for t in eg1.predicted] == ["hello", "world"]
|
||||
assert [t.pos_ for t in eg1.reference] == ["X", "X"]
|
||||
eg2 = corpus[1]
|
||||
assert [t.text for t in eg2.reference] == [";", ".", "yo"]
|
||||
assert [t.text for t in eg2.predicted] == [";", ".", "yo"]
|
||||
assert [t.pos_ for t in eg2.reference] == ["PUNCT", "PUNCT", "X"]
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||
def test_custom_data_augmentation(nlp, doc):
|
||||
def create_spongebob_augmenter(randomize: bool = False):
|
||||
|
|
|
@ -6,6 +6,7 @@ from functools import partial
|
|||
from pydantic import BaseModel, StrictStr
|
||||
|
||||
from ..util import registry, logger
|
||||
from ..matcher import Matcher
|
||||
from ..tokens import Doc
|
||||
from .example import Example
|
||||
|
||||
|
@ -59,6 +60,18 @@ def create_lower_casing_augmenter(
|
|||
return partial(lower_casing_augmenter, level=level)
|
||||
|
||||
|
||||
@registry.augmenters("spacy.remove_punct.v1")
|
||||
def create_remove_punct_augmenter(
|
||||
level: float, token_level: float, punct_tokens: List[str] = [".", ",", ";"],
|
||||
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||
return partial(
|
||||
remove_punct_augmenter,
|
||||
level=level,
|
||||
token_level=token_level,
|
||||
punct_tokens=punct_tokens,
|
||||
)
|
||||
|
||||
|
||||
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
||||
yield example
|
||||
|
||||
|
@ -75,6 +88,40 @@ def lower_casing_augmenter(
|
|||
yield example.from_dict(doc, example_dict)
|
||||
|
||||
|
||||
def remove_punct_augmenter(
|
||||
nlp: "Language",
|
||||
example: Example,
|
||||
*,
|
||||
level: float,
|
||||
token_level: float,
|
||||
punct_tokens: List[str],
|
||||
) -> Iterator[Example]:
|
||||
# Token plus one or more punctuation characters
|
||||
pattern = [{"ORTH": {"IN": punct_tokens}, "OP": "+"}]
|
||||
if random.random() >= level:
|
||||
yield example
|
||||
else:
|
||||
doc = example.reference
|
||||
# This is a bit unfortunate but we need the nlp.vocab in oder to
|
||||
# create the matcher
|
||||
matcher = Matcher(nlp.vocab)
|
||||
matcher.add("PUNCT", [pattern], greedy="LONGEST")
|
||||
matches = matcher(doc)
|
||||
with doc.retokenize() as retokenizer:
|
||||
for _, start, end in matches:
|
||||
# Don't merge if the first token is punctuation
|
||||
if start > 0 and random.random() < token_level:
|
||||
prev_idx = start - 1
|
||||
span = doc[prev_idx:end]
|
||||
retokenizer.merge(span, attrs={"NORM": doc[prev_idx].text})
|
||||
example_dict = example.to_dict()
|
||||
words = [t.norm_ for t in doc]
|
||||
spaces = [bool(t.whitespace_) for t in doc]
|
||||
example_dict["token_annotation"]["ORTH"] = words
|
||||
new_doc = Doc(nlp.vocab, words=words, spaces=spaces)
|
||||
yield example.from_dict(new_doc, example_dict)
|
||||
|
||||
|
||||
def orth_variants_augmenter(
|
||||
nlp: "Language",
|
||||
example: Example,
|
||||
|
|
Loading…
Reference in New Issue
Block a user