mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-15 06:09:01 +03:00
Implement character-based pretraining objective
This commit is contained in:
parent
36de9bf72a
commit
ee56c6a4e1
|
@ -13,12 +13,14 @@ from thinc.misc import LayerNorm as LN
|
||||||
from thinc.neural.util import prefer_gpu
|
from thinc.neural.util import prefer_gpu
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import srsly
|
import srsly
|
||||||
|
from thinc.neural.util import to_categorical
|
||||||
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import ID, HEAD
|
from ..attrs import ID, HEAD
|
||||||
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
|
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
|
||||||
from .._ml import masked_language_model, get_cossim_loss
|
from .._ml import masked_language_model, get_cossim_loss
|
||||||
|
from .._ml import MultiSoftmax
|
||||||
from .. import util
|
from .. import util
|
||||||
from .train import _load_pretrained_tok2vec
|
from .train import _load_pretrained_tok2vec
|
||||||
|
|
||||||
|
@ -121,11 +123,7 @@ def pretrain(
|
||||||
msg = Printer()
|
msg = Printer()
|
||||||
util.fix_random_seed(seed)
|
util.fix_random_seed(seed)
|
||||||
|
|
||||||
has_gpu = prefer_gpu()
|
has_gpu = prefer_gpu(gpu_id=1)
|
||||||
if has_gpu:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
|
||||||
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
@ -167,6 +165,7 @@ def pretrain(
|
||||||
cnn_maxout_pieces=3, # You can try setting this higher
|
cnn_maxout_pieces=3, # You can try setting this higher
|
||||||
subword_features=not use_chars, # Set to False for Chinese etc
|
subword_features=not use_chars, # Set to False for Chinese etc
|
||||||
),
|
),
|
||||||
|
objective=loss_func
|
||||||
)
|
)
|
||||||
# Load in pretrained weights
|
# Load in pretrained weights
|
||||||
if init_tok2vec is not None:
|
if init_tok2vec is not None:
|
||||||
|
@ -257,7 +256,10 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
||||||
RETURNS loss: A float for the loss.
|
RETURNS loss: A float for the loss.
|
||||||
"""
|
"""
|
||||||
predictions, backprop = model.begin_update(docs, drop=drop)
|
predictions, backprop = model.begin_update(docs, drop=drop)
|
||||||
loss, gradients = get_vectors_loss(model.ops, docs, predictions, objective)
|
if objective == "characters":
|
||||||
|
loss, gradients = get_characters_loss(model.ops, docs, predictions)
|
||||||
|
else:
|
||||||
|
loss, gradients = get_vectors_loss(model.ops, docs, predictions, objective)
|
||||||
backprop(gradients, sgd=optimizer)
|
backprop(gradients, sgd=optimizer)
|
||||||
# Don't want to return a cupy object here
|
# Don't want to return a cupy object here
|
||||||
# The gradients are modified in-place by the BERT MLM,
|
# The gradients are modified in-place by the BERT MLM,
|
||||||
|
@ -295,6 +297,17 @@ def make_docs(nlp, batch, min_length, max_length):
|
||||||
return docs, skip_count
|
return docs, skip_count
|
||||||
|
|
||||||
|
|
||||||
|
def get_characters_loss(ops, docs, prediction, nr_char=10):
|
||||||
|
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
|
||||||
|
target_ids = target_ids.reshape((-1,))
|
||||||
|
target = ops.asarray(to_categorical(target_ids, nb_classes=256), dtype="f")
|
||||||
|
target = target.reshape((-1, 256*nr_char))
|
||||||
|
diff = prediction - target
|
||||||
|
loss = (diff**2).sum()
|
||||||
|
d_target = diff / float(prediction.shape[0])
|
||||||
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||||
"""Compute a mean-squared error loss between the documents' vectors and
|
"""Compute a mean-squared error loss between the documents' vectors and
|
||||||
the prediction.
|
the prediction.
|
||||||
|
@ -319,16 +332,23 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||||
return loss, d_target
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
def create_pretraining_model(nlp, tok2vec):
|
def create_pretraining_model(nlp, tok2vec, objective="cosine", nr_char=10):
|
||||||
"""Define a network for the pretraining. We simply add an output layer onto
|
"""Define a network for the pretraining. We simply add an output layer onto
|
||||||
the tok2vec input model. The tok2vec input model needs to be a model that
|
the tok2vec input model. The tok2vec input model needs to be a model that
|
||||||
takes a batch of Doc objects (as a list), and returns a list of arrays.
|
takes a batch of Doc objects (as a list), and returns a list of arrays.
|
||||||
Each array in the output needs to have one row per token in the doc.
|
Each array in the output needs to have one row per token in the doc.
|
||||||
"""
|
"""
|
||||||
output_size = nlp.vocab.vectors.data.shape[1]
|
if objective == "characters":
|
||||||
output_layer = chain(
|
out_sizes = [256] * nr_char
|
||||||
LN(Maxout(300, pieces=3)), Affine(output_size, drop_factor=0.0)
|
output_layer = chain(
|
||||||
)
|
LN(Maxout(300, pieces=3)),
|
||||||
|
MultiSoftmax(out_sizes, 300)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_size = nlp.vocab.vectors.data.shape[1]
|
||||||
|
output_layer = chain(
|
||||||
|
LN(Maxout(300, pieces=3)), Affine(output_size, drop_factor=0.0)
|
||||||
|
)
|
||||||
# This is annoying, but the parser etc have the flatten step after
|
# This is annoying, but the parser etc have the flatten step after
|
||||||
# the tok2vec. To load the weights in cleanly, we need to match
|
# the tok2vec. To load the weights in cleanly, we need to match
|
||||||
# the shape of the models' components exactly. So what we cann
|
# the shape of the models' components exactly. So what we cann
|
||||||
|
|
Loading…
Reference in New Issue
Block a user