mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-06 21:30:22 +03:00
Simplify test
This commit is contained in:
parent
aa921ff130
commit
d6d5c52135
|
@ -48,12 +48,6 @@ def pytest_runtest_setup(item):
|
|||
pytest.skip("not referencing any issues")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_dir(request):
|
||||
print(request.fspath)
|
||||
return Path(request.fspath).parent
|
||||
|
||||
|
||||
# Fixtures for language tokenizers (languages sorted alphabetically)
|
||||
|
||||
|
||||
|
|
|
@ -12,11 +12,10 @@ from spacy.training import offsets_to_biluo_tags
|
|||
from spacy.training.alignment_array import AlignmentArray
|
||||
from spacy.training.align import get_alignments
|
||||
from spacy.training.converters import json_to_docs
|
||||
from spacy.training.initialize import init_nlp
|
||||
from spacy.training.loop import train
|
||||
from spacy.training.loop import train_while_improving
|
||||
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
|
||||
from spacy.util import load_config_from_str, registry, load_model_from_config
|
||||
from thinc.api import compounding
|
||||
from spacy.util import load_config_from_str, load_model_from_config
|
||||
from thinc.api import compounding, Adam
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -1146,59 +1145,40 @@ factory = "tagger"
|
|||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.width}
|
||||
|
||||
[corpora]
|
||||
|
||||
[corpora.train]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = null
|
||||
|
||||
[corpora.dev]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = null
|
||||
|
||||
[training]
|
||||
train_corpus = "corpora.train"
|
||||
dev_corpus = "corpora.dev"
|
||||
seed = 1
|
||||
gpu_allocator = "pytorch"
|
||||
dropout = 0.1
|
||||
accumulate_gradient = 3
|
||||
patience = 5000
|
||||
max_epochs = 1
|
||||
max_steps = 6
|
||||
eval_frequency = 10
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "spacy.batch_by_padded.v1"
|
||||
discard_oversize = False
|
||||
get_length = null
|
||||
size = 1
|
||||
buffer = 256
|
||||
"""
|
||||
|
||||
|
||||
def test_training_before_update(test_dir):
|
||||
ran_before_update = False
|
||||
def test_training_before_update(doc):
|
||||
def before_update(nlp, args):
|
||||
assert args["step"] == 0
|
||||
assert args["epoch"] == 1
|
||||
|
||||
@registry.callbacks(f"test_training_before_update_callback")
|
||||
def make_before_creation():
|
||||
def before_update(nlp, args):
|
||||
nonlocal ran_before_update
|
||||
ran_before_update = True
|
||||
assert "step" in args
|
||||
assert "epoch" in args
|
||||
# Raise an error here as the rest of the loop
|
||||
# will not run to completion due to uninitialized
|
||||
# models.
|
||||
raise ValueError("ran_before_update")
|
||||
|
||||
return before_update
|
||||
def generate_batch():
|
||||
yield 1, [Example(doc, doc)]
|
||||
|
||||
config = Config().from_str(training_config_string, interpolate=False)
|
||||
config["corpora"]["train"]["path"] = str(test_dir / "toy-en-corpus.spacy")
|
||||
config["corpora"]["dev"]["path"] = str(test_dir / "toy-en-corpus.spacy")
|
||||
config["training"]["before_update"] = {
|
||||
"@callbacks": "test_training_before_update_callback"
|
||||
}
|
||||
nlp = load_model_from_config(config, auto_fill=True, validate=True)
|
||||
optimizer = Adam()
|
||||
generator = train_while_improving(
|
||||
nlp,
|
||||
optimizer,
|
||||
generate_batch(),
|
||||
lambda: None,
|
||||
dropout=0.1,
|
||||
eval_frequency=100,
|
||||
accumulate_gradient=10,
|
||||
patience=10,
|
||||
max_steps=100,
|
||||
exclude=[],
|
||||
annotating_components=[],
|
||||
before_update=before_update,
|
||||
)
|
||||
|
||||
nlp = init_nlp(config)
|
||||
train(nlp)
|
||||
assert ran_before_update == True
|
||||
with pytest.raises(ValueError, match="ran_before_update"):
|
||||
for _ in generator:
|
||||
pass
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user