From aa921ff130dde42ba84faff16f95935531ea6542 Mon Sep 17 00:00:00 2001 From: shademe Date: Fri, 18 Nov 2022 12:55:03 +0100 Subject: [PATCH] Add test --- spacy/tests/conftest.py | 7 ++ spacy/tests/training/test_training.py | 92 ++++++++++++++++++++++- spacy/tests/training/toy-en-corpus.spacy | Bin 0 -> 2703 bytes 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 spacy/tests/training/toy-en-corpus.spacy diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 0fc74243d..c17fde0e8 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -1,3 +1,4 @@ +from pathlib import Path import pytest from spacy.util import get_lang_class from hypothesis import settings @@ -47,6 +48,12 @@ def pytest_runtest_setup(item): pytest.skip("not referencing any issues") +@pytest.fixture +def test_dir(request): + print(request.fspath) + return Path(request.fspath).parent + + # Fixtures for language tokenizers (languages sorted alphabetically) diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index 4384a796d..d1313c3c9 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -1,4 +1,5 @@ import random +from confection import Config import numpy import pytest @@ -11,8 +12,10 @@ from spacy.training import offsets_to_biluo_tags from spacy.training.alignment_array import AlignmentArray from spacy.training.align import get_alignments from spacy.training.converters import json_to_docs +from spacy.training.initialize import init_nlp +from spacy.training.loop import train from spacy.util import get_words_and_spaces, load_model_from_path, minibatch -from spacy.util import load_config_from_str +from spacy.util import load_config_from_str, registry, load_model_from_config from thinc.api import compounding from ..util import make_tempdir @@ -1112,3 +1115,90 @@ def test_retokenized_docs(doc): retokenizer.merge(doc1[0:2]) retokenizer.merge(doc1[5:7]) assert example.get_aligned("ORTH", as_string=True) == expected2 + + +training_config_string = """ +[nlp] +lang = "en" +pipeline = ["tok2vec", "tagger"] + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v2" + +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.width} + +[corpora] + +[corpora.train] +@readers = "spacy.Corpus.v1" +path = null + +[corpora.dev] +@readers = "spacy.Corpus.v1" +path = null + +[training] +train_corpus = "corpora.train" +dev_corpus = "corpora.dev" +seed = 1 +gpu_allocator = "pytorch" +dropout = 0.1 +accumulate_gradient = 3 +patience = 5000 +max_epochs = 1 +max_steps = 6 +eval_frequency = 10 + +[training.batcher] +@batchers = "spacy.batch_by_padded.v1" +discard_oversize = False +get_length = null +size = 1 +buffer = 256 +""" + + +def test_training_before_update(test_dir): + ran_before_update = False + + @registry.callbacks(f"test_training_before_update_callback") + def make_before_creation(): + def before_update(nlp, args): + nonlocal ran_before_update + ran_before_update = True + assert "step" in args + assert "epoch" in args + + return before_update + + config = Config().from_str(training_config_string, interpolate=False) + config["corpora"]["train"]["path"] = str(test_dir / "toy-en-corpus.spacy") + config["corpora"]["dev"]["path"] = str(test_dir / "toy-en-corpus.spacy") + config["training"]["before_update"] = { + "@callbacks": "test_training_before_update_callback" + } + nlp = load_model_from_config(config, auto_fill=True, validate=True) + + nlp = init_nlp(config) + train(nlp) + assert ran_before_update == True diff --git a/spacy/tests/training/toy-en-corpus.spacy b/spacy/tests/training/toy-en-corpus.spacy new file mode 100644 index 0000000000000000000000000000000000000000..9a771be712ce2b4b83c1e9137d12b48534630f79 GIT binary patch literal 2703 zcmV;A3UKv!ob8)=R1`-Z$3aP4FT52+ffZa_6%>!Gx*8(qf?T2zP$TJ?>4AZs?xwp3 z9F51SiP5aL;!1=t-E-j)lqjMgD&nzgAjUJU?(wK^B*xH=?wq)+n)hjzKbqe^{ z^eC6G^BmF6^G;sv2K08D02JnZ`sbC*owh6gub!R1`Y<-w1Lk1f z_A4CEg~+S^e+k&*4HV|X0_Liw_+D<@B4OtlqMbr!<4Rl5fX{%!9Qa$Fgq^2|c2*0* z%?|24Bv6=xb*j|g);;lYEtya-7AVXe7PNm~Xu)crMJpaqbhDxFct$ z;cjfh>+X-I0EPMGPu;3NJZapZ)UT7H)u$=rga)86pX1Z3?B08goRnXGk2r6UYfm>% zYVscr6z0GWb}IBYzG`R8{mYAsflBd15_TF9?SQ=Vs`uh$P*(DbNzHwF~uSC?hkN;{5Ct&p(O;v~OD%?nBQ>EN_4yX*Zu>blhS z>c8e^<8hE*KlZKiTAui{87Rzqe&4g`YEslurG71!7e<`XM6PPE#Y)w0sD9nE_TbIn z@4h(=6y_kGDyDKs`SpwW&jWt9=gxg*iAsO6%*7kC)G1 zc{p(#P?-0*4mG0Wx;`sn5 z%x{+89u?6es@*z;eqiSF+wWbndL&Soue-lEc6?ZLAF1n-M~L&bA&(tr)w-tvg}I~R z+eZ&K9Y{8ntN;pgN7Xydif@O*!?Ie-`c3inzDE0PsEmL)YJDJQXGRdcr+=5?c0ggi z_1W|hv$^Gq4=aoZO6vpE@u1cRQh6J7Jmqgx&L30|5eF3Ja(ST@(N65}+|_4pmv{n& zxufy|yK$-Ze7yHLpfGn-UQn91QS(Ab%Del!8zyfB3Uj$Q`fJ-iK#e;aCG33rD~mgF z`>@W6JO5v&|M~dsuoIUChL&Fh3Uj%-3bnr0TKLBB?yDkz!W{U4R2_m^Uk7fBSUGs* zfvG@YF1LSx>Ia>TUco0vy-5Ovxm?_F*8VkW9b#k$zoXt4GXp5h<@T>p{lHo4ROxjv z_z!Y+vRAmJwEKFS2`J3v@&ao9S#*0=bCd4%=FpP+=UrAO_JMhiiI`cf$^$6O<>HRB_TSDT+Oc2TJO6S+UM`jQINxu;gGQL6?uW?j zzjbTKYjoAGQUirKxS!>$e5!Px@2(e_7*5pRt9iS&@nXzw7Upu-Kd9@pV0~jmNY{r; zU$wqQjiXz4?sy7MPqow! z&+BXHZ7)6W4a`xm=j8H2)3%V$r=7po11QXaU&}pLBR3v)-dbkS>Hr)l%>UP(tJ$2< z+jG5sPzq3(JF1RXT3@5a(WFhQBC{5jTmlMnx%wKlj(65Rg-5UX?Ryq5%YedMu0BAW zH>GtV>T?cz79Gj%GJDGh?)7W`-l?U0C(J<|VpI71o1>m{kQ-0H_=H1GrZ)WsD9k~f zxK_drD3IjrIP3YH%Csx)&x^170EM}u{QAm6mu1NtZ?yvob4Ts_J8PdDHC{{k0qVZX zsa((YhEC`D0EM}u?mJ(K4@pqXylj0sXY`G)p13@NxufEav-VxAg;dwIRo@N*3Uf#8 zyC_|!sQV!+1iuUV4qN*Ig}I~lL(1kCUdlgxr3~mVwS($!F2^Sy@%h)@)}ItQ1?$}ck%ml5ZfrRbrV}BvGo_5jpY~} znZ{mq*;HuyooT4)_ofk+fSAOP082ok*qZ`{z`#I0Xp|)=XyTYLrb(t?J~WgM4HH7c zSbHc%P^1nc_^=uF@R{L>De-aPrYJrtj*rsvQJg87j}~7^LbQS4V-on-Fg`9$h>J@Q z+jKsj6ygnxkdT;Q7MmZR7-CM0kDtLOCh#-j`I#ZY%#bj#MT#v>YzbnU#V02VDU`v4 z37AG>!8F-w!YcN9N+Vzd!DD(e_Tzg}W5Z|~#~9WkD#Z#IgA15po?S&T43&dx>?#Jw zGR2WwR0bWPF{=y=DZVx0jG)FyUac0?8rG&JaEvt2cJ<3~ZdX%!Jv5233ZYH0awVl+tQ>iWJ1J&r^&*85lDq4o9FdqnU0SAx%+?*s-jc zW^j5Z%}`pLZTp)dj1Xvsv(N@o%~=?nP!j^9qpf12agB||nQWaJx3NZ2EuITAmyU4) zmyU~XaD7{66gzPuILgW~m^zcSF=nc@)1)EDdWnA1MoAu5^rVzP6qUDrXcBJtESsTz#KM7RWc J{Qw%gck$eeb?X2C literal 0 HcmV?d00001