mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-21 22:10:34 +03:00
Move get_characters_loss
This commit is contained in:
parent
892f2552e0
commit
cd2fa89d93
14
spacy/_ml.py
14
spacy/_ml.py
|
@ -988,3 +988,17 @@ def get_cossim_loss(yh, y, ignore_zeros=False):
|
|||
losses[zero_indices] = 0
|
||||
loss = losses.sum()
|
||||
return loss, -d_yh
|
||||
|
||||
|
||||
def get_characters_loss(ops, docs, prediction, nr_char=10):
|
||||
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
|
||||
target_ids = target_ids.reshape((-1,))
|
||||
target = ops.asarray(to_categorical(target_ids, nb_classes=256), dtype="f")
|
||||
target = target.reshape((-1, 256*nr_char))
|
||||
diff = prediction - target
|
||||
loss = (diff**2).sum()
|
||||
d_target = diff / float(prediction.shape[0])
|
||||
return loss, d_target
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from ..errors import Errors
|
|||
from ..tokens import Doc
|
||||
from ..attrs import ID, HEAD
|
||||
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
|
||||
from .._ml import masked_language_model, get_cossim_loss
|
||||
from .._ml import masked_language_model, get_cossim_loss, get_characters_loss
|
||||
from .._ml import MultiSoftmax
|
||||
from .. import util
|
||||
from .train import _load_pretrained_tok2vec
|
||||
|
@ -304,17 +304,6 @@ def make_docs(nlp, batch, min_length, max_length):
|
|||
return docs, skip_count
|
||||
|
||||
|
||||
def get_characters_loss(ops, docs, prediction, nr_char=10):
|
||||
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
|
||||
target_ids = target_ids.reshape((-1,))
|
||||
target = ops.asarray(to_categorical(target_ids, nb_classes=256), dtype="f")
|
||||
target = target.reshape((-1, 256*nr_char))
|
||||
diff = prediction - target
|
||||
loss = (diff**2).sum()
|
||||
d_target = diff / float(prediction.shape[0])
|
||||
return loss, d_target
|
||||
|
||||
|
||||
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||
"""Compute a mean-squared error loss between the documents' vectors and
|
||||
the prediction.
|
||||
|
|
Loading…
Reference in New Issue
Block a user