Raise error for non-default vectors with PretrainVectors (#12366)

This commit is contained in:
Adriane Boyd 2023-03-06 18:06:31 +01:00
parent a86ec1b2b1
commit 6177c87539
3 changed files with 43 additions and 9 deletions

View File

@ -549,6 +549,8 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'") "during training, make sure to include it in 'annotating components'")
# New errors added in v3.x # New errors added in v3.x
E850 = ("The PretrainVectors objective currently only supports default "
"vectors, not {mode} vectors.")
E851 = ("The 'textcat' component labels should only have values of 0 or 1, " E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
"but found value of '{val}'.") "but found value of '{val}'.")
E852 = ("The tar file pulled from the remote attempted an unsafe path " E852 = ("The tar file pulled from the remote attempted an unsafe path "

View File

@ -8,6 +8,7 @@ from thinc.loss import Loss
from ...util import registry, OOV_RANK from ...util import registry, OOV_RANK
from ...errors import Errors from ...errors import Errors
from ...attrs import ID from ...attrs import ID
from ...vectors import Mode as VectorsMode
import numpy import numpy
from functools import partial from functools import partial
@ -23,6 +24,8 @@ def create_pretrain_vectors(
maxout_pieces: int, hidden_size: int, loss: str maxout_pieces: int, hidden_size: int, loss: str
) -> Callable[["Vocab", Model], Model]: ) -> Callable[["Vocab", Model], Model]:
def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model:
if vocab.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
if vocab.vectors.shape[1] == 0: if vocab.vectors.shape[1] == 0:
raise ValueError(Errors.E875) raise ValueError(Errors.E875)
model = build_cloze_multi_task_model( model = build_cloze_multi_task_model(

View File

@ -2,17 +2,19 @@ from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
import srsly import srsly
from spacy.vocab import Vocab from thinc.api import Config, get_current_ops
from thinc.api import Config
from spacy import util
from spacy.lang.en import English
from spacy.training.initialize import init_nlp
from spacy.training.loop import train
from spacy.training.pretrain import pretrain
from spacy.tokens import Doc, DocBin
from spacy.language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH
from spacy.ml.models.multi_task import create_pretrain_vectors
from spacy.vectors import Vectors
from spacy.vocab import Vocab
from ..util import make_tempdir from ..util import make_tempdir
from ... import util
from ...lang.en import English
from ...training.initialize import init_nlp
from ...training.loop import train
from ...training.pretrain import pretrain
from ...tokens import Doc, DocBin
from ...language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH
pretrain_string_listener = """ pretrain_string_listener = """
[nlp] [nlp]
@ -346,3 +348,30 @@ def write_vectors_model(tmp_dir):
nlp = English(vocab) nlp = English(vocab)
nlp.to_disk(nlp_path) nlp.to_disk(nlp_path)
return str(nlp_path) return str(nlp_path)
def test_pretrain_default_vectors():
nlp = English()
nlp.add_pipe("tok2vec")
nlp.initialize()
# default vectors are supported
nlp.vocab.vectors = Vectors(shape=(10, 10))
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
# error for no vectors
with pytest.raises(ValueError, match="E875"):
nlp.vocab.vectors = Vectors()
create_pretrain_vectors(1, 1, "cosine")(
nlp.vocab, nlp.get_pipe("tok2vec").model
)
# error for floret vectors
with pytest.raises(ValueError, match="E850"):
ops = get_current_ops()
nlp.vocab.vectors = Vectors(
data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1
)
create_pretrain_vectors(1, 1, "cosine")(
nlp.vocab, nlp.get_pipe("tok2vec").model
)