diff --git a/spacy/errors.py b/spacy/errors.py index 2ebc49e8c..4f61cf098 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -487,7 +487,10 @@ class Errors: E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.") # New errors added in v3.x - + E874 = ("Could not initialize the tok2vec model from component " + "'{component}' and layer '{layer}'.") + E875 = ("To use the PretrainVectors objective, make sure that static vectors are loaded. " + "In the config, these are defined by the initialize.vectors setting.") E879 = ("Unexpected type for 'spans' data. Provide a dictionary mapping keys to " "a list of spans, with each span represented by a tuple (start_char, end_char). " "The tuple can be optionally extended with a label and a KB ID.") diff --git a/spacy/language.py b/spacy/language.py index 2a9b50bcc..80de94278 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1222,10 +1222,6 @@ class Language: init_vocab( self, data=I["vocab_data"], lookups=I["lookups"], vectors=I["vectors"] ) - pretrain_cfg = config.get("pretraining") - if pretrain_cfg: - P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) - init_tok2vec(self, P, I) if self.vocab.vectors.data.shape[1] >= 1: ops = get_current_ops() self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) @@ -1244,6 +1240,10 @@ class Language: proc.initialize, p_settings, section="components", name=name ) proc.initialize(get_examples, nlp=self, **p_settings) + pretrain_cfg = config.get("pretraining") + if pretrain_cfg: + P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) + init_tok2vec(self, P, I) self._link_components() self._optimizer = sgd if sgd is not None: @@ -1592,6 +1592,7 @@ class Language: # using the nlp.config with all defaults. config = util.copy_config(config) orig_pipeline = config.pop("components", {}) + orig_pretraining = config.pop("pretraining", None) config["components"] = {} if auto_fill: filled = registry.fill(config, validate=validate, schema=ConfigSchema) @@ -1599,6 +1600,9 @@ class Language: filled = config filled["components"] = orig_pipeline config["components"] = orig_pipeline + if orig_pretraining is not None: + filled["pretraining"] = orig_pretraining + config["pretraining"] = orig_pretraining resolved_nlp = registry.resolve( filled["nlp"], validate=validate, schema=ConfigSchemaNlp ) diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py index 8aa0f3c2b..cbfa59eea 100644 --- a/spacy/ml/models/multi_task.py +++ b/spacy/ml/models/multi_task.py @@ -21,6 +21,8 @@ def create_pretrain_vectors( maxout_pieces: int, hidden_size: int, loss: str ) -> Callable[["Vocab", Model], Model]: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: + if vocab.vectors.data.shape[1] == 0: + raise ValueError(Errors.E875) model = build_cloze_multi_task_model( vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces ) @@ -134,7 +136,7 @@ def build_cloze_characters_multi_task_model( ) -> Model: output_layer = chain( list2array(), - Maxout(hidden_size, nP=maxout_pieces), + Maxout(nO=hidden_size, nP=maxout_pieces), LayerNorm(nI=hidden_size), MultiSoftmax([256] * nr_char, nI=hidden_size), ) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 86f726c43..66b66b744 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -293,7 +293,7 @@ def test_serialize_parser(parser_config_string): def test_config_nlp_roundtrip(): - """Test that a config prduced by the nlp object passes training config + """Test that a config produced by the nlp object passes training config validation.""" nlp = English() nlp.add_pipe("entity_ruler") diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index a3834f31a..c36be9c57 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -4,7 +4,7 @@ from spacy.training import docs_to_json, offsets_to_biluo_tags from spacy.training.converters import iob_to_docs, conll_ner_to_docs, conllu_to_docs from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.lang.nl import Dutch -from spacy.util import ENV_VARS +from spacy.util import ENV_VARS, load_model_from_config from spacy.cli import info from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides @@ -397,10 +397,14 @@ def test_parse_cli_overrides(): "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]] ) @pytest.mark.parametrize("optimize", ["efficiency", "accuracy"]) -def test_init_config(lang, pipeline, optimize): +@pytest.mark.parametrize("pretraining", [True, False]) +def test_init_config(lang, pipeline, optimize, pretraining): # TODO: add more tests and also check for GPU with transformers - config = init_config(lang=lang, pipeline=pipeline, optimize=optimize, gpu=False) + config = init_config(lang=lang, pipeline=pipeline, optimize=optimize, pretraining=pretraining, gpu=False) assert isinstance(config, Config) + if pretraining: + config["paths"]["raw_text"] = "my_data.jsonl" + nlp = load_model_from_config(config, auto_fill=True) def test_model_recommendations(): diff --git a/spacy/tests/training/test_pretraining.py b/spacy/tests/training/test_pretraining.py new file mode 100644 index 000000000..bd8810a5c --- /dev/null +++ b/spacy/tests/training/test_pretraining.py @@ -0,0 +1,345 @@ +from pathlib import Path +import numpy as np +import pytest +import srsly +from spacy.vocab import Vocab +from thinc.api import Config + +from ..util import make_tempdir +from ... import util +from ...lang.en import English +from ...training.initialize import init_nlp +from ...training.loop import train +from ...training.pretrain import pretrain +from ...tokens import Doc, DocBin +from ...language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH + +pretrain_string_listener = """ +[nlp] +lang = "en" +pipeline = ["tok2vec", "tagger"] + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v1" + +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.width} + +[pretraining] +max_epochs = 5 + +[training] +max_epochs = 5 +""" + +pretrain_string_internal = """ +[nlp] +lang = "en" +pipeline = ["tagger"] + +[components] + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v1" + +[components.tagger.model.tok2vec] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[pretraining] +max_epochs = 5 + +[training] +max_epochs = 5 +""" + + +pretrain_string_vectors = """ +[nlp] +lang = "en" +pipeline = ["tok2vec", "tagger"] + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v1" + +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.width} + +[pretraining] +max_epochs = 5 + +[pretraining.objective] +@architectures = spacy.PretrainVectors.v1 +maxout_pieces = 3 +hidden_size = 300 +loss = cosine + +[training] +max_epochs = 5 +""" + +CHAR_OBJECTIVES = [ + {}, + {"@architectures": "spacy.PretrainCharacters.v1"}, + { + "@architectures": "spacy.PretrainCharacters.v1", + "maxout_pieces": 5, + "hidden_size": 42, + "n_characters": 2, + }, +] + +VECTOR_OBJECTIVES = [ + { + "@architectures": "spacy.PretrainVectors.v1", + "maxout_pieces": 3, + "hidden_size": 300, + "loss": "cosine", + }, + { + "@architectures": "spacy.PretrainVectors.v1", + "maxout_pieces": 2, + "hidden_size": 200, + "loss": "L2", + }, +] + + +def test_pretraining_default(): + """Test that pretraining defaults to a character objective""" + config = Config().from_str(pretrain_string_internal) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + assert "PretrainCharacters" in filled["pretraining"]["objective"]["@architectures"] + + +@pytest.mark.parametrize("objective", CHAR_OBJECTIVES) +def test_pretraining_tok2vec_characters(objective): + """Test that pretraining works with the character objective""" + config = Config().from_str(pretrain_string_listener) + config["pretraining"]["objective"] = objective + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled = filled.interpolate() + assert filled["pretraining"]["component"] == "tok2vec" + pretrain(filled, tmp_dir) + assert Path(tmp_dir / "model0.bin").exists() + assert Path(tmp_dir / "model4.bin").exists() + assert not Path(tmp_dir / "model5.bin").exists() + + +@pytest.mark.parametrize("objective", VECTOR_OBJECTIVES) +def test_pretraining_tok2vec_vectors_fail(objective): + """Test that pretraining doesn't works with the vectors objective if there are no static vectors""" + config = Config().from_str(pretrain_string_listener) + config["pretraining"]["objective"] = objective + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled = filled.interpolate() + assert filled["initialize"]["vectors"] is None + with pytest.raises(ValueError): + pretrain(filled, tmp_dir) + + +@pytest.mark.parametrize("objective", VECTOR_OBJECTIVES) +def test_pretraining_tok2vec_vectors(objective): + """Test that pretraining works with the vectors objective and static vectors defined""" + config = Config().from_str(pretrain_string_listener) + config["pretraining"]["objective"] = objective + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + nlp_path = write_vectors_model(tmp_dir) + filled["initialize"]["vectors"] = nlp_path + filled = filled.interpolate() + pretrain(filled, tmp_dir) + + +@pytest.mark.parametrize("config", [pretrain_string_internal, pretrain_string_listener]) +def test_pretraining_tagger_tok2vec(config): + """Test pretraining of the tagger's tok2vec layer (via a listener)""" + config = Config().from_str(pretrain_string_listener) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled["pretraining"]["component"] = "tagger" + filled["pretraining"]["layer"] = "tok2vec" + filled = filled.interpolate() + pretrain(filled, tmp_dir) + assert Path(tmp_dir / "model0.bin").exists() + assert Path(tmp_dir / "model4.bin").exists() + assert not Path(tmp_dir / "model5.bin").exists() + + +def test_pretraining_tagger(): + """Test pretraining of the tagger itself will throw an error (not an appropriate tok2vec layer)""" + config = Config().from_str(pretrain_string_internal) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + with make_tempdir() as tmp_dir: + file_path = write_sample_jsonl(tmp_dir) + filled["paths"]["raw_text"] = file_path + filled["pretraining"]["component"] = "tagger" + filled = filled.interpolate() + with pytest.raises(ValueError): + pretrain(filled, tmp_dir) + + +def test_pretraining_training(): + """Test that training can use a pretrained Tok2Vec model""" + config = Config().from_str(pretrain_string_internal) + nlp = util.load_model_from_config(config, auto_fill=True, validate=False) + filled = nlp.config + pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) + filled = pretrain_config.merge(filled) + train_config = util.load_config(DEFAULT_CONFIG_PATH) + filled = train_config.merge(filled) + with make_tempdir() as tmp_dir: + pretrain_dir = tmp_dir / "pretrain" + pretrain_dir.mkdir() + file_path = write_sample_jsonl(pretrain_dir) + filled["paths"]["raw_text"] = file_path + filled["pretraining"]["component"] = "tagger" + filled["pretraining"]["layer"] = "tok2vec" + train_dir = tmp_dir / "train" + train_dir.mkdir() + train_path, dev_path = write_sample_training(train_dir) + filled["paths"]["train"] = train_path + filled["paths"]["dev"] = dev_path + filled = filled.interpolate() + P = filled["pretraining"] + nlp_base = init_nlp(filled) + model_base = nlp_base.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed") + embed_base = None + for node in model_base.walk(): + if node.name == "hashembed": + embed_base = node + pretrain(filled, pretrain_dir) + pretrained_model = Path(pretrain_dir / "model3.bin") + assert pretrained_model.exists() + filled["initialize"]["init_tok2vec"] = str(pretrained_model) + nlp = init_nlp(filled) + model = nlp.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed") + embed = None + for node in model.walk(): + if node.name == "hashembed": + embed = node + # ensure that the tok2vec weights are actually changed by the pretraining + assert np.any(np.not_equal(embed.get_param("E"), embed_base.get_param("E"))) + train(nlp, train_dir) + + +def write_sample_jsonl(tmp_dir): + data = [ + { + "meta": {"id": "1"}, + "text": "This is the best TV you'll ever buy!", + "cats": {"pos": 1, "neg": 0}, + }, + { + "meta": {"id": "2"}, + "text": "I wouldn't buy this again.", + "cats": {"pos": 0, "neg": 1}, + }, + ] + file_path = f"{tmp_dir}/text.jsonl" + srsly.write_jsonl(file_path, data) + return file_path + + +def write_sample_training(tmp_dir): + words = ["The", "players", "start", "."] + tags = ["DT", "NN", "VBZ", "."] + doc = Doc(English().vocab, words=words, tags=tags) + doc_bin = DocBin() + doc_bin.add(doc) + train_path = f"{tmp_dir}/train.spacy" + dev_path = f"{tmp_dir}/dev.spacy" + doc_bin.to_disk(train_path) + doc_bin.to_disk(dev_path) + return train_path, dev_path + + +def write_vectors_model(tmp_dir): + import numpy + vocab = Vocab() + vector_data = { + "dog": numpy.random.uniform(-1, 1, (300,)), + "cat": numpy.random.uniform(-1, 1, (300,)), + "orange": numpy.random.uniform(-1, 1, (300,)) + } + for word, vector in vector_data.items(): + vocab.set_vector(word, vector) + nlp_path = tmp_dir / "vectors_model" + nlp = English(vocab) + nlp.to_disk(nlp_path) + return str(nlp_path) diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 25bb73c78..f7f2f21a4 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -9,6 +9,7 @@ import gzip import zipfile import tqdm +from .pretrain import get_tok2vec_ref from ..lookups import Lookups from ..vectors import Vectors from ..errors import Errors, Warnings @@ -147,10 +148,6 @@ def init_tok2vec( weights_data = None init_tok2vec = ensure_path(I["init_tok2vec"]) if init_tok2vec is not None: - if P["objective"].get("type") == "vectors" and not I["vectors"]: - err = 'need initialize.vectors if pretraining.objective.type is "vectors"' - errors = [{"loc": ["initialize"], "msg": err}] - raise ConfigValidationError(config=nlp.config, errors=errors) if not init_tok2vec.exists(): err = f"can't find pretrained tok2vec: {init_tok2vec}" errors = [{"loc": ["initialize", "init_tok2vec"], "msg": err}] @@ -158,21 +155,9 @@ def init_tok2vec( with init_tok2vec.open("rb") as file_: weights_data = file_.read() if weights_data is not None: - tok2vec_component = P["component"] - if tok2vec_component is None: - desc = ( - f"To use pretrained tok2vec weights, [pretraining.component] " - f"needs to specify the component that should load them." - ) - err = "component can't be null" - errors = [{"loc": ["pretraining", "component"], "msg": err}] - raise ConfigValidationError( - config=nlp.config["pretraining"], errors=errors, desc=desc - ) - layer = nlp.get_pipe(tok2vec_component).model - if P["layer"]: - layer = layer.get_ref(P["layer"]) + layer = get_tok2vec_ref(nlp, P) layer.from_bytes(weights_data) + logger.info(f"Loaded pretrained weights from {init_tok2vec}") return True return False diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 152d849e9..c791732db 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -6,9 +6,12 @@ from collections import Counter import srsly import time import re + +from thinc.config import ConfigValidationError from wasabi import Printer from .example import Example +from ..errors import Errors from ..tokens import Doc from ..schemas import ConfigSchemaPretrain from ..util import registry, load_model_from_config, dot_to_object @@ -133,12 +136,21 @@ def create_pretraining_model(nlp, pretrain_config): The actual tok2vec layer is stored as a reference, and only this bit will be serialized to file and read back in when calling the 'train' command. """ - nlp.initialize() - component = nlp.get_pipe(pretrain_config["component"]) - if pretrain_config.get("layer"): - tok2vec = component.model.get_ref(pretrain_config["layer"]) - else: - tok2vec = component.model + with nlp.select_pipes(enable=[]): + nlp.initialize() + tok2vec = get_tok2vec_ref(nlp, pretrain_config) + # If the config referred to a Tok2VecListener, grab the original model instead + if type(tok2vec).__name__ == "Tok2VecListener": + original_tok2vec = ( + tok2vec.upstream_name if tok2vec.upstream_name is not "*" else "tok2vec" + ) + tok2vec = nlp.get_pipe(original_tok2vec).model + try: + tok2vec.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")]) + except ValueError: + component = pretrain_config["component"] + layer = pretrain_config["layer"] + raise ValueError(Errors.E874.format(component=component, layer=layer)) create_function = pretrain_config["objective"] model = create_function(nlp.vocab, tok2vec) @@ -147,6 +159,24 @@ def create_pretraining_model(nlp, pretrain_config): return model +def get_tok2vec_ref(nlp, pretrain_config): + tok2vec_component = pretrain_config["component"] + if tok2vec_component is None: + desc = ( + f"To use pretrained tok2vec weights, [pretraining.component] " + f"needs to specify the component that should load them." + ) + err = "component can't be null" + errors = [{"loc": ["pretraining", "component"], "msg": err}] + raise ConfigValidationError( + config=nlp.config["pretraining"], errors=errors, desc=desc + ) + layer = nlp.get_pipe(tok2vec_component).model + if pretrain_config["layer"]: + layer = layer.get_ref(pretrain_config["layer"]) + return layer + + class ProgressTracker: def __init__(self, frequency=1000000): self.loss = 0.0 diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 1739836ed..793855d18 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -447,6 +447,9 @@ For more information, see the section on > ```ini > [pretraining] > component = "tok2vec" +> +> [initialize] +> vectors = "en_core_web_lg" > ... > > [pretraining.objective] @@ -457,7 +460,9 @@ For more information, see the section on > ``` Predict the word's vector from a static embeddings table as pretraining -objective for a Tok2Vec layer. +objective for a Tok2Vec layer. To use this objective, make sure that the +`initialize.vectors` section in the config refers to a model with static +vectors. | Name | Description | | --------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- |