diff --git a/setup.py b/setup.py index 02d4fe0d9..0a3384ed5 100755 --- a/setup.py +++ b/setup.py @@ -36,7 +36,6 @@ MOD_NAMES = [ 'spacy.syntax.transition_system', 'spacy.syntax.arc_eager', 'spacy.syntax._parse_features', - 'spacy.syntax._beam_utils', 'spacy.gold', 'spacy.tokens.doc', 'spacy.tokens.span', diff --git a/spacy/_ml.py b/spacy/_ml.py index 91b530fad..f1ded666e 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -5,12 +5,10 @@ from thinc.neural._classes.hash_embed import HashEmbed from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.util import get_array_module import random -import cytoolz from thinc.neural._classes.convolution import ExtractWindow from thinc.neural._classes.static_vectors import StaticVectors from thinc.neural._classes.batchnorm import BatchNorm -from thinc.neural._classes.layernorm import LayerNorm as LN from thinc.neural._classes.resnet import Residual from thinc.neural import ReLu from thinc.neural._classes.selu import SELU @@ -21,12 +19,10 @@ from thinc.api import FeatureExtracter, with_getitem from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool from thinc.neural._classes.attention import ParametricAttention from thinc.linear.linear import LinearModel -from thinc.api import uniqued, wrap, flatten_add_lengths - +from thinc.api import uniqued, wrap from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP from .tokens.doc import Doc -from . import util import numpy import io @@ -57,27 +53,6 @@ def _logistic(X, drop=0.): return Y, logistic_bwd -@layerize -def add_tuples(X, drop=0.): - """Give inputs of sequence pairs, where each sequence is (vals, length), - sum the values, returning a single sequence. - - If input is: - ((vals1, length), (vals2, length) - Output is: - (vals1+vals2, length) - - vals are a single tensor for the whole batch. - """ - (vals1, length1), (vals2, length2) = X - assert length1 == length2 - - def add_tuples_bwd(dY, sgd=None): - return (dY, dY) - - return (vals1+vals2, length), add_tuples_bwd - - def _zero_init(model): def _zero_init_impl(self, X, y): self.W.fill(0) @@ -86,7 +61,6 @@ def _zero_init(model): model.W.fill(0.) return model - @layerize def _preprocess_doc(docs, drop=0.): keys = [doc.to_array([LOWER]) for doc in docs] @@ -98,6 +72,7 @@ def _preprocess_doc(docs, drop=0.): return (keys, vals, lengths), None + def _init_for_precomputed(W, ops): if (W**2).sum() != 0.: return @@ -105,7 +80,6 @@ def _init_for_precomputed(W, ops): ops.xavier_uniform_init(reshaped) W[:] = reshaped.reshape(W.shape) - @describe.on_data(_set_dimensions_if_needed) @describe.attributes( nI=Dimension("Input size"), @@ -210,36 +184,25 @@ class PrecomputableMaxouts(Model): return Yfp, backward -def drop_layer(layer, factor=2.): - def drop_layer_fwd(X, drop=0.): - drop *= factor - mask = layer.ops.get_dropout_mask((1,), drop) - if mask is None or mask > 0: - return layer.begin_update(X, drop=drop) - else: - return X, lambda dX, sgd=None: dX - return wrap(drop_layer_fwd, layer) - - def Tok2Vec(width, embed_size, preprocess=None): - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] + cols = [ID, NORM, PREFIX, SUFFIX, SHAPE] with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): - norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower') + norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower') prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix') suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix') shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape') - embed = (norm | prefix | suffix | shape ) >> Maxout(width, width*4, pieces=3) + embed = (norm | prefix | suffix | shape ) tok2vec = ( with_flatten( asarray(Model.ops, dtype='uint64') - >> uniqued(embed, column=5) - >> drop_layer( - Residual( - (ExtractWindow(nW=1) >> ReLu(width, width*3)) - ) - ) ** 4, pad=4 - ) + >> embed + >> Maxout(width, width*4, pieces=3) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)), + pad=4) ) if preprocess not in (False, None): tok2vec = preprocess >> tok2vec @@ -334,8 +297,7 @@ def zero_init(model): def doc2feats(cols=None): - if cols is None: - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] + cols = [ID, NORM, PREFIX, SUFFIX, SHAPE] def forward(docs, drop=0.): feats = [] for doc in docs: @@ -361,37 +323,6 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.): return vectors, backward -def fine_tune(embedding, combine=None): - if combine is not None: - raise NotImplementedError( - "fine_tune currently only supports addition. Set combine=None") - def fine_tune_fwd(docs_tokvecs, drop=0.): - docs, tokvecs = docs_tokvecs - lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i') - - vecs, bp_vecs = embedding.begin_update(docs, drop=drop) - flat_tokvecs = embedding.ops.flatten(tokvecs) - flat_vecs = embedding.ops.flatten(vecs) - output = embedding.ops.unflatten( - (model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs), - lengths) - - def fine_tune_bwd(d_output, sgd=None): - bp_vecs(d_output, sgd=sgd) - flat_grad = model.ops.flatten(d_output) - model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum() - model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum() - if sgd is not None: - sgd(model._mem.weights, model._mem.gradient, key=model.id) - return d_output - return output, fine_tune_bwd - model = wrap(fine_tune_fwd, embedding) - model.mix = model._mem.add((model.id, 'mix'), (2,)) - model.mix.fill(1.) - model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix')) - return model - - @layerize def flatten(seqs, drop=0.): if isinstance(seqs[0], numpy.ndarray): @@ -438,27 +369,6 @@ def preprocess_doc(docs, drop=0.): vals = ops.allocate(keys.shape[0]) + 1 return (keys, vals, lengths), None -def getitem(i): - def getitem_fwd(X, drop=0.): - return X[i], None - return layerize(getitem_fwd) - -def build_tagger_model(nr_class, token_vector_width, **cfg): - embed_size = util.env_opt('embed_size', 7500) - with Model.define_operators({'>>': chain, '+': add}): - # Input: (doc, tensor) tuples - private_tok2vec = Tok2Vec(token_vector_width, embed_size, preprocess=doc2feats()) - - model = ( - fine_tune(private_tok2vec) - >> with_flatten( - Maxout(token_vector_width, token_vector_width) - >> Softmax(nr_class, token_vector_width) - ) - ) - model.nI = None - return model - def build_text_classifier(nr_class, width=64, **cfg): nr_vector = cfg.get('nr_vector', 200) @@ -473,7 +383,7 @@ def build_text_classifier(nr_class, width=64, **cfg): >> _flatten_add_lengths >> with_getitem(0, uniqued( - (embed_lower | embed_prefix | embed_suffix | embed_shape) + (embed_lower | embed_prefix | embed_suffix | embed_shape) >> Maxout(width, width+(width//2)*3)) >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3)) >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3)) @@ -494,7 +404,7 @@ def build_text_classifier(nr_class, width=64, **cfg): >> zero_init(Affine(nr_class, nr_class*2, drop_factor=0.0)) >> logistic ) - + model.lsuv = False return model diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index fef6753e6..a0a76e5ec 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -21,10 +21,10 @@ CONVERTERS = { @plac.annotations( input_file=("input file", "positional", None, str), output_dir=("output directory for converted file", "positional", None, str), - n_sents=("Number of sentences per doc", "option", "n", int), + n_sents=("Number of sentences per doc", "option", "n", float), morphology=("Enable appending morphology to tags", "flag", "m", bool) ) -def convert(cmd, input_file, output_dir, n_sents=1, morphology=False): +def convert(cmd, input_file, output_dir, n_sents, morphology): """ Convert files into JSON format for use with train command and other experiment management functions. diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 04aac8319..af028dae5 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -91,14 +91,15 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, for batch in minibatch(train_docs, size=batch_sizes): docs, golds = zip(*batch) nlp.update(docs, golds, sgd=optimizer, - drop=next(dropout_rates), losses=losses, - update_tensors=True) + drop=next(dropout_rates), losses=losses) pbar.update(sum(len(doc) for doc in docs)) with nlp.use_params(optimizer.averages): util.set_env_log(False) epoch_model_path = output_path / ('model%d' % i) nlp.to_disk(epoch_model_path) + with (output_path / ('model%d.pickle' % i)).open('wb') as file_: + dill.dump(nlp, file_, -1) nlp_loaded = lang_class(pipeline=pipeline) nlp_loaded = nlp_loaded.from_disk(epoch_model_path) scorer = nlp_loaded.evaluate( diff --git a/spacy/language.py b/spacy/language.py index cb679a2bc..0284c4636 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -277,8 +277,7 @@ class Language(object): def make_doc(self, text): return self.tokenizer(text) - def update(self, docs, golds, drop=0., sgd=None, losses=None, - update_tensors=False): + def update(self, docs, golds, drop=0., sgd=None, losses=None): """Update the models in the pipeline. docs (iterable): A batch of `Doc` objects. @@ -305,17 +304,14 @@ class Language(object): grads[key] = (W, dW) pipes = list(self.pipeline[1:]) random.shuffle(pipes) - tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) - all_d_tokvecses = [tok2vec.model.ops.allocate(tv.shape) for tv in tokvecses] for proc in pipes: if not hasattr(proc, 'update'): continue + tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) d_tokvecses = proc.update((docs, tokvecses), golds, drop=drop, sgd=get_grads, losses=losses) - if update_tensors and d_tokvecses is not None: - for i, d_tv in enumerate(d_tokvecses): - all_d_tokvecses[i] += d_tv - bp_tokvecses(all_d_tokvecses, sgd=sgd) + if d_tokvecses is not None: + bp_tokvecses(d_tokvecses, sgd=sgd) for key, (W, dW) in grads.items(): sgd(W, dW, key=key) # Clear the tensor variable, to free GPU memory. @@ -385,18 +381,9 @@ class Language(object): return optimizer def evaluate(self, docs_golds): - scorer = Scorer() docs, golds = zip(*docs_golds) - docs = list(docs) - golds = list(golds) - for pipe in self.pipeline: - if not hasattr(pipe, 'pipe'): - for doc in docs: - pipe(doc) - else: - docs = list(pipe.pipe(docs)) - assert len(docs) == len(golds) - for doc, gold in zip(docs, golds): + scorer = Scorer() + for doc, gold in zip(self.pipe(docs, batch_size=32), golds): scorer.score(doc, gold) doc.tensor = None return scorer diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 634d3e4b5..947f0a1f1 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -42,7 +42,7 @@ from .compat import json_dumps from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats -from ._ml import build_text_classifier, build_tagger_model +from ._ml import build_text_classifier from .parts_of_speech import X @@ -138,7 +138,7 @@ class TokenVectorEncoder(BaseThincComponent): name = 'tensorizer' @classmethod - def Model(cls, width=128, embed_size=4000, **cfg): + def Model(cls, width=128, embed_size=7500, **cfg): """Create a new statistical model for the class. width (int): Output size of the model. @@ -253,25 +253,23 @@ class NeuralTagger(BaseThincComponent): self.cfg = dict(cfg) def __call__(self, doc): - tags = self.predict(([doc], [doc.tensor])) + tags = self.predict([doc.tensor]) self.set_annotations([doc], tags) return doc def pipe(self, stream, batch_size=128, n_threads=-1): for docs in cytoolz.partition_all(batch_size, stream): - docs = list(docs) tokvecs = [d.tensor for d in docs] - tag_ids = self.predict((docs, tokvecs)) + tag_ids = self.predict(tokvecs) self.set_annotations(docs, tag_ids) yield from docs - def predict(self, docs_tokvecs): - scores = self.model(docs_tokvecs) + def predict(self, tokvecs): + scores = self.model(tokvecs) scores = self.model.ops.flatten(scores) guesses = scores.argmax(axis=1) if not isinstance(guesses, numpy.ndarray): guesses = guesses.get() - tokvecs = docs_tokvecs[1] guesses = self.model.ops.unflatten(guesses, [tv.shape[0] for tv in tokvecs]) return guesses @@ -284,8 +282,6 @@ class NeuralTagger(BaseThincComponent): cdef Vocab vocab = self.vocab for i, doc in enumerate(docs): doc_tag_ids = batch_tag_ids[i] - if hasattr(doc_tag_ids, 'get'): - doc_tag_ids = doc_tag_ids.get() for j, tag_id in enumerate(doc_tag_ids): # Don't clobber preset POS tags if doc.c[j].tag == 0 and doc.c[j].pos == 0: @@ -298,7 +294,8 @@ class NeuralTagger(BaseThincComponent): if self.model.nI is None: self.model.nI = tokvecs[0].shape[1] - tag_scores, bp_tag_scores = self.model.begin_update(docs_tokvecs, drop=drop) + + tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop) loss, d_tag_scores = self.get_loss(docs, golds, tag_scores) d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) @@ -349,8 +346,10 @@ class NeuralTagger(BaseThincComponent): @classmethod def Model(cls, n_tags, token_vector_width): - return build_tagger_model(n_tags, token_vector_width) - + return with_flatten( + chain(Maxout(token_vector_width, token_vector_width), + Softmax(n_tags, token_vector_width))) + def use_params(self, params): with self.model.use_params(params): yield @@ -433,7 +432,7 @@ class NeuralLabeller(NeuralTagger): @property def labels(self): - return self.cfg.setdefault('labels', {}) + return self.cfg.get('labels', {}) @labels.setter def labels(self, value): @@ -456,8 +455,10 @@ class NeuralLabeller(NeuralTagger): @classmethod def Model(cls, n_tags, token_vector_width): - return build_tagger_model(n_tags, token_vector_width) - + return with_flatten( + chain(Maxout(token_vector_width, token_vector_width), + Softmax(n_tags, token_vector_width))) + def get_loss(self, docs, golds, scores): scores = self.model.ops.flatten(scores) cdef int idx = 0 diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx deleted file mode 100644 index e77036e55..000000000 --- a/spacy/syntax/_beam_utils.pyx +++ /dev/null @@ -1,273 +0,0 @@ -# cython: infer_types=True -# cython: profile=True -cimport numpy as np -import numpy -from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF -from thinc.extra.search cimport Beam -from thinc.extra.search import MaxViolation -from thinc.typedefs cimport hash_t, class_t - -from .transition_system cimport TransitionSystem, Transition -from .stateclass cimport StateClass -from ..gold cimport GoldParse -from ..tokens.doc cimport Doc - - -# These are passed as callbacks to thinc.search.Beam -cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: - dest = _dest - src = _src - moves = _moves - dest.clone(src) - moves[clas].do(dest.c, moves[clas].label) - - -cdef int _check_final_state(void* _state, void* extra_args) except -1: - return (_state).is_final() - - -def _cleanup(Beam beam): - for i in range(beam.width): - Py_XDECREF(beam._states[i].content) - Py_XDECREF(beam._parents[i].content) - - -cdef hash_t _hash_state(void* _state, void* _) except 0: - state = _state - if state.c.is_final(): - return 1 - else: - return state.c.hash() - - -cdef class ParserBeam(object): - cdef public TransitionSystem moves - cdef public object states - cdef public object golds - cdef public object beams - - def __init__(self, TransitionSystem moves, states, golds, - int width=4, float density=0.001): - self.moves = moves - self.states = states - self.golds = golds - self.beams = [] - cdef Beam beam - cdef StateClass state, st - for state in states: - beam = Beam(self.moves.n_moves, width, density) - beam.initialize(self.moves.init_beam_state, state.c.length, state.c._sent) - for i in range(beam.width): - st = beam.at(i) - st.c.offset = state.c.offset - self.beams.append(beam) - - def __dealloc__(self): - if self.beams is not None: - for beam in self.beams: - if beam is not None: - _cleanup(beam) - - @property - def is_done(self): - return all(b.is_done for b in self.beams) - - def __getitem__(self, i): - return self.beams[i] - - def __len__(self): - return len(self.beams) - - def advance(self, scores, follow_gold=False): - cdef Beam beam - for i, beam in enumerate(self.beams): - if beam.is_done or not scores[i].size: - continue - self._set_scores(beam, scores[i]) - if self.golds is not None: - self._set_costs(beam, self.golds[i], follow_gold=follow_gold) - if follow_gold: - assert self.golds is not None - beam.advance(_transition_state, NULL, self.moves.c) - else: - beam.advance(_transition_state, _hash_state, self.moves.c) - beam.check_done(_check_final_state, NULL) - if beam.is_done: - for j in range(beam.size): - if is_gold(beam.at(j), self.golds[i], self.moves.strings): - beam._states[j].loss = 0.0 - elif beam._states[j].loss == 0.0: - beam._states[j].loss = 1.0 - - def _set_scores(self, Beam beam, float[:, ::1] scores): - cdef float* c_scores = &scores[0, 0] - for i in range(beam.size): - state = beam.at(i) - if not state.is_final(): - for j in range(beam.nr_class): - beam.scores[i][j] = c_scores[i * beam.nr_class + j] - self.moves.set_valid(beam.is_valid[i], state.c) - - def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): - for i in range(beam.size): - state = beam.at(i) - if not state.c.is_final(): - self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold) - if follow_gold: - for j in range(beam.nr_class): - if beam.costs[i][j] >= 1: - beam.is_valid[i][j] = 0 - - -def is_gold(StateClass state, GoldParse gold, strings): - predicted = set() - truth = set() - for i in range(gold.length): - if gold.cand_to_gold[i] is None: - continue - if state.safe_get(i).dep: - predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) - else: - predicted.add((i, state.H(i), 'ROOT')) - id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] - truth.add((id_, head, dep)) - return truth == predicted - - -def get_token_ids(states, int n_tokens): - cdef StateClass state - cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), - dtype='int32', order='C') - c_ids = ids.data - for i, state in enumerate(states): - if not state.is_final(): - state.c.set_context_tokens(c_ids, n_tokens) - else: - ids[i] = -1 - c_ids += ids.shape[1] - return ids - -nr_update = 0 -def update_beam(TransitionSystem moves, int nr_feature, int max_steps, - states, tokvecs, golds, - state2vec, vec2scores, drop=0., sgd=None, - losses=None, int width=4, float density=0.001): - global nr_update - nr_update += 1 - pbeam = ParserBeam(moves, states, golds, - width=width, density=density) - gbeam = ParserBeam(moves, states, golds, - width=width, density=0.0) - cdef StateClass state - beam_maps = [] - backprops = [] - violns = [MaxViolation() for _ in range(len(states))] - for t in range(max_steps): - # The beam maps let us find the right row in the flattened scores - # arrays for each state. States are identified by (example id, history). - # We keep a different beam map for each step (since we'll have a flat - # scores array for each step). The beam map will let us take the per-state - # losses, and compute the gradient for each (step, state, class). - beam_maps.append({}) - # Gather all states from the two beams in a list. Some stats may occur - # in both beams. To figure out which beam each state belonged to, - # we keep two lists of indices, p_indices and g_indices - states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1], nr_update) - if not states: - break - # Now that we have our flat list of states, feed them through the model - token_ids = get_token_ids(states, nr_feature) - vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) - scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) - - # Store the callbacks for the backward pass - backprops.append((token_ids, bp_vectors, bp_scores)) - - # Unpack the flat scores into lists for the two beams. The indices arrays - # tell us which example and state the scores-row refers to. - p_scores = [numpy.ascontiguousarray(scores[indices], dtype='f') for indices in p_indices] - g_scores = [numpy.ascontiguousarray(scores[indices], dtype='f') for indices in g_indices] - # Now advance the states in the beams. The gold beam is contrained to - # to follow only gold analyses. - pbeam.advance(p_scores) - gbeam.advance(g_scores, follow_gold=True) - # Track the "maximum violation", to use in the update. - for i, violn in enumerate(violns): - violn.check_crf(pbeam[i], gbeam[i]) - - # Only make updates if we have non-gold states - histories = [((v.p_hist + v.g_hist) if v.p_hist else []) for v in violns] - losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns] - states_d_scores = get_gradient(moves.n_moves, beam_maps, - histories, losses) - assert len(states_d_scores) == len(backprops), (len(states_d_scores), len(backprops)) - return states_d_scores, backprops - - -def get_states(pbeams, gbeams, beam_map, nr_update): - seen = {} - states = [] - p_indices = [] - g_indices = [] - cdef Beam pbeam, gbeam - assert len(pbeams) == len(gbeams) - for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): - p_indices.append([]) - g_indices.append([]) - if pbeam.loss > 0 and pbeam.min_score > gbeam.score: - continue - for i in range(pbeam.size): - state = pbeam.at(i) - if not state.is_final(): - key = tuple([eg_id] + pbeam.histories[i]) - seen[key] = len(states) - p_indices[-1].append(len(states)) - states.append(state) - beam_map.update(seen) - for i in range(gbeam.size): - state = gbeam.at(i) - if not state.is_final(): - key = tuple([eg_id] + gbeam.histories[i]) - if key in seen: - g_indices[-1].append(seen[key]) - else: - g_indices[-1].append(len(states)) - beam_map[key] = len(states) - states.append(state) - p_idx = [numpy.asarray(idx, dtype='i') for idx in p_indices] - g_idx = [numpy.asarray(idx, dtype='i') for idx in g_indices] - return states, p_idx, g_idx - - -def get_gradient(nr_class, beam_maps, histories, losses): - """ - The global model assigns a loss to each parse. The beam scores - are additive, so the same gradient is applied to each action - in the history. This gives the gradient of a single *action* - for a beam state -- so we have "the gradient of loss for taking - action i given history H." - - Histories: Each hitory is a list of actions - Each candidate has a history - Each beam has multiple candidates - Each batch has multiple beams - So history is list of lists of lists of ints - """ - nr_step = len(beam_maps) - grads = [] - for beam_map in beam_maps: - if beam_map: - grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) - assert len(histories) == len(losses) - for eg_id, hists in enumerate(histories): - for loss, hist in zip(losses[eg_id], hists): - key = tuple([eg_id]) - for j, clas in enumerate(hist): - i = beam_maps[j][key] - # In step j, at state i action clas - # resulted in loss - grads[j][i, clas] += loss / len(histories) - key = key + tuple([clas]) - return grads - - diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 9aeeba441..c06851978 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -37,7 +37,6 @@ cdef cppclass StateC: this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) this._sent = calloc(length + (PADDING * 2), sizeof(TokenC)) this._ents = calloc(length + (PADDING * 2), sizeof(Entity)) - this.offset = 0 cdef int i for i in range(length + (PADDING * 2)): this._ents[i].end = -1 diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 9477449a5..29e8de0aa 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -385,7 +385,6 @@ cdef class ArcEager(TransitionSystem): for i in range(self.n_moves): if self.c[i].move == move and self.c[i].label == label: return self.c[i] - return Transition(clas=0, move=MISSING, label=0) def move_name(self, int move, attr_t label): label_str = self.strings[label] diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index f4f66f9fb..e96e28fcf 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -34,7 +34,6 @@ from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass from .parser cimport Parser -from ._beam_utils import is_gold DEBUG = False @@ -238,3 +237,16 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio raise Exception("Gold parse is not gold-standard") +def is_gold(StateClass state, GoldParse gold, StringStore strings): + predicted = set() + truth = set() + for i in range(gold.length): + if gold.cand_to_gold[i] is None: + continue + if state.safe_get(i).dep: + predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) + else: + predicted.add((i, state.H(i), 'ROOT')) + id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] + truth.add((id_, head, dep)) + return truth == predicted diff --git a/spacy/syntax/nn_parser.pxd b/spacy/syntax/nn_parser.pxd index 7ff4b9f9f..524718965 100644 --- a/spacy/syntax/nn_parser.pxd +++ b/spacy/syntax/nn_parser.pxd @@ -14,4 +14,8 @@ cdef class Parser: cdef readonly TransitionSystem moves cdef readonly object cfg + cdef void _parse_step(self, StateC* state, + const float* feat_weights, + int nr_class, int nr_feat, int nr_piece) nogil + #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index a193c96a3..0b39e2216 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -37,17 +37,14 @@ from preshed.maps cimport MapStruct from preshed.maps cimport map_get from thinc.api import layerize, chain, noop, clone -from thinc.neural import Model, Affine, ReLu, Maxout -from thinc.neural._classes.selu import SELU -from thinc.neural._classes.layernorm import LayerNorm +from thinc.neural import Model, Affine, ELU, ReLu, Maxout from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.util import get_array_module from .. import util from ..util import get_async, get_cuda_stream from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts -from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune -from .._ml import Residual, drop_layer +from .._ml import Tok2Vec, doc2feats, rebatch from ..compat import json_dumps from . import _parse_features @@ -62,11 +59,8 @@ from ..structs cimport TokenC from ..tokens.doc cimport Doc from ..strings cimport StringStore from ..gold cimport GoldParse -from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG -from . import _beam_utils +from ..attrs cimport TAG, DEP -USE_FINE_TUNE = True -BEAM_PARSE = True def get_templates(*args, **kwargs): return [] @@ -238,14 +232,11 @@ cdef class Parser: Base class of the DependencyParser and EntityRecognizer. """ @classmethod - def Model(cls, nr_class, token_vector_width=128, hidden_width=300, depth=1, **cfg): + def Model(cls, nr_class, token_vector_width=128, hidden_width=128, depth=1, **cfg): depth = util.env_opt('parser_hidden_depth', depth) token_vector_width = util.env_opt('token_vector_width', token_vector_width) hidden_width = util.env_opt('hidden_width', hidden_width) parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) - embed_size = util.env_opt('embed_size', 4000) - tensors = fine_tune(Tok2Vec(token_vector_width, embed_size, - preprocess=doc2feats())) if parser_maxout_pieces == 1: lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class, nF=cls.nr_feature, @@ -257,10 +248,15 @@ cdef class Parser: nI=token_vector_width) with Model.use_device('cpu'): - upper = chain( - clone(Residual(ReLu(hidden_width)), (depth-1)), - zero_init(Affine(nr_class, drop_factor=0.0)) - ) + if depth == 0: + upper = chain() + upper.is_noop = True + else: + upper = chain( + clone(Maxout(hidden_width), (depth-1)), + zero_init(Affine(nr_class, drop_factor=0.0)) + ) + upper.is_noop = False # TODO: This is an unfortunate hack atm! # Used to set input dimensions in network. lower.begin_training(lower.ops.allocate((500, token_vector_width))) @@ -272,7 +268,7 @@ cdef class Parser: 'hidden_width': hidden_width, 'maxout_pieces': parser_maxout_pieces } - return (tensors, lower, upper), cfg + return (lower, upper), cfg def __init__(self, Vocab vocab, moves=True, model=True, **cfg): """ @@ -348,21 +344,17 @@ cdef class Parser: The number of threads with which to work on the buffer in parallel. Yields (Doc): Documents, in order. """ - if BEAM_PARSE: - beam_width = 8 + cdef StateClass parse_state cdef Doc doc - cdef Beam beam + queue = [] for docs in cytoolz.partition_all(batch_size, docs): docs = list(docs) - tokvecs = [doc.tensor for doc in docs] + tokvecs = [d.tensor for d in docs] if beam_width == 1: parse_states = self.parse_batch(docs, tokvecs) else: - beams = self.beam_parse(docs, tokvecs, - beam_width=beam_width, beam_density=beam_density) - parse_states = [] - for beam in beams: - parse_states.append(beam.at(0)) + parse_states = self.beam_parse(docs, tokvecs, + beam_width=beam_width, beam_density=beam_density) self.set_annotations(docs, parse_states) yield from docs @@ -377,12 +369,8 @@ cdef class Parser: int nr_class, nr_feat, nr_piece, nr_dim, nr_state if isinstance(docs, Doc): docs = [docs] - if isinstance(tokvecses, np.ndarray): - tokvecses = [tokvecses] tokvecs = self.model[0].ops.flatten(tokvecses) - if USE_FINE_TUNE: - tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) nr_state = len(docs) nr_class = self.moves.n_moves @@ -406,20 +394,27 @@ cdef class Parser: cdef np.ndarray scores c_token_ids = token_ids.data c_is_valid = is_valid.data + cdef int has_hidden = not getattr(vec2scores, 'is_noop', False) while not next_step.empty(): - for i in range(next_step.size()): - st = next_step[i] - st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) - self.moves.set_valid(&c_is_valid[i*nr_class], st) + if not has_hidden: + for i in cython.parallel.prange( + next_step.size(), num_threads=6, nogil=True): + self._parse_step(next_step[i], + feat_weights, nr_class, nr_feat, nr_piece) + else: + for i in range(next_step.size()): + st = next_step[i] + st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) + self.moves.set_valid(&c_is_valid[i*nr_class], st) vectors = state2vec(token_ids[:next_step.size()]) - scores = vec2scores(vectors) - c_scores = scores.data - for i in range(next_step.size()): - st = next_step[i] - guess = arg_max_if_valid( - &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) - action = self.moves.c[guess] - action.do(st, action.label) + scores = vec2scores(vectors) + c_scores = scores.data + for i in range(next_step.size()): + st = next_step[i] + guess = arg_max_if_valid( + &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) + action = self.moves.c[guess] + action.do(st, action.label) this_step, next_step = next_step, this_step next_step.clear() for st in this_step: @@ -434,15 +429,11 @@ cdef class Parser: cdef int nr_class = self.moves.n_moves cdef StateClass stcls, output tokvecs = self.model[0].ops.flatten(tokvecses) - if USE_FINE_TUNE: - tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) cuda_stream = get_cuda_stream() state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0) beams = [] cdef int offset = 0 - cdef int j = 0 - cdef int k for doc in docs: beam = Beam(nr_class, beam_width, min_density=beam_density) beam.initialize(self.moves.init_beam_state, doc.length, doc.c) @@ -455,31 +446,44 @@ cdef class Parser: states = [] for i in range(beam.size): stcls = beam.at(i) - # This way we avoid having to score finalized states - # We do have to take care to keep indexes aligned, though - if not stcls.is_final(): - states.append(stcls) + states.append(stcls) token_ids = self.get_token_ids(states) vectors = state2vec(token_ids) scores = vec2scores(vectors) - j = 0 - c_scores = scores.data for i in range(beam.size): stcls = beam.at(i) if not stcls.is_final(): self.moves.set_valid(beam.is_valid[i], stcls.c) - for k in range(nr_class): - beam.scores[i][k] = c_scores[j * scores.shape[1] + k] - j += 1 + for j in range(nr_class): + beam.scores[i][j] = scores[i, j] beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) beams.append(beam) return beams + cdef void _parse_step(self, StateC* state, + const float* feat_weights, + int nr_class, int nr_feat, int nr_piece) nogil: + '''This only works with no hidden layers -- fast but inaccurate''' + #for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True): + # self._parse_step(next_step[i], feat_weights, nr_class, nr_feat) + token_ids = calloc(nr_feat, sizeof(int)) + scores = calloc(nr_class * nr_piece, sizeof(float)) + is_valid = calloc(nr_class, sizeof(int)) + + state.set_context_tokens(token_ids, nr_feat) + sum_state_features(scores, + feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece) + self.moves.set_valid(is_valid, state) + guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece) + action = self.moves.c[guess] + action.do(state, action.label) + + free(is_valid) + free(scores) + free(token_ids) + def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): - if BEAM_PARSE: - return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd, - losses=losses) if losses is not None and self.name not in losses: losses[self.name] = 0. docs, tokvec_lists = docs_tokvecs @@ -487,10 +491,6 @@ cdef class Parser: if isinstance(docs, Doc) and isinstance(golds, GoldParse): docs = [docs] golds = [golds] - if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - my_tokvecs = self.model[0].ops.flatten(my_tokvecs) - tokvecs += my_tokvecs cuda_stream = get_cuda_stream() @@ -517,13 +517,13 @@ cdef class Parser: scores, bp_scores = vec2scores.begin_update(vector, drop=drop) d_scores = self.get_batch_loss(states, golds, scores) - d_vector = bp_scores(d_scores, sgd=sgd) + d_vector = bp_scores(d_scores / d_scores.shape[0], sgd=sgd) if drop != 0: d_vector *= mask if isinstance(self.model[0].ops, CupyOps) \ and not isinstance(token_ids, state2vec.ops.xp.ndarray): - # Move token_ids and d_vector to GPU, asynchronously + # Move token_ids and d_vector to CPU, asynchronously backprops.append(( get_async(cuda_stream, token_ids), get_async(cuda_stream, d_vector), @@ -540,55 +540,7 @@ cdef class Parser: break self._make_updates(d_tokvecs, backprops, sgd, cuda_stream) - d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) - if USE_FINE_TUNE: - bp_my_tokvecs(d_tokvecs, sgd=sgd) - return d_tokvecs - - def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): - if losses is not None and self.name not in losses: - losses[self.name] = 0. - docs, tokvecs = docs_tokvecs - lengths = [len(d) for d in docs] - assert min(lengths) >= 1 - tokvecs = self.model[0].ops.flatten(tokvecs) - if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - my_tokvecs = self.model[0].ops.flatten(my_tokvecs) - tokvecs += my_tokvecs - - states = self.moves.init_batch(docs) - for gold in golds: - self.moves.preprocess_gold(gold) - - cuda_stream = get_cuda_stream() - state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, 0.0) - - states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500, - states, tokvecs, golds, - state2vec, vec2scores, - drop, sgd, losses, - width=8) - backprop_lower = [] - for i, d_scores in enumerate(states_d_scores): - if losses is not None: - losses[self.name] += (d_scores**2).sum() - ids, bp_vectors, bp_scores = backprops[i] - d_vector = bp_scores(d_scores, sgd=sgd) - if isinstance(self.model[0].ops, CupyOps) \ - and not isinstance(ids, state2vec.ops.xp.ndarray): - backprop_lower.append(( - get_async(cuda_stream, ids), - get_async(cuda_stream, d_vector), - bp_vectors)) - else: - backprop_lower.append((ids, d_vector, bp_vectors)) - d_tokvecs = self.model[0].ops.allocate(tokvecs.shape) - self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream) - d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths) - if USE_FINE_TUNE: - bp_my_tokvecs(d_tokvecs, sgd=sgd) - return d_tokvecs + return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) def _init_gold_batch(self, whole_docs, whole_golds): """Make a square batch, of length equal to the shortest doc. A long @@ -633,10 +585,14 @@ cdef class Parser: xp = get_array_module(d_tokvecs) for ids, d_vector, bp_vector in backprops: d_state_features = bp_vector(d_vector, sgd=sgd) - mask = ids >= 0 - indices = xp.nonzero(mask) - self.model[0].ops.scatter_add(d_tokvecs, ids[indices], - d_state_features[indices]) + active_feats = ids * (ids >= 0) + active_feats = active_feats.reshape((ids.shape[0], ids.shape[1], 1)) + if hasattr(xp, 'scatter_add'): + xp.scatter_add(d_tokvecs, + ids, d_state_features * active_feats) + else: + xp.add.at(d_tokvecs, + ids, d_state_features * active_feats) @property def move_names(self): @@ -647,7 +603,7 @@ cdef class Parser: return names def get_batch_model(self, batch_size, tokvecs, stream, dropout): - _, lower, upper = self.model + lower, upper = self.model state2vec = precompute_hiddens(batch_size, tokvecs, lower, stream, drop=dropout) return state2vec, upper @@ -737,12 +693,10 @@ cdef class Parser: def to_disk(self, path, **exclude): serializers = { - 'tok2vec_model': lambda p: p.open('wb').write( - self.model[0].to_bytes()), 'lower_model': lambda p: p.open('wb').write( - self.model[1].to_bytes()), + self.model[0].to_bytes()), 'upper_model': lambda p: p.open('wb').write( - self.model[2].to_bytes()), + self.model[1].to_bytes()), 'vocab': lambda p: self.vocab.to_disk(p), 'moves': lambda p: self.moves.to_disk(p, strings=False), 'cfg': lambda p: p.open('w').write(json_dumps(self.cfg)) @@ -763,29 +717,24 @@ cdef class Parser: self.model, cfg = self.Model(**self.cfg) else: cfg = {} - with (path / 'tok2vec_model').open('rb') as file_: - bytes_data = file_.read() - self.model[0].from_bytes(bytes_data) with (path / 'lower_model').open('rb') as file_: bytes_data = file_.read() - self.model[1].from_bytes(bytes_data) + self.model[0].from_bytes(bytes_data) with (path / 'upper_model').open('rb') as file_: bytes_data = file_.read() - self.model[2].from_bytes(bytes_data) + self.model[1].from_bytes(bytes_data) self.cfg.update(cfg) return self def to_bytes(self, **exclude): serializers = OrderedDict(( - ('tok2vec_model', lambda: self.model[0].to_bytes()), - ('lower_model', lambda: self.model[1].to_bytes()), - ('upper_model', lambda: self.model[2].to_bytes()), + ('lower_model', lambda: self.model[0].to_bytes()), + ('upper_model', lambda: self.model[1].to_bytes()), ('vocab', lambda: self.vocab.to_bytes()), ('moves', lambda: self.moves.to_bytes(strings=False)), ('cfg', lambda: ujson.dumps(self.cfg)) )) if 'model' in exclude: - exclude['tok2vec_model'] = True exclude['lower_model'] = True exclude['upper_model'] = True exclude.pop('model') @@ -796,7 +745,6 @@ cdef class Parser: ('vocab', lambda b: self.vocab.from_bytes(b)), ('moves', lambda b: self.moves.from_bytes(b, strings=False)), ('cfg', lambda b: self.cfg.update(ujson.loads(b))), - ('tok2vec_model', lambda b: None), ('lower_model', lambda b: None), ('upper_model', lambda b: None) )) @@ -806,12 +754,10 @@ cdef class Parser: self.model, cfg = self.Model(self.moves.n_moves) else: cfg = {} - if 'tok2vec_model' in msg: - self.model[0].from_bytes(msg['tok2vec_model']) if 'lower_model' in msg: - self.model[1].from_bytes(msg['lower_model']) + self.model[0].from_bytes(msg['lower_model']) if 'upper_model' in msg: - self.model[2].from_bytes(msg['upper_model']) + self.model[1].from_bytes(msg['upper_model']) self.cfg.update(cfg) return self diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index d3f64f827..27b375bba 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -107,8 +107,6 @@ cdef class TransitionSystem: def is_valid(self, StateClass stcls, move_name): action = self.lookup_transition(move_name) - if action.move == 0: - return False return action.is_valid(stcls.c, action.label) cdef int set_valid(self, int* is_valid, const StateC* st) nogil: diff --git a/spacy/tests/parser/test_neural_parser.py b/spacy/tests/parser/test_neural_parser.py index 30a6367c8..42b55745f 100644 --- a/spacy/tests/parser/test_neural_parser.py +++ b/spacy/tests/parser/test_neural_parser.py @@ -78,16 +78,3 @@ def test_predict_doc_beam(parser, tok2vec, model, doc): parser(doc, beam_width=32, beam_density=0.001) for word in doc: print(word.text, word.head, word.dep_) - - -def test_update_doc_beam(parser, tok2vec, model, doc, gold): - parser.model = model - tokvecs, bp_tokvecs = tok2vec.begin_update([doc]) - d_tokvecs = parser.update_beam(([doc], tokvecs), [gold]) - assert d_tokvecs[0].shape == tokvecs[0].shape - def optimize(weights, gradient, key=None): - weights -= 0.001 * gradient - bp_tokvecs(d_tokvecs, sgd=optimize) - assert d_tokvecs[0].sum() == 0. - - diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py deleted file mode 100644 index 45c85d969..000000000 --- a/spacy/tests/parser/test_nn_beam.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import unicode_literals -import pytest -import numpy -from thinc.api import layerize - -from ...vocab import Vocab -from ...syntax.arc_eager import ArcEager -from ...tokens import Doc -from ...gold import GoldParse -from ...syntax._beam_utils import ParserBeam, update_beam -from ...syntax.stateclass import StateClass - - -@pytest.fixture -def vocab(): - return Vocab() - -@pytest.fixture -def moves(vocab): - aeager = ArcEager(vocab.strings, {}) - aeager.add_action(2, 'nsubj') - aeager.add_action(3, 'dobj') - aeager.add_action(2, 'aux') - return aeager - - -@pytest.fixture -def docs(vocab): - return [Doc(vocab, words=['Rats', 'bite', 'things'])] - -@pytest.fixture -def states(docs): - return [StateClass(doc) for doc in docs] - -@pytest.fixture -def tokvecs(docs, vector_size): - output = [] - for doc in docs: - vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size)) - output.append(numpy.asarray(vec)) - return output - - -@pytest.fixture -def golds(docs): - return [GoldParse(doc) for doc in docs] - - -@pytest.fixture -def batch_size(docs): - return len(docs) - - -@pytest.fixture -def beam_width(): - return 4 - - -@pytest.fixture -def vector_size(): - return 6 - - -@pytest.fixture -def beam(moves, states, golds, beam_width): - return ParserBeam(moves, states, golds, width=beam_width) - -@pytest.fixture -def scores(moves, batch_size, beam_width): - return [ - numpy.asarray( - numpy.random.uniform(-0.1, 0.1, (batch_size, moves.n_moves)), - dtype='f') - for _ in range(batch_size)] - - -def test_create_beam(beam): - pass - - -def test_beam_advance(beam, scores): - beam.advance(scores) - - -def test_beam_advance_too_few_scores(beam, scores): - with pytest.raises(IndexError): - beam.advance(scores[:-1])