This commit is contained in:
shademe 2022-11-18 12:55:03 +01:00
parent e19b490d75
commit aa921ff130
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402
3 changed files with 98 additions and 1 deletions

View File

@ -1,3 +1,4 @@
from pathlib import Path
import pytest
from spacy.util import get_lang_class
from hypothesis import settings
@ -47,6 +48,12 @@ def pytest_runtest_setup(item):
pytest.skip("not referencing any issues")
@pytest.fixture
def test_dir(request):
print(request.fspath)
return Path(request.fspath).parent
# Fixtures for language tokenizers (languages sorted alphabetically)

View File

@ -1,4 +1,5 @@
import random
from confection import Config
import numpy
import pytest
@ -11,8 +12,10 @@ from spacy.training import offsets_to_biluo_tags
from spacy.training.alignment_array import AlignmentArray
from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs
from spacy.training.initialize import init_nlp
from spacy.training.loop import train
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
from spacy.util import load_config_from_str
from spacy.util import load_config_from_str, registry, load_model_from_config
from thinc.api import compounding
from ..util import make_tempdir
@ -1112,3 +1115,90 @@ def test_retokenized_docs(doc):
retokenizer.merge(doc1[0:2])
retokenizer.merge(doc1[5:7])
assert example.get_aligned("ORTH", as_string=True) == expected2
training_config_string = """
[nlp]
lang = "en"
pipeline = ["tok2vec", "tagger"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 342
depth = 4
window_size = 1
embed_size = 2000
maxout_pieces = 3
subword_features = true
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v2"
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.width}
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = null
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = null
[training]
train_corpus = "corpora.train"
dev_corpus = "corpora.dev"
seed = 1
gpu_allocator = "pytorch"
dropout = 0.1
accumulate_gradient = 3
patience = 5000
max_epochs = 1
max_steps = 6
eval_frequency = 10
[training.batcher]
@batchers = "spacy.batch_by_padded.v1"
discard_oversize = False
get_length = null
size = 1
buffer = 256
"""
def test_training_before_update(test_dir):
ran_before_update = False
@registry.callbacks(f"test_training_before_update_callback")
def make_before_creation():
def before_update(nlp, args):
nonlocal ran_before_update
ran_before_update = True
assert "step" in args
assert "epoch" in args
return before_update
config = Config().from_str(training_config_string, interpolate=False)
config["corpora"]["train"]["path"] = str(test_dir / "toy-en-corpus.spacy")
config["corpora"]["dev"]["path"] = str(test_dir / "toy-en-corpus.spacy")
config["training"]["before_update"] = {
"@callbacks": "test_training_before_update_callback"
}
nlp = load_model_from_config(config, auto_fill=True, validate=True)
nlp = init_nlp(config)
train(nlp)
assert ran_before_update == True

Binary file not shown.