mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Merge pull request #1355 from explosion/feature/noshare
Make pipeline components independent
This commit is contained in:
commit
c2e2f81773
|
@ -24,7 +24,6 @@ install:
|
|||
- "%PYTHON%\\python.exe -m pip install wheel"
|
||||
- "%PYTHON%\\python.exe -m pip install cython"
|
||||
- "%PYTHON%\\python.exe -m pip install -r requirements.txt"
|
||||
- "%PYTHON%\\python.exe setup.py build_ext --inplace"
|
||||
- "%PYTHON%\\python.exe -m pip install -e ."
|
||||
|
||||
build: off
|
||||
|
|
19
spacy/_ml.py
19
spacy/_ml.py
|
@ -240,7 +240,6 @@ def link_vectors_to_models(vocab):
|
|||
# (unideal, I know)
|
||||
thinc.extra.load_nlp.VECTORS[(ops.device, VECTORS_KEY)] = data
|
||||
|
||||
|
||||
def Tok2Vec(width, embed_size, **kwargs):
|
||||
pretrained_dims = kwargs.get('pretrained_dims', 0)
|
||||
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 3)
|
||||
|
@ -271,7 +270,7 @@ def Tok2Vec(width, embed_size, **kwargs):
|
|||
tok2vec = (
|
||||
FeatureExtracter(cols)
|
||||
>> with_flatten(
|
||||
embed >> (convolution * 4), pad=4)
|
||||
embed >> (convolution ** 4), pad=4)
|
||||
)
|
||||
|
||||
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
|
||||
|
@ -513,17 +512,17 @@ def build_tagger_model(nr_class, **cfg):
|
|||
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||
pretrained_dims = cfg.get('pretrained_dims', 0)
|
||||
with Model.define_operators({'>>': chain, '+': add}):
|
||||
# Input: (doc, tensor) tuples
|
||||
private_tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||
pretrained_dims=pretrained_dims)
|
||||
if 'tok2vec' in cfg:
|
||||
tok2vec = cfg['tok2vec']
|
||||
else:
|
||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||
pretrained_dims=pretrained_dims)
|
||||
model = (
|
||||
fine_tune(private_tok2vec)
|
||||
>> with_flatten(
|
||||
Maxout(token_vector_width, token_vector_width)
|
||||
>> Softmax(nr_class, token_vector_width)
|
||||
)
|
||||
tok2vec
|
||||
>> with_flatten(Softmax(nr_class, token_vector_width))
|
||||
)
|
||||
model.nI = None
|
||||
model.tok2vec = tok2vec
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -3,12 +3,13 @@
|
|||
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
|
||||
|
||||
__title__ = 'spacy-nightly'
|
||||
__version__ = '2.0.0a14'
|
||||
__version__ = '2.0.0a15'
|
||||
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
|
||||
__uri__ = 'https://spacy.io'
|
||||
__author__ = 'Explosion AI'
|
||||
__email__ = 'contact@explosion.ai'
|
||||
__license__ = 'MIT'
|
||||
__release__ = False
|
||||
|
||||
__docs_models__ = 'https://spacy.io/docs/usage/models'
|
||||
__download_url__ = 'https://github.com/explosion/spacy-models/releases/download'
|
||||
|
|
|
@ -11,6 +11,8 @@ import tqdm
|
|||
from thinc.neural._classes.model import Model
|
||||
from thinc.neural.optimizers import linear_decay
|
||||
from timeit import default_timer as timer
|
||||
import random
|
||||
import numpy.random
|
||||
|
||||
from ..tokens.doc import Doc
|
||||
from ..scorer import Scorer
|
||||
|
@ -22,6 +24,9 @@ from .. import about
|
|||
from .. import displacy
|
||||
from ..compat import json_dumps
|
||||
|
||||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
lang=("model language", "positional", None, str),
|
||||
|
@ -63,7 +68,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
prints("Expected dict but got: {}".format(type(meta)),
|
||||
title="Not a valid meta.json format", exits=1)
|
||||
|
||||
pipeline = ['token_vectors', 'tags', 'dependencies', 'entities']
|
||||
pipeline = ['tags', 'dependencies', 'entities']
|
||||
if no_tagger and 'tags' in pipeline: pipeline.remove('tags')
|
||||
if no_parser and 'dependencies' in pipeline: pipeline.remove('dependencies')
|
||||
if no_entities and 'entities' in pipeline: pipeline.remove('entities')
|
||||
|
@ -99,8 +104,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
for batch in minibatch(train_docs, size=batch_sizes):
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(docs, golds, sgd=optimizer,
|
||||
drop=next(dropout_rates), losses=losses,
|
||||
update_shared=True)
|
||||
drop=next(dropout_rates), losses=losses)
|
||||
pbar.update(sum(len(doc) for doc in docs))
|
||||
|
||||
with nlp.use_params(optimizer.averages):
|
||||
|
@ -109,10 +113,13 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
nlp.to_disk(epoch_model_path)
|
||||
nlp_loaded = lang_class(pipeline=pipeline)
|
||||
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
|
||||
scorer = nlp.evaluate(
|
||||
corpus.dev_docs(
|
||||
nlp,
|
||||
gold_preproc=gold_preproc))
|
||||
scorer = nlp_loaded.evaluate(
|
||||
list(corpus.dev_docs(
|
||||
nlp_loaded,
|
||||
gold_preproc=gold_preproc)))
|
||||
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
||||
with acc_loc.open('w') as file_:
|
||||
file_.write(json_dumps(scorer.scores))
|
||||
meta_loc = output_path / ('model%d' % i) / 'meta.json'
|
||||
meta['accuracy'] = scorer.scores
|
||||
meta['lang'] = nlp.lang
|
||||
|
|
|
@ -34,6 +34,7 @@ from .lang.tag_map import TAG_MAP
|
|||
from .lang.lex_attrs import LEX_ATTRS
|
||||
from . import util
|
||||
from .scorer import Scorer
|
||||
from ._ml import link_vectors_to_models
|
||||
|
||||
|
||||
class BaseDefaults(object):
|
||||
|
@ -278,8 +279,7 @@ class Language(object):
|
|||
def make_doc(self, text):
|
||||
return self.tokenizer(text)
|
||||
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None,
|
||||
update_shared=False):
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
"""Update the models in the pipeline.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
|
@ -303,31 +303,17 @@ class Language(object):
|
|||
if self._optimizer is None:
|
||||
self._optimizer = Adam(Model.ops, 0.001)
|
||||
sgd = self._optimizer
|
||||
tok2vec = self.pipeline[0]
|
||||
grads = {}
|
||||
def get_grads(W, dW, key=None):
|
||||
grads[key] = (W, dW)
|
||||
pipes = list(self.pipeline[1:])
|
||||
pipes = list(self.pipeline)
|
||||
random.shuffle(pipes)
|
||||
tokvecses, bp_tokvecses = tok2vec.model.begin_update(docs, drop=drop)
|
||||
all_d_tokvecses = [tok2vec.model.ops.allocate(tv.shape) for tv in tokvecses]
|
||||
for proc in pipes:
|
||||
if not hasattr(proc, 'update'):
|
||||
continue
|
||||
d_tokvecses = proc.update((docs, tokvecses), golds,
|
||||
drop=drop, sgd=get_grads, losses=losses)
|
||||
if update_shared and d_tokvecses is not None:
|
||||
for i, d_tv in enumerate(d_tokvecses):
|
||||
all_d_tokvecses[i] += d_tv
|
||||
if update_shared and bp_tokvecses is not None:
|
||||
bp_tokvecses(all_d_tokvecses, sgd=sgd)
|
||||
proc.update(docs, golds, drop=drop, sgd=get_grads, losses=losses)
|
||||
for key, (W, dW) in grads.items():
|
||||
sgd(W, dW, key=key)
|
||||
# Clear the tensor variable, to free GPU memory.
|
||||
# If we don't do this, the memory leak gets pretty
|
||||
# bad, because we may be holding part of a batch.
|
||||
for doc in docs:
|
||||
doc.tensor = None
|
||||
|
||||
def preprocess_gold(self, docs_golds):
|
||||
"""Can be called before training to pre-process gold data. By default,
|
||||
|
@ -370,8 +356,6 @@ class Language(object):
|
|||
**cfg: Config parameters.
|
||||
returns: An optimizer
|
||||
"""
|
||||
if self.parser:
|
||||
self.pipeline.append(NeuralLabeller(self.vocab))
|
||||
# Populate vocab
|
||||
if get_gold_tuples is not None:
|
||||
for _, annots_brackets in get_gold_tuples():
|
||||
|
@ -386,6 +370,7 @@ class Language(object):
|
|||
self.vocab.vectors.data)
|
||||
else:
|
||||
device = None
|
||||
link_vectors_to_models(self.vocab)
|
||||
for proc in self.pipeline:
|
||||
if hasattr(proc, 'begin_training'):
|
||||
context = proc.begin_training(get_gold_tuples(),
|
||||
|
@ -417,7 +402,6 @@ class Language(object):
|
|||
assert len(docs) == len(golds)
|
||||
for doc, gold in zip(docs, golds):
|
||||
scorer.score(doc, gold)
|
||||
doc.tensor = None
|
||||
return scorer
|
||||
|
||||
@contextmanager
|
||||
|
@ -506,7 +490,6 @@ class Language(object):
|
|||
"""
|
||||
path = util.ensure_path(path)
|
||||
serializers = OrderedDict((
|
||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||
('tokenizer', lambda p: self.tokenizer.to_disk(p, vocab=False)),
|
||||
('meta.json', lambda p: p.open('w').write(json_dumps(self.meta)))
|
||||
))
|
||||
|
@ -518,6 +501,7 @@ class Language(object):
|
|||
if not hasattr(proc, 'to_disk'):
|
||||
continue
|
||||
serializers[proc.name] = lambda p, proc=proc: proc.to_disk(p, vocab=False)
|
||||
serializers['vocab'] = lambda p: self.vocab.to_disk(p)
|
||||
util.to_disk(path, serializers, {p: False for p in disable})
|
||||
|
||||
def from_disk(self, path, disable=tuple()):
|
||||
|
|
|
@ -146,6 +146,8 @@ cdef class Morphology:
|
|||
self.add_special_case(tag_str, form_str, attrs)
|
||||
|
||||
def lemmatize(self, const univ_pos_t univ_pos, attr_t orth, morphology):
|
||||
if orth not in self.strings:
|
||||
return orth
|
||||
cdef unicode py_string = self.strings[orth]
|
||||
if self.lemmatizer is None:
|
||||
return self.strings.add(py_string.lower())
|
||||
|
|
|
@ -174,7 +174,7 @@ class BaseThincComponent(object):
|
|||
|
||||
deserialize = OrderedDict((
|
||||
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
|
||||
('vocab', lambda b: self.vocab.from_bytes(b))
|
||||
('vocab', lambda b: self.vocab.from_bytes(b)),
|
||||
('model', load_model),
|
||||
))
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
|
@ -322,7 +322,7 @@ class TokenVectorEncoder(BaseThincComponent):
|
|||
if self.model is True:
|
||||
self.cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
self.model = self.Model(**self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
||||
|
||||
class NeuralTagger(BaseThincComponent):
|
||||
|
@ -335,27 +335,25 @@ class NeuralTagger(BaseThincComponent):
|
|||
self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1])
|
||||
|
||||
def __call__(self, doc):
|
||||
tags = self.predict(([doc], [doc.tensor]))
|
||||
tags = self.predict([doc])
|
||||
self.set_annotations([doc], tags)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for docs in cytoolz.partition_all(batch_size, stream):
|
||||
docs = list(docs)
|
||||
tokvecs = [d.tensor for d in docs]
|
||||
tag_ids = self.predict((docs, tokvecs))
|
||||
tag_ids = self.predict(docs)
|
||||
self.set_annotations(docs, tag_ids)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs_tokvecs):
|
||||
scores = self.model(docs_tokvecs)
|
||||
def predict(self, docs):
|
||||
scores = self.model(docs)
|
||||
scores = self.model.ops.flatten(scores)
|
||||
guesses = scores.argmax(axis=1)
|
||||
if not isinstance(guesses, numpy.ndarray):
|
||||
guesses = guesses.get()
|
||||
tokvecs = docs_tokvecs[1]
|
||||
guesses = self.model.ops.unflatten(guesses,
|
||||
[tv.shape[0] for tv in tokvecs])
|
||||
[len(d) for d in docs])
|
||||
return guesses
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
|
@ -375,20 +373,16 @@ class NeuralTagger(BaseThincComponent):
|
|||
idx += 1
|
||||
doc.is_tagged = True
|
||||
|
||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
docs, tokvecs = docs_tokvecs
|
||||
|
||||
if self.model.nI is None:
|
||||
self.model.nI = tokvecs[0].shape[1]
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(docs_tokvecs, drop=drop)
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||
bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||
|
||||
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
return d_tokvecs
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
scores = self.model.ops.flatten(scores)
|
||||
|
@ -432,7 +426,7 @@ class NeuralTagger(BaseThincComponent):
|
|||
if self.model is True:
|
||||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, **cfg):
|
||||
|
@ -514,9 +508,25 @@ class NeuralTagger(BaseThincComponent):
|
|||
|
||||
class NeuralLabeller(NeuralTagger):
|
||||
name = 'nn_labeller'
|
||||
def __init__(self, vocab, model=True, **cfg):
|
||||
def __init__(self, vocab, model=True, target='dep_tag_offset', **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
if target == 'dep':
|
||||
self.make_label = self.make_dep
|
||||
elif target == 'tag':
|
||||
self.make_label = self.make_tag
|
||||
elif target == 'ent':
|
||||
self.make_label = self.make_ent
|
||||
elif target == 'dep_tag_offset':
|
||||
self.make_label = self.make_dep_tag_offset
|
||||
elif target == 'ent_tag':
|
||||
self.make_label = self.make_ent_tag
|
||||
elif hasattr(target, '__call__'):
|
||||
self.make_label = target
|
||||
else:
|
||||
raise ValueError(
|
||||
"NeuralLabeller target should be function or one of "
|
||||
"['dep', 'tag', 'ent', 'dep_tag_offset', 'ent_tag']")
|
||||
self.cfg = dict(cfg)
|
||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||
self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1])
|
||||
|
@ -532,43 +542,78 @@ class NeuralLabeller(NeuralTagger):
|
|||
def set_annotations(self, docs, dep_ids):
|
||||
pass
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None):
|
||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
|
||||
for raw_text, annots_brackets in gold_tuples:
|
||||
for annots, brackets in annots_brackets:
|
||||
ids, words, tags, heads, deps, ents = annots
|
||||
for dep in deps:
|
||||
if dep not in self.labels:
|
||||
self.labels[dep] = len(self.labels)
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
for i in range(len(ids)):
|
||||
label = self.make_label(i, words, tags, heads, deps, ents)
|
||||
if label is not None and label not in self.labels:
|
||||
self.labels[label] = len(self.labels)
|
||||
print(len(self.labels))
|
||||
if self.model is True:
|
||||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
self.model = self.Model(len(self.labels), **self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
self.model = chain(
|
||||
tok2vec,
|
||||
Softmax(len(self.labels), 128)
|
||||
)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, **cfg):
|
||||
return build_tagger_model(n_tags, **cfg)
|
||||
def Model(cls, n_tags, tok2vec=None, **cfg):
|
||||
return build_tagger_model(n_tags, tok2vec=tok2vec, **cfg)
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
scores = self.model.ops.flatten(scores)
|
||||
cdef int idx = 0
|
||||
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
||||
guesses = scores.argmax(axis=1)
|
||||
for gold in golds:
|
||||
for tag in gold.labels:
|
||||
if tag is None or tag not in self.labels:
|
||||
for i in range(len(gold.labels)):
|
||||
label = self.make_label(i, gold.words, gold.tags, gold.heads,
|
||||
gold.labels, gold.ents)
|
||||
if label is None or label not in self.labels:
|
||||
correct[idx] = guesses[idx]
|
||||
else:
|
||||
correct[idx] = self.labels[tag]
|
||||
correct[idx] = self.labels[label]
|
||||
idx += 1
|
||||
correct = self.model.ops.xp.array(correct, dtype='i')
|
||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
d_scores /= d_scores.shape[0]
|
||||
loss = (d_scores**2).sum()
|
||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||
return float(loss), d_scores
|
||||
|
||||
@staticmethod
|
||||
def make_dep(i, words, tags, heads, deps, ents):
|
||||
if deps[i] is None or heads[i] is None:
|
||||
return None
|
||||
return deps[i]
|
||||
|
||||
@staticmethod
|
||||
def make_tag(i, words, tags, heads, deps, ents):
|
||||
return tags[i]
|
||||
|
||||
@staticmethod
|
||||
def make_ent(i, words, tags, heads, deps, ents):
|
||||
if ents is None:
|
||||
return None
|
||||
return ents[i]
|
||||
|
||||
@staticmethod
|
||||
def make_dep_tag_offset(i, words, tags, heads, deps, ents):
|
||||
if deps[i] is None or heads[i] is None:
|
||||
return None
|
||||
offset = heads[i] - i
|
||||
offset = min(offset, 2)
|
||||
offset = max(offset, -2)
|
||||
return '%s-%s:%d' % (deps[i], tags[i], offset)
|
||||
|
||||
@staticmethod
|
||||
def make_ent_tag(i, words, tags, heads, deps, ents):
|
||||
if ents is None or ents[i] is None:
|
||||
return None
|
||||
else:
|
||||
return '%s-%s' % (tags[i], ents[i])
|
||||
|
||||
|
||||
class SimilarityHook(BaseThincComponent):
|
||||
"""
|
||||
|
@ -605,15 +650,10 @@ class SimilarityHook(BaseThincComponent):
|
|||
yield self(doc)
|
||||
|
||||
def predict(self, doc1, doc2):
|
||||
return self.model.predict([(doc1.tensor, doc2.tensor)])
|
||||
return self.model.predict([(doc1, doc2)])
|
||||
|
||||
def update(self, doc1_tensor1_doc2_tensor2, golds, sgd=None, drop=0.):
|
||||
doc1s, tensor1s, doc2s, tensor2s = doc1_tensor1_doc2_tensor2
|
||||
sims, bp_sims = self.model.begin_update(zip(tensor1s, tensor2s),
|
||||
drop=drop)
|
||||
d_tensor1s, d_tensor2s = bp_sims(golds, sgd=sgd)
|
||||
|
||||
return d_tensor1s, d_tensor2s
|
||||
def update(self, doc1_doc2, golds, sgd=None, drop=0.):
|
||||
sims, bp_sims = self.model.begin_update(doc1_doc2, drop=drop)
|
||||
|
||||
def begin_training(self, _=tuple(), pipeline=None):
|
||||
"""
|
||||
|
@ -669,15 +709,13 @@ class TextCategorizer(BaseThincComponent):
|
|||
for j, label in enumerate(self.labels):
|
||||
doc.cats[label] = float(scores[i, j])
|
||||
|
||||
def update(self, docs_tensors, golds, state=None, drop=0., sgd=None, losses=None):
|
||||
docs, tensors = docs_tensors
|
||||
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
|
||||
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_scores = self.get_loss(docs, golds, scores)
|
||||
d_tensors = bp_scores(d_scores, sgd=sgd)
|
||||
bp_scores(d_scores, sgd=sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += loss
|
||||
return d_tensors
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
truths = numpy.zeros((len(golds), len(self.labels)), dtype='f')
|
||||
|
@ -739,6 +777,14 @@ cdef class NeuralDependencyParser(NeuralParser):
|
|||
name = 'parser'
|
||||
TransitionSystem = ArcEager
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
for target in ['dep', 'ent']:
|
||||
labeller = NeuralLabeller(self.vocab, target=target)
|
||||
tok2vec = self.model[0]
|
||||
labeller.begin_training(gold_tuples, pipeline=pipeline, tok2vec=tok2vec)
|
||||
pipeline.append(labeller)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
def __reduce__(self):
|
||||
return (NeuralDependencyParser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
||||
|
@ -749,13 +795,13 @@ cdef class NeuralEntityRecognizer(NeuralParser):
|
|||
|
||||
nr_feature = 6
|
||||
|
||||
def predict_confidences(self, docs):
|
||||
tensors = [d.tensor for d in docs]
|
||||
samples = []
|
||||
for i in range(10):
|
||||
states = self.parse_batch(docs, tensors, drop=0.3)
|
||||
for state in states:
|
||||
samples.append(self._get_entities(state))
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
for target in []:
|
||||
labeller = NeuralLabeller(self.vocab, target=target)
|
||||
tok2vec = self.model[0]
|
||||
labeller.begin_training(gold_tuples, pipeline=pipeline, tok2vec=tok2vec)
|
||||
pipeline.append(labeller)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
def __reduce__(self):
|
||||
return (NeuralEntityRecognizer, (self.vocab, self.moves, self.model), None, None)
|
||||
|
|
|
@ -147,10 +147,10 @@ def get_token_ids(states, int n_tokens):
|
|||
|
||||
nr_update = 0
|
||||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||
states, tokvecs, golds,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
int width, float density,
|
||||
sgd=None, losses=None, drop=0.):
|
||||
losses=None, drop=0.):
|
||||
global nr_update
|
||||
cdef MaxViolation violn
|
||||
nr_update += 1
|
||||
|
|
|
@ -13,6 +13,7 @@ cdef class Parser:
|
|||
cdef public object model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef readonly object cfg
|
||||
cdef public object _multitasks
|
||||
|
||||
cdef void _parse_step(self, StateC* state,
|
||||
const float* feat_weights,
|
||||
|
|
|
@ -7,6 +7,7 @@ from __future__ import unicode_literals, print_function
|
|||
|
||||
from collections import Counter, OrderedDict
|
||||
import ujson
|
||||
import json
|
||||
import contextlib
|
||||
|
||||
from libc.math cimport exp
|
||||
|
@ -48,7 +49,7 @@ from .. import util
|
|||
from ..util import get_async, get_cuda_stream
|
||||
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
||||
from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune
|
||||
from .._ml import Residual, drop_layer
|
||||
from .._ml import Residual, drop_layer, flatten
|
||||
from .._ml import link_vectors_to_models
|
||||
from ..compat import json_dumps
|
||||
|
||||
|
@ -245,8 +246,9 @@ cdef class Parser:
|
|||
hidden_width = util.env_opt('hidden_width', hidden_width)
|
||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
|
||||
embed_size = util.env_opt('embed_size', 4000)
|
||||
tensors = fine_tune(Tok2Vec(token_vector_width, embed_size,
|
||||
pretrained_dims=cfg.get('pretrained_dims')))
|
||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||
pretrained_dims=cfg.get('pretrained_dims', 0))
|
||||
tok2vec = chain(tok2vec, flatten)
|
||||
if parser_maxout_pieces == 1:
|
||||
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
|
||||
nF=cls.nr_feature,
|
||||
|
@ -278,7 +280,7 @@ cdef class Parser:
|
|||
'hidden_width': hidden_width,
|
||||
'maxout_pieces': parser_maxout_pieces
|
||||
}
|
||||
return (tensors, lower, upper), cfg
|
||||
return (tok2vec, lower, upper), cfg
|
||||
|
||||
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
|
||||
"""
|
||||
|
@ -317,6 +319,7 @@ cdef class Parser:
|
|||
for label in labels:
|
||||
self.moves.add_action(action, label)
|
||||
self.model = model
|
||||
self._multitasks = []
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
@ -336,11 +339,11 @@ cdef class Parser:
|
|||
beam_density = self.cfg.get('beam_density', 0.0)
|
||||
cdef Beam beam
|
||||
if beam_width == 1:
|
||||
states = self.parse_batch([doc], [doc.tensor])
|
||||
states = self.parse_batch([doc])
|
||||
self.set_annotations([doc], states)
|
||||
return doc
|
||||
else:
|
||||
beam = self.beam_parse([doc], [doc.tensor],
|
||||
beam = self.beam_parse([doc],
|
||||
beam_width=beam_width, beam_density=beam_density)[0]
|
||||
output = self.moves.get_beam_annot(beam)
|
||||
state = <StateClass>beam.at(0)
|
||||
|
@ -369,11 +372,11 @@ cdef class Parser:
|
|||
cdef Beam beam
|
||||
for docs in cytoolz.partition_all(batch_size, docs):
|
||||
docs = list(docs)
|
||||
tokvecs = [doc.tensor for doc in docs]
|
||||
if beam_width == 1:
|
||||
parse_states = self.parse_batch(docs, tokvecs)
|
||||
parse_states = self.parse_batch(docs)
|
||||
beams = []
|
||||
else:
|
||||
beams = self.beam_parse(docs, tokvecs,
|
||||
beams = self.beam_parse(docs,
|
||||
beam_width=beam_width, beam_density=beam_density)
|
||||
parse_states = []
|
||||
for beam in beams:
|
||||
|
@ -381,7 +384,7 @@ cdef class Parser:
|
|||
self.set_annotations(docs, parse_states)
|
||||
yield from docs
|
||||
|
||||
def parse_batch(self, docs, tokvecses):
|
||||
def parse_batch(self, docs):
|
||||
cdef:
|
||||
precompute_hiddens state2vec
|
||||
StateClass state
|
||||
|
@ -392,21 +395,15 @@ cdef class Parser:
|
|||
int nr_class, nr_feat, nr_piece, nr_dim, nr_state
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
if isinstance(tokvecses, np.ndarray):
|
||||
tokvecses = [tokvecses]
|
||||
|
||||
if USE_FINE_TUNE:
|
||||
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
cuda_stream = get_cuda_stream()
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||
0.0)
|
||||
|
||||
nr_state = len(docs)
|
||||
nr_class = self.moves.n_moves
|
||||
nr_dim = tokvecs.shape[1]
|
||||
nr_feat = self.nr_feature
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
state2vec, vec2scores = self.get_batch_model(nr_state, tokvecs,
|
||||
cuda_stream, 0.0)
|
||||
nr_piece = state2vec.nP
|
||||
|
||||
states = self.moves.init_batch(docs)
|
||||
|
@ -422,21 +419,23 @@ cdef class Parser:
|
|||
c_token_ids = <int*>token_ids.data
|
||||
c_is_valid = <int*>is_valid.data
|
||||
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
|
||||
cdef int nr_step
|
||||
while not next_step.empty():
|
||||
nr_step = next_step.size()
|
||||
if not has_hidden:
|
||||
for i in range(
|
||||
next_step.size(), num_threads=6, nogil=True):
|
||||
for i in cython.parallel.prange(nr_step, num_threads=6,
|
||||
nogil=True):
|
||||
self._parse_step(next_step[i],
|
||||
feat_weights, nr_class, nr_feat, nr_piece)
|
||||
else:
|
||||
for i in range(next_step.size()):
|
||||
for i in range(nr_step):
|
||||
st = next_step[i]
|
||||
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
|
||||
self.moves.set_valid(&c_is_valid[i*nr_class], st)
|
||||
vectors = state2vec(token_ids[:next_step.size()])
|
||||
scores = vec2scores(vectors)
|
||||
c_scores = <float*>scores.data
|
||||
for i in range(next_step.size()):
|
||||
for i in range(nr_step):
|
||||
st = next_step[i]
|
||||
guess = arg_max_if_valid(
|
||||
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
|
||||
|
@ -449,19 +448,15 @@ cdef class Parser:
|
|||
next_step.push_back(st)
|
||||
return states
|
||||
|
||||
def beam_parse(self, docs, tokvecses, int beam_width=3, float beam_density=0.001):
|
||||
def beam_parse(self, docs, int beam_width=3, float beam_density=0.001):
|
||||
cdef Beam beam
|
||||
cdef np.ndarray scores
|
||||
cdef Doc doc
|
||||
cdef int nr_class = self.moves.n_moves
|
||||
cdef StateClass stcls, output
|
||||
if USE_FINE_TUNE:
|
||||
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
cuda_stream = get_cuda_stream()
|
||||
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
|
||||
cuda_stream, 0.0)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||
0.0)
|
||||
beams = []
|
||||
cdef int offset = 0
|
||||
cdef int j = 0
|
||||
|
@ -521,30 +516,24 @@ cdef class Parser:
|
|||
free(scores)
|
||||
free(token_ids)
|
||||
|
||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||
return None
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
|
||||
return self.update_beam(docs_tokvecs, golds,
|
||||
return self.update_beam(docs, golds,
|
||||
self.cfg['beam_width'], self.cfg['beam_density'],
|
||||
drop=drop, sgd=sgd, losses=losses)
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
docs, tokvec_lists = docs_tokvecs
|
||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||
docs = [docs]
|
||||
golds = [golds]
|
||||
if USE_FINE_TUNE:
|
||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
|
||||
tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(docs_tokvecs[1])
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
|
||||
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
||||
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
||||
0.0)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||
0.0)
|
||||
todo = [(s, g) for (s, g) in zip(states, golds)
|
||||
if not s.is_final() and g is not None]
|
||||
if not todo:
|
||||
|
@ -588,13 +577,9 @@ cdef class Parser:
|
|||
if n_steps >= max_steps:
|
||||
break
|
||||
self._make_updates(d_tokvecs,
|
||||
backprops, sgd, cuda_stream)
|
||||
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
||||
if USE_FINE_TUNE:
|
||||
d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||
return d_tokvecs
|
||||
bp_tokvecs, backprops, sgd, cuda_stream)
|
||||
|
||||
def update_beam(self, docs_tokvecs, golds, width=None, density=None,
|
||||
def update_beam(self, docs, golds, width=None, density=None,
|
||||
drop=0., sgd=None, losses=None):
|
||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||
return None
|
||||
|
@ -606,26 +591,20 @@ cdef class Parser:
|
|||
density = self.cfg.get('beam_density', 0.0)
|
||||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
docs, tokvecs = docs_tokvecs
|
||||
lengths = [len(d) for d in docs]
|
||||
assert min(lengths) >= 1
|
||||
if USE_FINE_TUNE:
|
||||
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
|
||||
tokvecs = self.model[0].ops.flatten(my_tokvecs)
|
||||
else:
|
||||
tokvecs = self.model[0].ops.flatten(tokvecs)
|
||||
states = self.moves.init_batch(docs)
|
||||
for gold in golds:
|
||||
self.moves.preprocess_gold(gold)
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, 0.0)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, 0.0)
|
||||
|
||||
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
|
||||
states, tokvecs, golds,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
width, density,
|
||||
sgd=sgd, drop=drop, losses=losses)
|
||||
drop=drop, losses=losses)
|
||||
backprop_lower = []
|
||||
cdef float batch_size = len(docs)
|
||||
for i, d_scores in enumerate(states_d_scores):
|
||||
|
@ -643,20 +622,7 @@ cdef class Parser:
|
|||
else:
|
||||
backprop_lower.append((ids, d_vector, bp_vectors))
|
||||
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
|
||||
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
|
||||
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths)
|
||||
if USE_FINE_TUNE:
|
||||
d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||
return d_tokvecs
|
||||
|
||||
def _pad_tokvecs(self, tokvecs):
|
||||
# Add a vector for missing values at the start of tokvecs
|
||||
xp = get_array_module(tokvecs)
|
||||
pad = xp.zeros((1, tokvecs.shape[1]), dtype=tokvecs.dtype)
|
||||
return xp.vstack((pad, tokvecs))
|
||||
|
||||
def _unpad_tokvecs(self, d_tokvecs):
|
||||
return d_tokvecs[1:]
|
||||
self._make_updates(d_tokvecs, bp_tokvecs, backprop_lower, sgd, cuda_stream)
|
||||
|
||||
def _init_gold_batch(self, whole_docs, whole_golds):
|
||||
"""Make a square batch, of length equal to the shortest doc. A long
|
||||
|
@ -694,7 +660,7 @@ cdef class Parser:
|
|||
max_moves = max(max_moves, len(oracle_actions))
|
||||
return states, golds, max_moves
|
||||
|
||||
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
||||
def _make_updates(self, d_tokvecs, bp_tokvecs, backprops, sgd, cuda_stream=None):
|
||||
# Tells CUDA to block, so our async copies complete.
|
||||
if cuda_stream is not None:
|
||||
cuda_stream.synchronize()
|
||||
|
@ -705,6 +671,7 @@ cdef class Parser:
|
|||
d_state_features *= mask.reshape(ids.shape + (1,))
|
||||
self.model[0].ops.scatter_add(d_tokvecs, ids * mask,
|
||||
d_state_features)
|
||||
bp_tokvecs(d_tokvecs, sgd=sgd)
|
||||
|
||||
@property
|
||||
def move_names(self):
|
||||
|
@ -714,11 +681,12 @@ cdef class Parser:
|
|||
names.append(name)
|
||||
return names
|
||||
|
||||
def get_batch_model(self, batch_size, tokvecs, stream, dropout):
|
||||
_, lower, upper = self.model
|
||||
state2vec = precompute_hiddens(batch_size, tokvecs,
|
||||
def get_batch_model(self, docs, stream, dropout):
|
||||
tok2vec, lower, upper = self.model
|
||||
tokvecs, bp_tokvecs = tok2vec.begin_update(docs, drop=dropout)
|
||||
state2vec = precompute_hiddens(len(docs), tokvecs,
|
||||
lower, stream, drop=dropout)
|
||||
return state2vec, upper
|
||||
return (tokvecs, bp_tokvecs), state2vec, upper
|
||||
|
||||
nr_feature = 8
|
||||
|
||||
|
@ -781,7 +749,7 @@ cdef class Parser:
|
|||
# order, or the model goes out of synch
|
||||
self.cfg.setdefault('extra_labels', []).append(label)
|
||||
|
||||
def begin_training(self, gold_tuples, **cfg):
|
||||
def begin_training(self, gold_tuples, pipeline=None, **cfg):
|
||||
if 'model' in cfg:
|
||||
self.model = cfg['model']
|
||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
|
||||
|
@ -792,9 +760,20 @@ cdef class Parser:
|
|||
if self.model is True:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||
self.init_multitask_objectives(gold_tuples, pipeline, **cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
self.cfg.update(cfg)
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
'''Setup models for secondary objectives, to benefit from multi-task
|
||||
learning. This method is intended to be overridden by subclasses.
|
||||
|
||||
For instance, the dependency parser can benefit from sharing
|
||||
an input representation with a label prediction model. These auxiliary
|
||||
models are discarded after training.
|
||||
'''
|
||||
pass
|
||||
|
||||
def preprocess_gold(self, docs_golds):
|
||||
for doc, gold in docs_golds:
|
||||
yield doc, gold
|
||||
|
@ -853,7 +832,7 @@ cdef class Parser:
|
|||
('upper_model', lambda: self.model[2].to_bytes()),
|
||||
('vocab', lambda: self.vocab.to_bytes()),
|
||||
('moves', lambda: self.moves.to_bytes(strings=False)),
|
||||
('cfg', lambda: ujson.dumps(self.cfg))
|
||||
('cfg', lambda: json.dumps(self.cfg, indent=2, sort_keys=True))
|
||||
))
|
||||
if 'model' in exclude:
|
||||
exclude['tok2vec_model'] = True
|
||||
|
@ -866,7 +845,7 @@ cdef class Parser:
|
|||
deserializers = OrderedDict((
|
||||
('vocab', lambda b: self.vocab.from_bytes(b)),
|
||||
('moves', lambda b: self.moves.from_bytes(b, strings=False)),
|
||||
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
|
||||
('cfg', lambda b: self.cfg.update(json.loads(b))),
|
||||
('tok2vec_model', lambda b: None),
|
||||
('lower_model', lambda b: None),
|
||||
('upper_model', lambda b: None)
|
||||
|
|
|
@ -61,33 +61,22 @@ def test_predict_doc(parser, tok2vec, model, doc):
|
|||
parser(doc)
|
||||
|
||||
|
||||
def test_update_doc(parser, tok2vec, model, doc, gold):
|
||||
def test_update_doc(parser, model, doc, gold):
|
||||
parser.model = model
|
||||
tokvecs, bp_tokvecs = tok2vec.begin_update([doc])
|
||||
d_tokvecs = parser.update(([doc], tokvecs), [gold])
|
||||
assert d_tokvecs[0].shape == tokvecs[0].shape
|
||||
def optimize(weights, gradient, key=None):
|
||||
weights -= 0.001 * gradient
|
||||
bp_tokvecs(d_tokvecs, sgd=optimize)
|
||||
assert d_tokvecs[0].sum() == 0.
|
||||
parser.update([doc], [gold], sgd=optimize)
|
||||
|
||||
|
||||
def test_predict_doc_beam(parser, tok2vec, model, doc):
|
||||
doc.tensor = tok2vec([doc])[0]
|
||||
def test_predict_doc_beam(parser, model, doc):
|
||||
parser.model = model
|
||||
parser(doc, beam_width=32, beam_density=0.001)
|
||||
for word in doc:
|
||||
print(word.text, word.head, word.dep_)
|
||||
|
||||
|
||||
def test_update_doc_beam(parser, tok2vec, model, doc, gold):
|
||||
def test_update_doc_beam(parser, model, doc, gold):
|
||||
parser.model = model
|
||||
tokvecs, bp_tokvecs = tok2vec.begin_update([doc])
|
||||
d_tokvecs = parser.update_beam(([doc], tokvecs), [gold])
|
||||
assert d_tokvecs[0].shape == tokvecs[0].shape
|
||||
def optimize(weights, gradient, key=None):
|
||||
weights -= 0.001 * gradient
|
||||
bp_tokvecs(d_tokvecs, sgd=optimize)
|
||||
assert d_tokvecs[0].sum() == 0.
|
||||
parser.update_beam([doc], [gold], sgd=optimize)
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
|||
def taggers(en_vocab):
|
||||
tagger1 = Tagger(en_vocab)
|
||||
tagger2 = Tagger(en_vocab)
|
||||
tagger1.model = tagger1.Model(8, 8)
|
||||
tagger1.model = tagger1.Model(8)
|
||||
tagger2.model = tagger1.model
|
||||
return (tagger1, tagger2)
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ cdef class Doc:
|
|||
|
||||
cdef public object noun_chunks_iterator
|
||||
|
||||
cdef int push_back(self, LexemeOrToken lex_or_tok, bint trailing_space) except -1
|
||||
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1
|
||||
|
||||
cpdef np.ndarray to_array(self, object features)
|
||||
|
||||
|
|
|
@ -324,7 +324,6 @@ cdef class Vocab:
|
|||
self.lexemes_from_bytes(file_.read())
|
||||
if self.vectors is not None:
|
||||
self.vectors.from_disk(path, exclude='strings.json')
|
||||
link_vectors_to_models(self)
|
||||
return self
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
|
@ -364,7 +363,6 @@ cdef class Vocab:
|
|||
('vectors', lambda b: serialize_vectors(b))
|
||||
))
|
||||
util.from_bytes(bytes_data, setters, exclude)
|
||||
link_vectors_to_models(self)
|
||||
return self
|
||||
|
||||
def lexemes_to_bytes(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user