mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Support floret for PretrainVectors (#12435)
* Support floret for PretrainVectors * Format
This commit is contained in:
parent
d0bd3f5ee4
commit
fac457a509
|
@ -549,8 +549,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"during training, make sure to include it in 'annotating components'")
|
||||
|
||||
# New errors added in v3.x
|
||||
E850 = ("The PretrainVectors objective currently only supports default "
|
||||
"vectors, not {mode} vectors.")
|
||||
E850 = ("The PretrainVectors objective currently only supports default or "
|
||||
"floret vectors, not {mode} vectors.")
|
||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||
"but found value of '{val}'.")
|
||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast
|
||||
from thinc.types import Floats2d
|
||||
from thinc.types import Floats2d, Ints1d
|
||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
||||
from thinc.api import MultiSoftmax, list2array
|
||||
from thinc.api import to_categorical, CosineDistance, L2Distance
|
||||
|
@ -7,7 +7,7 @@ from thinc.loss import Loss
|
|||
|
||||
from ...util import registry, OOV_RANK
|
||||
from ...errors import Errors
|
||||
from ...attrs import ID
|
||||
from ...attrs import ID, ORTH
|
||||
from ...vectors import Mode as VectorsMode
|
||||
|
||||
import numpy
|
||||
|
@ -24,8 +24,6 @@ def create_pretrain_vectors(
|
|||
maxout_pieces: int, hidden_size: int, loss: str
|
||||
) -> Callable[["Vocab", Model], Model]:
|
||||
def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model:
|
||||
if vocab.vectors.mode != VectorsMode.default:
|
||||
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
|
||||
if vocab.vectors.shape[1] == 0:
|
||||
raise ValueError(Errors.E875)
|
||||
model = build_cloze_multi_task_model(
|
||||
|
@ -70,14 +68,23 @@ def get_vectors_loss(ops, docs, prediction, distance):
|
|||
"""Compute a loss based on a distance between the documents' vectors and
|
||||
the prediction.
|
||||
"""
|
||||
# The simplest way to implement this would be to vstack the
|
||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||
# Instead we fetch the index into the vectors table for each of our tokens,
|
||||
# and look them up all at once. This prevents data copying.
|
||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||
target = docs[0].vocab.vectors.data[ids]
|
||||
target[ids == OOV_RANK] = 0
|
||||
d_target, loss = distance(prediction, target)
|
||||
vocab = docs[0].vocab
|
||||
if vocab.vectors.mode == VectorsMode.default:
|
||||
# The simplest way to implement this would be to vstack the
|
||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||
# Instead we fetch the index into the vectors table for each of our
|
||||
# tokens, and look them up all at once. This prevents data copying.
|
||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||
target = docs[0].vocab.vectors.data[ids]
|
||||
target[ids == OOV_RANK] = 0
|
||||
d_target, loss = distance(prediction, target)
|
||||
elif vocab.vectors.mode == VectorsMode.floret:
|
||||
keys = ops.flatten([cast(Ints1d, doc.to_array(ORTH)) for doc in docs])
|
||||
target = vocab.vectors.get_batch(keys)
|
||||
target = ops.as_contig(target)
|
||||
d_target, loss = distance(prediction, target)
|
||||
else:
|
||||
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
|
||||
return loss, d_target
|
||||
|
||||
|
||||
|
|
|
@ -359,19 +359,15 @@ def test_pretrain_default_vectors():
|
|||
nlp.vocab.vectors = Vectors(shape=(10, 10))
|
||||
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
|
||||
|
||||
# floret vectors are supported
|
||||
nlp.vocab.vectors = Vectors(
|
||||
data=get_current_ops().xp.zeros((10, 10)), mode="floret", hash_count=1
|
||||
)
|
||||
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
|
||||
|
||||
# error for no vectors
|
||||
with pytest.raises(ValueError, match="E875"):
|
||||
nlp.vocab.vectors = Vectors()
|
||||
create_pretrain_vectors(1, 1, "cosine")(
|
||||
nlp.vocab, nlp.get_pipe("tok2vec").model
|
||||
)
|
||||
|
||||
# error for floret vectors
|
||||
with pytest.raises(ValueError, match="E850"):
|
||||
ops = get_current_ops()
|
||||
nlp.vocab.vectors = Vectors(
|
||||
data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1
|
||||
)
|
||||
create_pretrain_vectors(1, 1, "cosine")(
|
||||
nlp.vocab, nlp.get_pipe("tok2vec").model
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user