mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
💫 Better support for semi-supervised learning (#3035)
The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
This commit is contained in:
parent
449b889454
commit
83ac227bd3
77
examples/training/rehearsal.py
Normal file
77
examples/training/rehearsal.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
"""Prevent catastrophic forgetting with rehearsal updates."""
|
||||
import plac
|
||||
import random
|
||||
import srsly
|
||||
import spacy
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.util import minibatch
|
||||
|
||||
|
||||
LABEL = "ANIMAL"
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"Horses are too tall and they pretend to care about your feelings",
|
||||
{"entities": [(0, 6, "ANIMAL")]},
|
||||
),
|
||||
("Do they bite?", {"entities": []}),
|
||||
(
|
||||
"horses are too tall and they pretend to care about your feelings",
|
||||
{"entities": [(0, 6, "ANIMAL")]},
|
||||
),
|
||||
("horses pretend to care about your feelings", {"entities": [(0, 6, "ANIMAL")]}),
|
||||
(
|
||||
"they pretend to care about your feelings, those horses",
|
||||
{"entities": [(48, 54, "ANIMAL")]},
|
||||
),
|
||||
("horses?", {"entities": [(0, 6, "ANIMAL")]}),
|
||||
]
|
||||
|
||||
|
||||
def read_raw_data(nlp, jsonl_loc):
|
||||
for json_obj in srsly.read_jsonl(jsonl_loc):
|
||||
if json_obj["text"].strip():
|
||||
doc = nlp.make_doc(json_obj["text"])
|
||||
yield doc
|
||||
|
||||
|
||||
def read_gold_data(nlp, gold_loc):
|
||||
docs = []
|
||||
golds = []
|
||||
for json_obj in srsly.read_jsonl(gold_loc):
|
||||
doc = nlp.make_doc(json_obj["text"])
|
||||
ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]]
|
||||
gold = GoldParse(doc, entities=ents)
|
||||
docs.append(doc)
|
||||
golds.append(gold)
|
||||
return list(zip(docs, golds))
|
||||
|
||||
|
||||
def main(model_name, unlabelled_loc):
|
||||
n_iter = 10
|
||||
dropout = 0.2
|
||||
batch_size = 4
|
||||
nlp = spacy.load(model_name)
|
||||
nlp.get_pipe("ner").add_label(LABEL)
|
||||
raw_docs = list(read_raw_data(nlp, unlabelled_loc))
|
||||
optimizer = nlp.resume_training()
|
||||
|
||||
# get names of other pipes to disable them during training
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
for itn in range(n_iter):
|
||||
random.shuffle(TRAIN_DATA)
|
||||
random.shuffle(raw_docs)
|
||||
losses = {}
|
||||
r_losses = {}
|
||||
# batch up the examples using spaCy's minibatch
|
||||
raw_batches = minibatch(raw_docs, size=batch_size)
|
||||
for doc, gold in TRAIN_DATA:
|
||||
nlp.update([doc], [gold], sgd=optimizer, drop=dropout, losses=losses)
|
||||
raw_batch = list(next(raw_batches))
|
||||
nlp.rehearse(raw_batch, sgd=optimizer, losses=r_losses)
|
||||
print("Losses", losses)
|
||||
print("R. Losses", r_losses)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
plac.call(main)
|
88
spacy/_ml.py
88
spacy/_ml.py
|
@ -586,16 +586,8 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True,
|
|||
if exclusive_classes:
|
||||
output_layer = Softmax(nr_class, tok2vec.nO)
|
||||
else:
|
||||
output_layer = (
|
||||
zero_init(Affine(nr_class, tok2vec.nO))
|
||||
>> logistic
|
||||
)
|
||||
model = (
|
||||
tok2vec
|
||||
>> flatten_add_lengths
|
||||
>> Pooling(mean_pool)
|
||||
>> output_layer
|
||||
)
|
||||
output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic
|
||||
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
|
||||
model.tok2vec = chain(tok2vec, flatten)
|
||||
model.nO = nr_class
|
||||
return model
|
||||
|
@ -637,3 +629,79 @@ def concatenate_lists(*layers, **kwargs): # pragma: no cover
|
|||
|
||||
model = wrap(concatenate_lists_fwd, concat)
|
||||
return model
|
||||
|
||||
|
||||
def masked_language_model(vocab, model, mask_prob=0.15):
|
||||
"""Convert a model into a BERT-style masked language model"""
|
||||
|
||||
random_words = _RandomWords(vocab)
|
||||
|
||||
def mlm_forward(docs, drop=0.0):
|
||||
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
|
||||
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
|
||||
output, backprop = model.begin_update(docs, drop=drop)
|
||||
|
||||
def mlm_backward(d_output, sgd=None):
|
||||
d_output *= 1 - mask
|
||||
return backprop(d_output, sgd=sgd)
|
||||
|
||||
return output, mlm_backward
|
||||
|
||||
return wrap(mlm_forward, model)
|
||||
|
||||
|
||||
class _RandomWords(object):
|
||||
def __init__(self, vocab):
|
||||
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
|
||||
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
|
||||
self.words = self.words[:10000]
|
||||
self.probs = self.probs[:10000]
|
||||
self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
|
||||
self.probs /= self.probs.sum()
|
||||
self._cache = []
|
||||
|
||||
def next(self):
|
||||
if not self._cache:
|
||||
self._cache.extend(
|
||||
numpy.random.choice(len(self.words), 10000, p=self.probs)
|
||||
)
|
||||
index = self._cache.pop()
|
||||
return self.words[index]
|
||||
|
||||
|
||||
def _apply_mask(docs, random_words, mask_prob=0.15):
|
||||
# This needs to be here to avoid circular imports
|
||||
from .tokens.doc import Doc
|
||||
|
||||
N = sum(len(doc) for doc in docs)
|
||||
mask = numpy.random.uniform(0.0, 1.0, (N,))
|
||||
mask = mask >= mask_prob
|
||||
i = 0
|
||||
masked_docs = []
|
||||
for doc in docs:
|
||||
words = []
|
||||
for token in doc:
|
||||
if not mask[i]:
|
||||
word = _replace_word(token.text, random_words)
|
||||
else:
|
||||
word = token.text
|
||||
words.append(word)
|
||||
i += 1
|
||||
spaces = [bool(w.whitespace_) for w in doc]
|
||||
# NB: If you change this implementation to instead modify
|
||||
# the docs in place, take care that the IDs reflect the original
|
||||
# words. Currently we use the original docs to make the vectors
|
||||
# for the target, so we don't lose the original tokens. But if
|
||||
# you modified the docs in place here, you would.
|
||||
masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces))
|
||||
return mask, masked_docs
|
||||
|
||||
|
||||
def _replace_word(word, random_words, mask="[MASK]"):
|
||||
roll = numpy.random.random()
|
||||
if roll < 0.8:
|
||||
return mask
|
||||
elif roll < 0.9:
|
||||
return random_words.next()
|
||||
else:
|
||||
return word
|
||||
|
|
|
@ -17,6 +17,7 @@ import srsly
|
|||
from ..tokens import Doc
|
||||
from ..attrs import ID, HEAD
|
||||
from .._ml import Tok2Vec, flatten, chain, zero_init, create_default_optimizer
|
||||
from .._ml import masked_language_model
|
||||
from .. import util
|
||||
|
||||
|
||||
|
@ -212,79 +213,6 @@ def create_pretraining_model(nlp, tok2vec):
|
|||
return model
|
||||
|
||||
|
||||
def masked_language_model(vocab, model, mask_prob=0.15):
|
||||
"""Convert a model into a BERT-style masked language model"""
|
||||
|
||||
random_words = RandomWords(vocab)
|
||||
|
||||
def mlm_forward(docs, drop=0.0):
|
||||
mask, docs = apply_mask(docs, random_words, mask_prob=mask_prob)
|
||||
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
|
||||
output, backprop = model.begin_update(docs, drop=drop)
|
||||
|
||||
def mlm_backward(d_output, sgd=None):
|
||||
d_output *= 1 - mask
|
||||
return backprop(d_output, sgd=sgd)
|
||||
|
||||
return output, mlm_backward
|
||||
|
||||
return wrap(mlm_forward, model)
|
||||
|
||||
|
||||
def apply_mask(docs, random_words, mask_prob=0.15):
|
||||
N = sum(len(doc) for doc in docs)
|
||||
mask = numpy.random.uniform(0.0, 1.0, (N,))
|
||||
mask = mask >= mask_prob
|
||||
i = 0
|
||||
masked_docs = []
|
||||
for doc in docs:
|
||||
words = []
|
||||
for token in doc:
|
||||
if not mask[i]:
|
||||
word = replace_word(token.text, random_words)
|
||||
else:
|
||||
word = token.text
|
||||
words.append(word)
|
||||
i += 1
|
||||
spaces = [bool(w.whitespace_) for w in doc]
|
||||
# NB: If you change this implementation to instead modify
|
||||
# the docs in place, take care that the IDs reflect the original
|
||||
# words. Currently we use the original docs to make the vectors
|
||||
# for the target, so we don't lose the original tokens. But if
|
||||
# you modified the docs in place here, you would.
|
||||
masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces))
|
||||
return mask, masked_docs
|
||||
|
||||
|
||||
def replace_word(word, random_words, mask="[MASK]"):
|
||||
roll = random.random()
|
||||
if roll < 0.8:
|
||||
return mask
|
||||
elif roll < 0.9:
|
||||
return random_words.next()
|
||||
else:
|
||||
return word
|
||||
|
||||
|
||||
class RandomWords(object):
|
||||
def __init__(self, vocab):
|
||||
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
|
||||
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
|
||||
self.words = self.words[:10000]
|
||||
self.probs = self.probs[:10000]
|
||||
self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
|
||||
self.probs /= self.probs.sum()
|
||||
self._cache = []
|
||||
|
||||
def next(self):
|
||||
if not self._cache:
|
||||
self._cache.extend(
|
||||
numpy.random.choice(len(self.words), 10000, p=self.probs)
|
||||
)
|
||||
index = self._cache.pop()
|
||||
return self.words[index]
|
||||
|
||||
|
||||
class ProgressTracker(object):
|
||||
def __init__(self, frequency=1000000):
|
||||
self.loss = 0.0
|
||||
|
|
|
@ -25,6 +25,12 @@ from .. import about
|
|||
output_path=("Output directory to store model in", "positional", None, Path),
|
||||
train_path=("Location of JSON-formatted training data", "positional", None, Path),
|
||||
dev_path=("Location of JSON-formatted development data", "positional", None, Path),
|
||||
raw_text=(
|
||||
"Path to jsonl file with unlabelled text documents.",
|
||||
"option",
|
||||
"rt",
|
||||
Path,
|
||||
),
|
||||
base_model=("Name of model to update (optional)", "option", "b", str),
|
||||
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
|
||||
vectors=("Model to load vectors from", "option", "v", str),
|
||||
|
@ -62,6 +68,7 @@ def train(
|
|||
output_path,
|
||||
train_path,
|
||||
dev_path,
|
||||
raw_text=None,
|
||||
base_model=None,
|
||||
pipeline="tagger,parser,ner",
|
||||
vectors=None,
|
||||
|
@ -92,6 +99,8 @@ def train(
|
|||
train_path = util.ensure_path(train_path)
|
||||
dev_path = util.ensure_path(dev_path)
|
||||
meta_path = util.ensure_path(meta_path)
|
||||
if raw_text is not None:
|
||||
raw_text = list(srsly.read_jsonl(raw_text))
|
||||
if not train_path or not train_path.exists():
|
||||
msg.fail("Training data not found", train_path, exits=1)
|
||||
if not dev_path or not dev_path.exists():
|
||||
|
@ -186,6 +195,8 @@ def train(
|
|||
optimizer.b1_decay = 0.0001
|
||||
optimizer.b2_decay = 0.0001
|
||||
nlp._optimizer = None
|
||||
optimizer.b1_decay = 0.003
|
||||
optimizer.b2_decay = 0.003
|
||||
|
||||
# Load in pre-trained weights
|
||||
if init_tok2vec is not None:
|
||||
|
@ -208,6 +219,11 @@ def train(
|
|||
train_docs = corpus.train_docs(
|
||||
nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0
|
||||
)
|
||||
if raw_text:
|
||||
random.shuffle(raw_text)
|
||||
raw_batches = util.minibatch(
|
||||
(nlp.make_doc(rt["text"]) for rt in raw_text), size=8
|
||||
)
|
||||
words_seen = 0
|
||||
with _create_progress_bar(n_train_words) as pbar:
|
||||
losses = {}
|
||||
|
@ -222,7 +238,12 @@ def train(
|
|||
drop=next(dropout_rates),
|
||||
losses=losses,
|
||||
)
|
||||
if not int(os.environ.get('LOG_FRIENDLY', 0)):
|
||||
if raw_text:
|
||||
# If raw text is available, perform 'rehearsal' updates,
|
||||
# which use unlabelled data to reduce overfitting.
|
||||
raw_batch = list(next(raw_batches))
|
||||
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses)
|
||||
if not int(os.environ.get("LOG_FRIENDLY", 0)):
|
||||
pbar.update(sum(len(doc) for doc in docs))
|
||||
words_seen += sum(len(doc) for doc in docs)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
|
@ -286,7 +307,7 @@ def train(
|
|||
|
||||
@contextlib.contextmanager
|
||||
def _create_progress_bar(total):
|
||||
if int(os.environ.get('LOG_FRIENDLY', 0)):
|
||||
if int(os.environ.get("LOG_FRIENDLY", 0)):
|
||||
yield
|
||||
else:
|
||||
pbar = tqdm.tqdm(total=total, leave=False)
|
||||
|
|
|
@ -7,7 +7,7 @@ import weakref
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
from thinc.neural import Model
|
||||
import srsly
|
||||
|
||||
|
@ -453,6 +453,59 @@ class Language(object):
|
|||
for key, (W, dW) in grads.items():
|
||||
sgd(W, dW, key=key)
|
||||
|
||||
def rehearse(self, docs, sgd=None, losses=None, config=None):
|
||||
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
||||
forgetting. Rehearsal updates run an initial copy of the model over some
|
||||
data, and update the model so its current predictions are more like the
|
||||
initial ones. This is useful for keeping a pre-trained model on-track,
|
||||
even if you're updating it with a smaller set of examples.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
drop (float): The droput rate.
|
||||
sgd (callable): An optimizer.
|
||||
RETURNS (dict): Results from the update.
|
||||
|
||||
EXAMPLE:
|
||||
>>> raw_text_batches = minibatch(raw_texts)
|
||||
>>> for labelled_batch in minibatch(zip(train_docs, train_golds)):
|
||||
>>> docs, golds = zip(*train_docs)
|
||||
>>> nlp.update(docs, golds)
|
||||
>>> raw_batch = [nlp.make_doc(text) for text in next(raw_text_batches)]
|
||||
>>> nlp.rehearse(raw_batch)
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return
|
||||
if sgd is None:
|
||||
if self._optimizer is None:
|
||||
self._optimizer = create_default_optimizer(Model.ops)
|
||||
sgd = self._optimizer
|
||||
docs = list(docs)
|
||||
for i, doc in enumerate(docs):
|
||||
if isinstance(doc, basestring_):
|
||||
docs[i] = self.make_doc(doc)
|
||||
pipes = list(self.pipeline)
|
||||
random.shuffle(pipes)
|
||||
if config is None:
|
||||
config = {}
|
||||
grads = {}
|
||||
|
||||
def get_grads(W, dW, key=None):
|
||||
grads[key] = (W, dW)
|
||||
|
||||
get_grads.alpha = sgd.alpha
|
||||
get_grads.b1 = sgd.b1
|
||||
get_grads.b2 = sgd.b2
|
||||
|
||||
for name, proc in pipes:
|
||||
if not hasattr(proc, "rehearse"):
|
||||
continue
|
||||
grads = {}
|
||||
proc.rehearse(docs, sgd=get_grads, losses=losses, **config.get(name, {}))
|
||||
for key, (W, dW) in grads.items():
|
||||
sgd(W, dW, key=key)
|
||||
|
||||
return losses
|
||||
|
||||
def preprocess_gold(self, docs_golds):
|
||||
"""Can be called before training to pre-process gold data. By default,
|
||||
it handles nonprojectivity and adds missing tags to the tag map.
|
||||
|
@ -499,6 +552,30 @@ class Language(object):
|
|||
)
|
||||
return self._optimizer
|
||||
|
||||
def resume_training(self, sgd=None, **cfg):
|
||||
"""Continue training a pre-trained model.
|
||||
|
||||
Create and return an optimizer, and initialize "rehearsal" for any pipeline
|
||||
component that has a .rehearse() method. Rehearsal is used to prevent
|
||||
models from "forgetting" their initialised "knowledge". To perform
|
||||
rehearsal, collect samples of text you want the models to retain performance
|
||||
on, and call nlp.rehearse() with a batch of Doc objects.
|
||||
"""
|
||||
if cfg.get("device", -1) >= 0:
|
||||
util.use_gpu(cfg["device"])
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = Model.ops.asarray(self.vocab.vectors.data)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if self.vocab.vectors.data.shape[1]:
|
||||
cfg["pretrained_vectors"] = self.vocab.vectors.name
|
||||
if sgd is None:
|
||||
sgd = create_default_optimizer(Model.ops)
|
||||
self._optimizer = sgd
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "_rehearsal_model"):
|
||||
proc._rehearsal_model = deepcopy(proc.model)
|
||||
return self._optimizer
|
||||
|
||||
def evaluate(self, docs_golds, verbose=False):
|
||||
scorer = Scorer()
|
||||
docs, golds = zip(*docs_golds)
|
||||
|
|
|
@ -33,6 +33,7 @@ from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
|
|||
from ._ml import build_simple_cnn_text_classifier
|
||||
from ._ml import link_vectors_to_models, zero_init, flatten
|
||||
from ._ml import create_default_optimizer
|
||||
from ._ml import masked_language_model
|
||||
from .errors import Errors, TempErrors
|
||||
from .compat import basestring_
|
||||
from . import util
|
||||
|
@ -326,6 +327,9 @@ class Pipe(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rehearse(self, docs, sgd=None, losses=None, **config):
|
||||
pass
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
"""Find the loss and gradient of loss for the batch of
|
||||
documents and their predicted scores."""
|
||||
|
@ -568,6 +572,7 @@ class Tagger(Pipe):
|
|||
def __init__(self, vocab, model=True, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self._rehearsal_model = None
|
||||
self.cfg = OrderedDict(sorted(cfg.items()))
|
||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||
|
||||
|
@ -649,6 +654,20 @@ class Tagger(Pipe):
|
|||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
def rehearse(self, docs, drop=0., sgd=None, losses=None):
|
||||
"""Perform a 'rehearsal' update, where we try to match the output of
|
||||
an initial model.
|
||||
"""
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
guesses, backprop = self.model.begin_update(docs, drop=drop)
|
||||
target = self._rehearsal_model(docs)
|
||||
gradient = guesses - target
|
||||
backprop(gradient, sgd=sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += (gradient**2).sum()
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
scores = self.model.ops.flatten(scores)
|
||||
tag_index = {tag: i for i, tag in enumerate(self.labels)}
|
||||
|
@ -986,6 +1005,69 @@ class MultitaskObjective(Tagger):
|
|||
return sent_tags[target]
|
||||
|
||||
|
||||
class ClozeMultitask(Pipe):
|
||||
@classmethod
|
||||
def Model(cls, vocab, tok2vec, **cfg):
|
||||
output_size = vocab.vectors.data.shape[1]
|
||||
output_layer = chain(
|
||||
LayerNorm(Maxout(output_size, tok2vec.nO, pieces=3)),
|
||||
zero_init(Affine(output_size, output_size, drop_factor=0.0))
|
||||
)
|
||||
model = chain(tok2vec, output_layer)
|
||||
model = masked_language_model(vocab, model)
|
||||
model.tok2vec = tok2vec
|
||||
model.output_layer = output_layer
|
||||
return model
|
||||
|
||||
def __init__(self, vocab, model=True, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.cfg = cfg
|
||||
|
||||
def set_annotations(self, docs, dep_ids, tensors=None):
|
||||
pass
|
||||
|
||||
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None,
|
||||
tok2vec=None, sgd=None, **kwargs):
|
||||
link_vectors_to_models(self.vocab)
|
||||
if self.model is True:
|
||||
self.model = self.Model(self.vocab, tok2vec)
|
||||
X = self.model.ops.allocate((5, self.model.tok2vec.nO))
|
||||
self.model.output_layer.begin_training(X)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def predict(self, docs):
|
||||
tokvecs = self.model.tok2vec(docs)
|
||||
vectors = self.model.output_layer(tokvecs)
|
||||
return tokvecs, vectors
|
||||
|
||||
def get_loss(self, docs, vectors, prediction):
|
||||
# The simplest way to implement this would be to vstack the
|
||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||
# Instead we fetch the index into the vectors table for each of our tokens,
|
||||
# and look them up all at once. This prevents data copying.
|
||||
ids = self.model.ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||
target = vectors[ids]
|
||||
gradient = (prediction - target) / prediction.shape[0]
|
||||
loss = (gradient**2).sum()
|
||||
return float(loss), gradient
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
pass
|
||||
|
||||
def rehearse(self, docs, drop=0., sgd=None, losses=None):
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
predictions, bp_predictions = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_predictions = self.get_loss(docs, self.vocab.vectors.data, predictions)
|
||||
bp_predictions(d_predictions, sgd=sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
|
||||
class SimilarityHook(Pipe):
|
||||
"""
|
||||
Experimental: A pipeline component to install a hook for supervised
|
||||
|
@ -1062,6 +1144,7 @@ class TextCategorizer(Pipe):
|
|||
def __init__(self, vocab, model=True, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self._rehearsal_model = None
|
||||
self.cfg = dict(cfg)
|
||||
|
||||
@property
|
||||
|
@ -1103,6 +1186,17 @@ class TextCategorizer(Pipe):
|
|||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += loss
|
||||
|
||||
def rehearse(self, docs, drop=0., sgd=None, losses=None):
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
||||
target = self._rehearsal_model(docs)
|
||||
gradient = scores - target
|
||||
bp_scores(gradient, sgd=sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += (gradient**2).sum()
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
truths = numpy.zeros((len(golds), len(self.labels)), dtype='f')
|
||||
not_missing = numpy.ones((len(golds), len(self.labels)), dtype='f')
|
||||
|
@ -1165,8 +1259,12 @@ cdef class DependencyParser(Parser):
|
|||
return [nonproj.deprojectivize]
|
||||
|
||||
def add_multitask_objective(self, target):
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
self._multitasks.append(labeller)
|
||||
if target == 'cloze':
|
||||
cloze = ClozeMultitask(self.vocab)
|
||||
self._multitasks.append(cloze)
|
||||
else:
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
|
||||
for labeller in self._multitasks:
|
||||
|
@ -1186,8 +1284,12 @@ cdef class EntityRecognizer(Parser):
|
|||
nr_feature = 6
|
||||
|
||||
def add_multitask_objective(self, target):
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
self._multitasks.append(labeller)
|
||||
if target == 'cloze':
|
||||
cloze = ClozeMultitask(self.vocab)
|
||||
self._multitasks.append(cloze)
|
||||
else:
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
|
||||
for labeller in self._multitasks:
|
||||
|
|
|
@ -193,10 +193,6 @@ class ParserModel(Model):
|
|||
Model.__init__(self)
|
||||
self._layers = [tok2vec, lower_model, upper_model]
|
||||
|
||||
@property
|
||||
def tok2vec(self):
|
||||
return self._layers[0]
|
||||
|
||||
def begin_update(self, docs, drop=0.):
|
||||
step_model = ParserStepModel(docs, self._layers, drop=drop)
|
||||
def finish_parser_update(golds, sgd=None):
|
||||
|
@ -205,13 +201,20 @@ class ParserModel(Model):
|
|||
return step_model, finish_parser_update
|
||||
|
||||
def resize_output(self, new_output):
|
||||
smaller = self.upper
|
||||
larger = Affine(new_output, smaller.nI)
|
||||
larger.W *= 0
|
||||
# It seems very unhappy if I pass these as smaller.W?
|
||||
# Seems to segfault. Maybe it's a descriptor protocol thing?
|
||||
smaller_W = smaller.W
|
||||
larger_W = larger.W
|
||||
smaller_b = smaller.b
|
||||
larger_b = larger.b
|
||||
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
||||
# just adding rows here.
|
||||
smaller = self._layers[-1]._layers[-1]
|
||||
larger = Affine(self.moves.n_moves, smaller.nI)
|
||||
copy_array(larger.W[:smaller.nO], smaller.W)
|
||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self._layers[-1]._layers[-1] = larger
|
||||
larger_W[:smaller.nO] = smaller_W
|
||||
larger_b[:smaller.nO] = smaller_b
|
||||
self._layers[-1] = larger
|
||||
|
||||
def begin_training(self, X, y=None):
|
||||
self.lower.begin_training(X, y=y)
|
||||
|
|
|
@ -12,6 +12,7 @@ from ._parser_model cimport WeightsC, ActivationsC, SizesC
|
|||
cdef class Parser:
|
||||
cdef readonly Vocab vocab
|
||||
cdef public object model
|
||||
cdef public object _rehearsal_model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef readonly object cfg
|
||||
cdef public object _multitasks
|
||||
|
@ -21,4 +22,3 @@ cdef class Parser:
|
|||
|
||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil
|
||||
|
||||
|
|
|
@ -72,13 +72,15 @@ cdef class Parser:
|
|||
pretrained_vectors=pretrained_vectors,
|
||||
bilstm_depth=bilstm_depth)
|
||||
tok2vec = chain(tok2vec, flatten)
|
||||
tok2vec.nO = token_vector_width
|
||||
lower = PrecomputableAffine(hidden_width,
|
||||
nF=cls.nr_feature, nI=token_vector_width,
|
||||
nP=parser_maxout_pieces)
|
||||
lower.nP = parser_maxout_pieces
|
||||
|
||||
with Model.use_device('cpu'):
|
||||
upper = zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
||||
upper = Affine(nr_class, hidden_width, drop_factor=0.0)
|
||||
upper.W *= 0
|
||||
|
||||
cfg = {
|
||||
'nr_class': nr_class,
|
||||
|
@ -121,6 +123,7 @@ cdef class Parser:
|
|||
self.cfg = cfg
|
||||
self.model = model
|
||||
self._multitasks = []
|
||||
self._rehearsal_model = None
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
@ -404,6 +407,43 @@ cdef class Parser:
|
|||
finish_update(golds, sgd=sgd)
|
||||
return losses
|
||||
|
||||
def rehearse(self, docs, sgd=None, losses=None, **cfg):
|
||||
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
if losses is None:
|
||||
losses = {}
|
||||
for multitask in self._multitasks:
|
||||
if hasattr(multitask, 'rehearse'):
|
||||
multitask.rehearse(docs, losses=losses, sgd=sgd)
|
||||
if self._rehearsal_model is None:
|
||||
return None
|
||||
losses.setdefault(self.name, 0.)
|
||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||
tutor = self._rehearsal_model(docs)
|
||||
model, finish_update = self.model.begin_update(docs, drop=0.0)
|
||||
states = self.moves.init_batch(docs)
|
||||
n_scores = 0.
|
||||
loss = 0.
|
||||
non_zeroed_classes = self._rehearsal_model.upper.W.any(axis=1)
|
||||
while states:
|
||||
targets, _ = tutor.begin_update(states)
|
||||
guesses, backprop = model.begin_update(states)
|
||||
d_scores = (targets - guesses) / targets.shape[0]
|
||||
d_scores *= non_zeroed_classes
|
||||
# If all weights for an output are 0 in the original model, don't
|
||||
# supervise that output. This allows us to add classes.
|
||||
loss += (d_scores**2).sum()
|
||||
backprop(d_scores, sgd=sgd)
|
||||
# Follow the predicted action
|
||||
self.transition_states(states, guesses)
|
||||
states = [state for state in states if not state.is_final()]
|
||||
n_scores += d_scores.size
|
||||
# Do the backprop
|
||||
finish_update(docs, sgd=sgd)
|
||||
losses[self.name] += loss / n_scores
|
||||
return losses
|
||||
|
||||
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None,
|
||||
beam_density=0.0):
|
||||
lengths = [len(d) for d in docs]
|
||||
|
@ -416,7 +456,7 @@ cdef class Parser:
|
|||
model.vec2scores, width, drop=drop, losses=losses,
|
||||
beam_density=beam_density)
|
||||
for i, d_scores in enumerate(states_d_scores):
|
||||
losses[self.name] += (d_scores**2).sum()
|
||||
losses[self.name] += (d_scores**2).mean()
|
||||
ids, bp_vectors, bp_scores = backprops[i]
|
||||
d_vector = bp_scores(d_scores, sgd=sgd)
|
||||
if isinstance(model.ops, CupyOps) \
|
||||
|
|
Loading…
Reference in New Issue
Block a user