Support floret for PretrainVectors (#12435)

* Support floret for PretrainVectors

* Format
This commit is contained in:
Adriane Boyd 2023-03-24 16:28:51 +01:00 committed by GitHub
parent d0bd3f5ee4
commit fac457a509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 24 deletions

View File

@ -549,8 +549,8 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'") "during training, make sure to include it in 'annotating components'")
# New errors added in v3.x # New errors added in v3.x
E850 = ("The PretrainVectors objective currently only supports default " E850 = ("The PretrainVectors objective currently only supports default or "
"vectors, not {mode} vectors.") "floret vectors, not {mode} vectors.")
E851 = ("The 'textcat' component labels should only have values of 0 or 1, " E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
"but found value of '{val}'.") "but found value of '{val}'.")
E852 = ("The tar file pulled from the remote attempted an unsafe path " E852 = ("The tar file pulled from the remote attempted an unsafe path "

View File

@ -1,5 +1,5 @@
from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast
from thinc.types import Floats2d from thinc.types import Floats2d, Ints1d
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
from thinc.api import MultiSoftmax, list2array from thinc.api import MultiSoftmax, list2array
from thinc.api import to_categorical, CosineDistance, L2Distance from thinc.api import to_categorical, CosineDistance, L2Distance
@ -7,7 +7,7 @@ from thinc.loss import Loss
from ...util import registry, OOV_RANK from ...util import registry, OOV_RANK
from ...errors import Errors from ...errors import Errors
from ...attrs import ID from ...attrs import ID, ORTH
from ...vectors import Mode as VectorsMode from ...vectors import Mode as VectorsMode
import numpy import numpy
@ -24,8 +24,6 @@ def create_pretrain_vectors(
maxout_pieces: int, hidden_size: int, loss: str maxout_pieces: int, hidden_size: int, loss: str
) -> Callable[["Vocab", Model], Model]: ) -> Callable[["Vocab", Model], Model]:
def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model: def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model:
if vocab.vectors.mode != VectorsMode.default:
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
if vocab.vectors.shape[1] == 0: if vocab.vectors.shape[1] == 0:
raise ValueError(Errors.E875) raise ValueError(Errors.E875)
model = build_cloze_multi_task_model( model = build_cloze_multi_task_model(
@ -70,14 +68,23 @@ def get_vectors_loss(ops, docs, prediction, distance):
"""Compute a loss based on a distance between the documents' vectors and """Compute a loss based on a distance between the documents' vectors and
the prediction. the prediction.
""" """
vocab = docs[0].vocab
if vocab.vectors.mode == VectorsMode.default:
# The simplest way to implement this would be to vstack the # The simplest way to implement this would be to vstack the
# token.vector values, but that's a bit inefficient, especially on GPU. # token.vector values, but that's a bit inefficient, especially on GPU.
# Instead we fetch the index into the vectors table for each of our tokens, # Instead we fetch the index into the vectors table for each of our
# and look them up all at once. This prevents data copying. # tokens, and look them up all at once. This prevents data copying.
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
target = docs[0].vocab.vectors.data[ids] target = docs[0].vocab.vectors.data[ids]
target[ids == OOV_RANK] = 0 target[ids == OOV_RANK] = 0
d_target, loss = distance(prediction, target) d_target, loss = distance(prediction, target)
elif vocab.vectors.mode == VectorsMode.floret:
keys = ops.flatten([cast(Ints1d, doc.to_array(ORTH)) for doc in docs])
target = vocab.vectors.get_batch(keys)
target = ops.as_contig(target)
d_target, loss = distance(prediction, target)
else:
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
return loss, d_target return loss, d_target

View File

@ -359,19 +359,15 @@ def test_pretrain_default_vectors():
nlp.vocab.vectors = Vectors(shape=(10, 10)) nlp.vocab.vectors = Vectors(shape=(10, 10))
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model) create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
# floret vectors are supported
nlp.vocab.vectors = Vectors(
data=get_current_ops().xp.zeros((10, 10)), mode="floret", hash_count=1
)
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
# error for no vectors # error for no vectors
with pytest.raises(ValueError, match="E875"): with pytest.raises(ValueError, match="E875"):
nlp.vocab.vectors = Vectors() nlp.vocab.vectors = Vectors()
create_pretrain_vectors(1, 1, "cosine")( create_pretrain_vectors(1, 1, "cosine")(
nlp.vocab, nlp.get_pipe("tok2vec").model nlp.vocab, nlp.get_pipe("tok2vec").model
) )
# error for floret vectors
with pytest.raises(ValueError, match="E850"):
ops = get_current_ops()
nlp.vocab.vectors = Vectors(
data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1
)
create_pretrain_vectors(1, 1, "cosine")(
nlp.vocab, nlp.get_pipe("tok2vec").model
)