Replace config string with spacy.blank

This commit is contained in:
shademe 2022-11-18 14:50:39 +01:00
parent d6d5c52135
commit 3fd19b8bd8
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -3,6 +3,7 @@ from confection import Config
import numpy import numpy
import pytest import pytest
import spacy
import srsly import srsly
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import Doc, DocBin from spacy.tokens import Doc, DocBin
@ -14,7 +15,7 @@ from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs from spacy.training.converters import json_to_docs
from spacy.training.loop import train_while_improving from spacy.training.loop import train_while_improving
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
from spacy.util import load_config_from_str, load_model_from_config from spacy.util import load_config_from_str
from thinc.api import compounding, Adam from thinc.api import compounding, Adam
from ..util import make_tempdir from ..util import make_tempdir
@ -1116,38 +1117,6 @@ def test_retokenized_docs(doc):
assert example.get_aligned("ORTH", as_string=True) == expected2 assert example.get_aligned("ORTH", as_string=True) == expected2
training_config_string = """
[nlp]
lang = "en"
pipeline = ["tok2vec", "tagger"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 342
depth = 4
window_size = 1
embed_size = 2000
maxout_pieces = 3
subword_features = true
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v2"
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.width}
"""
def test_training_before_update(doc): def test_training_before_update(doc):
def before_update(nlp, args): def before_update(nlp, args):
assert args["step"] == 0 assert args["step"] == 0
@ -1161,8 +1130,8 @@ def test_training_before_update(doc):
def generate_batch(): def generate_batch():
yield 1, [Example(doc, doc)] yield 1, [Example(doc, doc)]
config = Config().from_str(training_config_string, interpolate=False) nlp = spacy.blank("en", config={"training": {}})
nlp = load_model_from_config(config, auto_fill=True, validate=True) nlp.add_pipe("tagger")
optimizer = Adam() optimizer = Adam()
generator = train_while_improving( generator = train_while_improving(
nlp, nlp,