mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
WIP
This commit is contained in:
parent
3c36a57e84
commit
f5853e9b62
|
@ -2,6 +2,7 @@ import pytest
|
||||||
from spacy.training import Corpus
|
from spacy.training import Corpus
|
||||||
from spacy.training.augment import create_orth_variants_augmenter
|
from spacy.training.augment import create_orth_variants_augmenter
|
||||||
from spacy.training.augment import create_lower_casing_augmenter
|
from spacy.training.augment import create_lower_casing_augmenter
|
||||||
|
from spacy.training.augment import create_remove_punct_augmenter
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import DocBin, Doc
|
from spacy.tokens import DocBin, Doc
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
@ -67,6 +68,33 @@ def test_lowercase_augmenter(nlp, doc):
|
||||||
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
assert [t.pos_ for t in eg.reference] == [t.pos_ for t in doc]
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_punct_augmenter(nlp):
|
||||||
|
doc1 = Doc(
|
||||||
|
nlp.vocab,
|
||||||
|
words=["hello", ",", "world", ".", "."],
|
||||||
|
spaces=[False, True, False, False, True],
|
||||||
|
pos=["X", "PUNCT", "X", "PUNCT", "PUNCT"],
|
||||||
|
)
|
||||||
|
doc2 = Doc(
|
||||||
|
nlp.vocab,
|
||||||
|
words=[";", ".", "yo", "."],
|
||||||
|
spaces=[False, True, False, False],
|
||||||
|
pos=["PUNCT", "PUNCT", "X", "PUNCT"],
|
||||||
|
)
|
||||||
|
augmenter = create_remove_punct_augmenter(level=1.0, token_level=1.0)
|
||||||
|
with make_docbin([doc1, doc2]) as output_file:
|
||||||
|
reader = Corpus(output_file, augmenter=augmenter)
|
||||||
|
corpus = list(reader(nlp))
|
||||||
|
eg1 = corpus[0]
|
||||||
|
assert [t.text for t in eg1.reference] == ["hello", "world"]
|
||||||
|
assert [t.text for t in eg1.predicted] == ["hello", "world"]
|
||||||
|
assert [t.pos_ for t in eg1.reference] == ["X", "X"]
|
||||||
|
eg2 = corpus[1]
|
||||||
|
assert [t.text for t in eg2.reference] == [";", ".", "yo"]
|
||||||
|
assert [t.text for t in eg2.predicted] == [";", ".", "yo"]
|
||||||
|
assert [t.pos_ for t in eg2.reference] == ["PUNCT", "PUNCT", "X"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::UserWarning")
|
@pytest.mark.filterwarnings("ignore::UserWarning")
|
||||||
def test_custom_data_augmentation(nlp, doc):
|
def test_custom_data_augmentation(nlp, doc):
|
||||||
def create_spongebob_augmenter(randomize: bool = False):
|
def create_spongebob_augmenter(randomize: bool = False):
|
||||||
|
|
|
@ -6,6 +6,7 @@ from functools import partial
|
||||||
from pydantic import BaseModel, StrictStr
|
from pydantic import BaseModel, StrictStr
|
||||||
|
|
||||||
from ..util import registry, logger
|
from ..util import registry, logger
|
||||||
|
from ..matcher import Matcher
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from .example import Example
|
from .example import Example
|
||||||
|
|
||||||
|
@ -59,6 +60,18 @@ def create_lower_casing_augmenter(
|
||||||
return partial(lower_casing_augmenter, level=level)
|
return partial(lower_casing_augmenter, level=level)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.augmenters("spacy.remove_punct.v1")
|
||||||
|
def create_remove_punct_augmenter(
|
||||||
|
level: float, token_level: float, punct_tokens: List[str] = [".", ",", ";"],
|
||||||
|
) -> Callable[["Language", Example], Iterator[Example]]:
|
||||||
|
return partial(
|
||||||
|
remove_punct_augmenter,
|
||||||
|
level=level,
|
||||||
|
token_level=token_level,
|
||||||
|
punct_tokens=punct_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
|
||||||
yield example
|
yield example
|
||||||
|
|
||||||
|
@ -75,6 +88,40 @@ def lower_casing_augmenter(
|
||||||
yield example.from_dict(doc, example_dict)
|
yield example.from_dict(doc, example_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_punct_augmenter(
|
||||||
|
nlp: "Language",
|
||||||
|
example: Example,
|
||||||
|
*,
|
||||||
|
level: float,
|
||||||
|
token_level: float,
|
||||||
|
punct_tokens: List[str],
|
||||||
|
) -> Iterator[Example]:
|
||||||
|
# Token plus one or more punctuation characters
|
||||||
|
pattern = [{"ORTH": {"IN": punct_tokens}, "OP": "+"}]
|
||||||
|
if random.random() >= level:
|
||||||
|
yield example
|
||||||
|
else:
|
||||||
|
doc = example.reference
|
||||||
|
# This is a bit unfortunate but we need the nlp.vocab in oder to
|
||||||
|
# create the matcher
|
||||||
|
matcher = Matcher(nlp.vocab)
|
||||||
|
matcher.add("PUNCT", [pattern], greedy="LONGEST")
|
||||||
|
matches = matcher(doc)
|
||||||
|
with doc.retokenize() as retokenizer:
|
||||||
|
for _, start, end in matches:
|
||||||
|
# Don't merge if the first token is punctuation
|
||||||
|
if start > 0 and random.random() < token_level:
|
||||||
|
prev_idx = start - 1
|
||||||
|
span = doc[prev_idx:end]
|
||||||
|
retokenizer.merge(span, attrs={"NORM": doc[prev_idx].text})
|
||||||
|
example_dict = example.to_dict()
|
||||||
|
words = [t.norm_ for t in doc]
|
||||||
|
spaces = [bool(t.whitespace_) for t in doc]
|
||||||
|
example_dict["token_annotation"]["ORTH"] = words
|
||||||
|
new_doc = Doc(nlp.vocab, words=words, spaces=spaces)
|
||||||
|
yield example.from_dict(new_doc, example_dict)
|
||||||
|
|
||||||
|
|
||||||
def orth_variants_augmenter(
|
def orth_variants_augmenter(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
example: Example,
|
example: Example,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user