diff --git a/spacy/errors.py b/spacy/errors.py index c897c29ff..40cfa8d92 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -549,8 +549,8 @@ class Errors(metaclass=ErrorsWithCodes): "during training, make sure to include it in 'annotating components'") # New errors added in v3.x - E850 = ("The PretrainVectors objective currently only supports default " - "vectors, not {mode} vectors.") + E850 = ("The PretrainVectors objective currently only supports default or " + "floret vectors, not {mode} vectors.") E851 = ("The 'textcat' component labels should only have values of 0 or 1, " "but found value of '{val}'.") E852 = ("The tar file pulled from the remote attempted an unsafe path " diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py index 826fddd4f..7eb13b608 100644 --- a/spacy/ml/models/multi_task.py +++ b/spacy/ml/models/multi_task.py @@ -1,5 +1,5 @@ from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast -from thinc.types import Floats2d +from thinc.types import Floats2d, Ints1d from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model from thinc.api import MultiSoftmax, list2array from thinc.api import to_categorical, CosineDistance, L2Distance @@ -7,7 +7,7 @@ from thinc.loss import Loss from ...util import registry, OOV_RANK from ...errors import Errors -from ...attrs import ID +from ...attrs import ID, ORTH from ...vectors import Mode as VectorsMode import numpy @@ -24,8 +24,6 @@ def create_pretrain_vectors( maxout_pieces: int, hidden_size: int, loss: str ) -> Callable[["Vocab", Model], Model]: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: - if vocab.vectors.mode != VectorsMode.default: - raise ValueError(Errors.E850.format(mode=vocab.vectors.mode)) if vocab.vectors.shape[1] == 0: raise ValueError(Errors.E875) model = build_cloze_multi_task_model( @@ -70,14 +68,23 @@ def get_vectors_loss(ops, docs, prediction, distance): """Compute a loss based on a distance between the documents' vectors and the prediction. """ - # The simplest way to implement this would be to vstack the - # token.vector values, but that's a bit inefficient, especially on GPU. - # Instead we fetch the index into the vectors table for each of our tokens, - # and look them up all at once. This prevents data copying. - ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) - target = docs[0].vocab.vectors.data[ids] - target[ids == OOV_RANK] = 0 - d_target, loss = distance(prediction, target) + vocab = docs[0].vocab + if vocab.vectors.mode == VectorsMode.default: + # The simplest way to implement this would be to vstack the + # token.vector values, but that's a bit inefficient, especially on GPU. + # Instead we fetch the index into the vectors table for each of our + # tokens, and look them up all at once. This prevents data copying. + ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) + target = docs[0].vocab.vectors.data[ids] + target[ids == OOV_RANK] = 0 + d_target, loss = distance(prediction, target) + elif vocab.vectors.mode == VectorsMode.floret: + keys = ops.flatten([cast(Ints1d, doc.to_array(ORTH)) for doc in docs]) + target = vocab.vectors.get_batch(keys) + target = ops.as_contig(target) + d_target, loss = distance(prediction, target) + else: + raise ValueError(Errors.E850.format(mode=vocab.vectors.mode)) return loss, d_target diff --git a/spacy/tests/training/test_pretraining.py b/spacy/tests/training/test_pretraining.py index c0d64f1e7..d1db92de5 100644 --- a/spacy/tests/training/test_pretraining.py +++ b/spacy/tests/training/test_pretraining.py @@ -359,19 +359,15 @@ def test_pretrain_default_vectors(): nlp.vocab.vectors = Vectors(shape=(10, 10)) create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model) + # floret vectors are supported + nlp.vocab.vectors = Vectors( + data=get_current_ops().xp.zeros((10, 10)), mode="floret", hash_count=1 + ) + create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model) + # error for no vectors with pytest.raises(ValueError, match="E875"): nlp.vocab.vectors = Vectors() create_pretrain_vectors(1, 1, "cosine")( nlp.vocab, nlp.get_pipe("tok2vec").model ) - - # error for floret vectors - with pytest.raises(ValueError, match="E850"): - ops = get_current_ops() - nlp.vocab.vectors = Vectors( - data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1 - ) - create_pretrain_vectors(1, 1, "cosine")( - nlp.vocab, nlp.get_pipe("tok2vec").model - )