diff --git a/setup.py b/setup.py index 02d4fe0d9..0a3384ed5 100755 --- a/setup.py +++ b/setup.py @@ -36,7 +36,6 @@ MOD_NAMES = [ 'spacy.syntax.transition_system', 'spacy.syntax.arc_eager', 'spacy.syntax._parse_features', - 'spacy.syntax._beam_utils', 'spacy.gold', 'spacy.tokens.doc', 'spacy.tokens.span', diff --git a/spacy/_ml.py b/spacy/_ml.py index 34f1e0a53..0d67ce01e 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -5,12 +5,10 @@ from thinc.neural._classes.hash_embed import HashEmbed from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.util import get_array_module import random -import cytoolz from thinc.neural._classes.convolution import ExtractWindow from thinc.neural._classes.static_vectors import StaticVectors from thinc.neural._classes.batchnorm import BatchNorm -from thinc.neural._classes.layernorm import LayerNorm as LN from thinc.neural._classes.resnet import Residual from thinc.neural import ReLu from thinc.neural._classes.selu import SELU @@ -21,12 +19,10 @@ from thinc.api import FeatureExtracter, with_getitem from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool from thinc.neural._classes.attention import ParametricAttention from thinc.linear.linear import LinearModel -from thinc.api import uniqued, wrap, flatten_add_lengths - +from thinc.api import uniqued, wrap from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP from .tokens.doc import Doc -from . import util import numpy import io @@ -57,27 +53,6 @@ def _logistic(X, drop=0.): return Y, logistic_bwd -@layerize -def add_tuples(X, drop=0.): - """Give inputs of sequence pairs, where each sequence is (vals, length), - sum the values, returning a single sequence. - - If input is: - ((vals1, length), (vals2, length) - Output is: - (vals1+vals2, length) - - vals are a single tensor for the whole batch. - """ - (vals1, length1), (vals2, length2) = X - assert length1 == length2 - - def add_tuples_bwd(dY, sgd=None): - return (dY, dY) - - return (vals1+vals2, length), add_tuples_bwd - - def _zero_init(model): def _zero_init_impl(self, X, y): self.W.fill(0) @@ -86,7 +61,6 @@ def _zero_init(model): model.W.fill(0.) return model - @layerize def _preprocess_doc(docs, drop=0.): keys = [doc.to_array([LOWER]) for doc in docs] @@ -98,6 +72,7 @@ def _preprocess_doc(docs, drop=0.): return (keys, vals, lengths), None + def _init_for_precomputed(W, ops): if (W**2).sum() != 0.: return @@ -105,7 +80,6 @@ def _init_for_precomputed(W, ops): ops.xavier_uniform_init(reshaped) W[:] = reshaped.reshape(W.shape) - @describe.on_data(_set_dimensions_if_needed) @describe.attributes( nI=Dimension("Input size"), @@ -210,36 +184,25 @@ class PrecomputableMaxouts(Model): return Yfp, backward -def drop_layer(layer, factor=2.): - def drop_layer_fwd(X, drop=0.): - drop *= factor - mask = layer.ops.get_dropout_mask((1,), drop) - if mask is None or mask > 0: - return layer.begin_update(X, drop=drop) - else: - return X, lambda dX, sgd=None: dX - return wrap(drop_layer_fwd, layer) - - def Tok2Vec(width, embed_size, preprocess=None): - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] + cols = [ID, NORM, PREFIX, SUFFIX, SHAPE] with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): - norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower') + norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower') prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix') suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix') shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape') - embed = (norm | prefix | suffix | shape ) >> Maxout(width, width*4, pieces=3) + embed = (norm | prefix | suffix | shape ) tok2vec = ( with_flatten( asarray(Model.ops, dtype='uint64') - >> uniqued(embed, column=5) - >> drop_layer( - Residual( - (ExtractWindow(nW=1) >> ReLu(width, width*3)) - ) - ) ** 4, pad=4 - ) + >> embed + >> Maxout(width, width*4, pieces=3) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) + >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)), + pad=4) ) if preprocess not in (False, None): tok2vec = preprocess >> tok2vec @@ -334,8 +297,7 @@ def zero_init(model): def doc2feats(cols=None): - if cols is None: - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] + cols = [ID, NORM, PREFIX, SUFFIX, SHAPE] def forward(docs, drop=0.): feats = [] for doc in docs: @@ -441,27 +403,6 @@ def preprocess_doc(docs, drop=0.): vals = ops.allocate(keys.shape[0]) + 1 return (keys, vals, lengths), None -def getitem(i): - def getitem_fwd(X, drop=0.): - return X[i], None - return layerize(getitem_fwd) - -def build_tagger_model(nr_class, token_vector_width, **cfg): - embed_size = util.env_opt('embed_size', 7500) - with Model.define_operators({'>>': chain, '+': add}): - # Input: (doc, tensor) tuples - private_tok2vec = Tok2Vec(token_vector_width, embed_size, preprocess=doc2feats()) - - model = ( - fine_tune(private_tok2vec) - >> with_flatten( - Maxout(token_vector_width, token_vector_width) - >> Softmax(nr_class, token_vector_width) - ) - ) - model.nI = None - return model - def build_text_classifier(nr_class, width=64, **cfg): nr_vector = cfg.get('nr_vector', 200) @@ -476,7 +417,7 @@ def build_text_classifier(nr_class, width=64, **cfg): >> _flatten_add_lengths >> with_getitem(0, uniqued( - (embed_lower | embed_prefix | embed_suffix | embed_shape) + (embed_lower | embed_prefix | embed_suffix | embed_shape) >> Maxout(width, width+(width//2)*3)) >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3)) >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3)) @@ -497,7 +438,7 @@ def build_text_classifier(nr_class, width=64, **cfg): >> zero_init(Affine(nr_class, nr_class*2, drop_factor=0.0)) >> logistic ) - + model.lsuv = False return model diff --git a/spacy/about.py b/spacy/about.py index bf44c31d5..9f62c769e 100644 --- a/spacy/about.py +++ b/spacy/about.py @@ -3,7 +3,7 @@ # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py __title__ = 'spacy-nightly' -__version__ = '2.0.0a7' +__version__ = '2.0.0a9' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __uri__ = 'https://spacy.io' __author__ = 'Explosion AI' diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index fef6753e6..a0a76e5ec 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -21,10 +21,10 @@ CONVERTERS = { @plac.annotations( input_file=("input file", "positional", None, str), output_dir=("output directory for converted file", "positional", None, str), - n_sents=("Number of sentences per doc", "option", "n", int), + n_sents=("Number of sentences per doc", "option", "n", float), morphology=("Enable appending morphology to tags", "flag", "m", bool) ) -def convert(cmd, input_file, output_dir, n_sents=1, morphology=False): +def convert(cmd, input_file, output_dir, n_sents, morphology): """ Convert files into JSON format for use with train command and other experiment management functions. diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 04aac8319..af028dae5 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -91,14 +91,15 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, for batch in minibatch(train_docs, size=batch_sizes): docs, golds = zip(*batch) nlp.update(docs, golds, sgd=optimizer, - drop=next(dropout_rates), losses=losses, - update_tensors=True) + drop=next(dropout_rates), losses=losses) pbar.update(sum(len(doc) for doc in docs)) with nlp.use_params(optimizer.averages): util.set_env_log(False) epoch_model_path = output_path / ('model%d' % i) nlp.to_disk(epoch_model_path) + with (output_path / ('model%d.pickle' % i)).open('wb') as file_: + dill.dump(nlp, file_, -1) nlp_loaded = lang_class(pipeline=pipeline) nlp_loaded = nlp_loaded.from_disk(epoch_model_path) scorer = nlp_loaded.evaluate( diff --git a/spacy/deprecated.py b/spacy/deprecated.py index 77273d193..ad52bfe24 100644 --- a/spacy/deprecated.py +++ b/spacy/deprecated.py @@ -15,7 +15,7 @@ def depr_model_download(lang): lang (unicode): Language shortcut, 'en' or 'de'. """ prints("The spacy.%s.download command is now deprecated. Please use " - "python -m spacy download [model name or shortcut] instead. For " + "spacy download [model name or shortcut] instead. For " "more info, see the documentation:" % lang, about.__docs_models__, "Downloading default '%s' model now..." % lang, diff --git a/spacy/language.py b/spacy/language.py index cb679a2bc..0284c4636 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -277,8 +277,7 @@ class Language(object): def make_doc(self, text): return self.tokenizer(text) - def update(self, docs, golds, drop=0., sgd=None, losses=None, - update_tensors=False): + def update(self, docs, golds, drop=0., sgd=None, losses=None): """Update the models in the pipeline. docs (iterable): A batch of `Doc` objects. @@ -305,17 +304,14 @@ class Language(object): grads[key] = (W, dW) pipes = list(self.pipeline[1:]) random.shuffle(pipes) - tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) - all_d_tokvecses = [tok2vec.model.ops.allocate(tv.shape) for tv in tokvecses] for proc in pipes: if not hasattr(proc, 'update'): continue + tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) d_tokvecses = proc.update((docs, tokvecses), golds, drop=drop, sgd=get_grads, losses=losses) - if update_tensors and d_tokvecses is not None: - for i, d_tv in enumerate(d_tokvecses): - all_d_tokvecses[i] += d_tv - bp_tokvecses(all_d_tokvecses, sgd=sgd) + if d_tokvecses is not None: + bp_tokvecses(d_tokvecses, sgd=sgd) for key, (W, dW) in grads.items(): sgd(W, dW, key=key) # Clear the tensor variable, to free GPU memory. @@ -385,18 +381,9 @@ class Language(object): return optimizer def evaluate(self, docs_golds): - scorer = Scorer() docs, golds = zip(*docs_golds) - docs = list(docs) - golds = list(golds) - for pipe in self.pipeline: - if not hasattr(pipe, 'pipe'): - for doc in docs: - pipe(doc) - else: - docs = list(pipe.pipe(docs)) - assert len(docs) == len(golds) - for doc, gold in zip(docs, golds): + scorer = Scorer() + for doc, gold in zip(self.pipe(docs, batch_size=32), golds): scorer.score(doc, gold) doc.tensor = None return scorer diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 634d3e4b5..947f0a1f1 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -42,7 +42,7 @@ from .compat import json_dumps from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats -from ._ml import build_text_classifier, build_tagger_model +from ._ml import build_text_classifier from .parts_of_speech import X @@ -138,7 +138,7 @@ class TokenVectorEncoder(BaseThincComponent): name = 'tensorizer' @classmethod - def Model(cls, width=128, embed_size=4000, **cfg): + def Model(cls, width=128, embed_size=7500, **cfg): """Create a new statistical model for the class. width (int): Output size of the model. @@ -253,25 +253,23 @@ class NeuralTagger(BaseThincComponent): self.cfg = dict(cfg) def __call__(self, doc): - tags = self.predict(([doc], [doc.tensor])) + tags = self.predict([doc.tensor]) self.set_annotations([doc], tags) return doc def pipe(self, stream, batch_size=128, n_threads=-1): for docs in cytoolz.partition_all(batch_size, stream): - docs = list(docs) tokvecs = [d.tensor for d in docs] - tag_ids = self.predict((docs, tokvecs)) + tag_ids = self.predict(tokvecs) self.set_annotations(docs, tag_ids) yield from docs - def predict(self, docs_tokvecs): - scores = self.model(docs_tokvecs) + def predict(self, tokvecs): + scores = self.model(tokvecs) scores = self.model.ops.flatten(scores) guesses = scores.argmax(axis=1) if not isinstance(guesses, numpy.ndarray): guesses = guesses.get() - tokvecs = docs_tokvecs[1] guesses = self.model.ops.unflatten(guesses, [tv.shape[0] for tv in tokvecs]) return guesses @@ -284,8 +282,6 @@ class NeuralTagger(BaseThincComponent): cdef Vocab vocab = self.vocab for i, doc in enumerate(docs): doc_tag_ids = batch_tag_ids[i] - if hasattr(doc_tag_ids, 'get'): - doc_tag_ids = doc_tag_ids.get() for j, tag_id in enumerate(doc_tag_ids): # Don't clobber preset POS tags if doc.c[j].tag == 0 and doc.c[j].pos == 0: @@ -298,7 +294,8 @@ class NeuralTagger(BaseThincComponent): if self.model.nI is None: self.model.nI = tokvecs[0].shape[1] - tag_scores, bp_tag_scores = self.model.begin_update(docs_tokvecs, drop=drop) + + tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop) loss, d_tag_scores = self.get_loss(docs, golds, tag_scores) d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) @@ -349,8 +346,10 @@ class NeuralTagger(BaseThincComponent): @classmethod def Model(cls, n_tags, token_vector_width): - return build_tagger_model(n_tags, token_vector_width) - + return with_flatten( + chain(Maxout(token_vector_width, token_vector_width), + Softmax(n_tags, token_vector_width))) + def use_params(self, params): with self.model.use_params(params): yield @@ -433,7 +432,7 @@ class NeuralLabeller(NeuralTagger): @property def labels(self): - return self.cfg.setdefault('labels', {}) + return self.cfg.get('labels', {}) @labels.setter def labels(self, value): @@ -456,8 +455,10 @@ class NeuralLabeller(NeuralTagger): @classmethod def Model(cls, n_tags, token_vector_width): - return build_tagger_model(n_tags, token_vector_width) - + return with_flatten( + chain(Maxout(token_vector_width, token_vector_width), + Softmax(n_tags, token_vector_width))) + def get_loss(self, docs, golds, scores): scores = self.model.ops.flatten(scores) cdef int idx = 0 diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx deleted file mode 100644 index e77036e55..000000000 --- a/spacy/syntax/_beam_utils.pyx +++ /dev/null @@ -1,273 +0,0 @@ -# cython: infer_types=True -# cython: profile=True -cimport numpy as np -import numpy -from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF -from thinc.extra.search cimport Beam -from thinc.extra.search import MaxViolation -from thinc.typedefs cimport hash_t, class_t - -from .transition_system cimport TransitionSystem, Transition -from .stateclass cimport StateClass -from ..gold cimport GoldParse -from ..tokens.doc cimport Doc - - -# These are passed as callbacks to thinc.search.Beam -cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: - dest = _dest - src = _src - moves = _moves - dest.clone(src) - moves[clas].do(dest.c, moves[clas].label) - - -cdef int _check_final_state(void* _state, void* extra_args) except -1: - return (_state).is_final() - - -def _cleanup(Beam beam): - for i in range(beam.width): - Py_XDECREF(beam._states[i].content) - Py_XDECREF(beam._parents[i].content) - - -cdef hash_t _hash_state(void* _state, void* _) except 0: - state = _state - if state.c.is_final(): - return 1 - else: - return state.c.hash() - - -cdef class ParserBeam(object): - cdef public TransitionSystem moves - cdef public object states - cdef public object golds - cdef public object beams - - def __init__(self, TransitionSystem moves, states, golds, - int width=4, float density=0.001): - self.moves = moves - self.states = states - self.golds = golds - self.beams = [] - cdef Beam beam - cdef StateClass state, st - for state in states: - beam = Beam(self.moves.n_moves, width, density) - beam.initialize(self.moves.init_beam_state, state.c.length, state.c._sent) - for i in range(beam.width): - st = beam.at(i) - st.c.offset = state.c.offset - self.beams.append(beam) - - def __dealloc__(self): - if self.beams is not None: - for beam in self.beams: - if beam is not None: - _cleanup(beam) - - @property - def is_done(self): - return all(b.is_done for b in self.beams) - - def __getitem__(self, i): - return self.beams[i] - - def __len__(self): - return len(self.beams) - - def advance(self, scores, follow_gold=False): - cdef Beam beam - for i, beam in enumerate(self.beams): - if beam.is_done or not scores[i].size: - continue - self._set_scores(beam, scores[i]) - if self.golds is not None: - self._set_costs(beam, self.golds[i], follow_gold=follow_gold) - if follow_gold: - assert self.golds is not None - beam.advance(_transition_state, NULL, self.moves.c) - else: - beam.advance(_transition_state, _hash_state, self.moves.c) - beam.check_done(_check_final_state, NULL) - if beam.is_done: - for j in range(beam.size): - if is_gold(beam.at(j), self.golds[i], self.moves.strings): - beam._states[j].loss = 0.0 - elif beam._states[j].loss == 0.0: - beam._states[j].loss = 1.0 - - def _set_scores(self, Beam beam, float[:, ::1] scores): - cdef float* c_scores = &scores[0, 0] - for i in range(beam.size): - state = beam.at(i) - if not state.is_final(): - for j in range(beam.nr_class): - beam.scores[i][j] = c_scores[i * beam.nr_class + j] - self.moves.set_valid(beam.is_valid[i], state.c) - - def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): - for i in range(beam.size): - state = beam.at(i) - if not state.c.is_final(): - self.moves.set_costs(beam.is_valid[i], beam.costs[i], state, gold) - if follow_gold: - for j in range(beam.nr_class): - if beam.costs[i][j] >= 1: - beam.is_valid[i][j] = 0 - - -def is_gold(StateClass state, GoldParse gold, strings): - predicted = set() - truth = set() - for i in range(gold.length): - if gold.cand_to_gold[i] is None: - continue - if state.safe_get(i).dep: - predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) - else: - predicted.add((i, state.H(i), 'ROOT')) - id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] - truth.add((id_, head, dep)) - return truth == predicted - - -def get_token_ids(states, int n_tokens): - cdef StateClass state - cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), - dtype='int32', order='C') - c_ids = ids.data - for i, state in enumerate(states): - if not state.is_final(): - state.c.set_context_tokens(c_ids, n_tokens) - else: - ids[i] = -1 - c_ids += ids.shape[1] - return ids - -nr_update = 0 -def update_beam(TransitionSystem moves, int nr_feature, int max_steps, - states, tokvecs, golds, - state2vec, vec2scores, drop=0., sgd=None, - losses=None, int width=4, float density=0.001): - global nr_update - nr_update += 1 - pbeam = ParserBeam(moves, states, golds, - width=width, density=density) - gbeam = ParserBeam(moves, states, golds, - width=width, density=0.0) - cdef StateClass state - beam_maps = [] - backprops = [] - violns = [MaxViolation() for _ in range(len(states))] - for t in range(max_steps): - # The beam maps let us find the right row in the flattened scores - # arrays for each state. States are identified by (example id, history). - # We keep a different beam map for each step (since we'll have a flat - # scores array for each step). The beam map will let us take the per-state - # losses, and compute the gradient for each (step, state, class). - beam_maps.append({}) - # Gather all states from the two beams in a list. Some stats may occur - # in both beams. To figure out which beam each state belonged to, - # we keep two lists of indices, p_indices and g_indices - states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1], nr_update) - if not states: - break - # Now that we have our flat list of states, feed them through the model - token_ids = get_token_ids(states, nr_feature) - vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) - scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) - - # Store the callbacks for the backward pass - backprops.append((token_ids, bp_vectors, bp_scores)) - - # Unpack the flat scores into lists for the two beams. The indices arrays - # tell us which example and state the scores-row refers to. - p_scores = [numpy.ascontiguousarray(scores[indices], dtype='f') for indices in p_indices] - g_scores = [numpy.ascontiguousarray(scores[indices], dtype='f') for indices in g_indices] - # Now advance the states in the beams. The gold beam is contrained to - # to follow only gold analyses. - pbeam.advance(p_scores) - gbeam.advance(g_scores, follow_gold=True) - # Track the "maximum violation", to use in the update. - for i, violn in enumerate(violns): - violn.check_crf(pbeam[i], gbeam[i]) - - # Only make updates if we have non-gold states - histories = [((v.p_hist + v.g_hist) if v.p_hist else []) for v in violns] - losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns] - states_d_scores = get_gradient(moves.n_moves, beam_maps, - histories, losses) - assert len(states_d_scores) == len(backprops), (len(states_d_scores), len(backprops)) - return states_d_scores, backprops - - -def get_states(pbeams, gbeams, beam_map, nr_update): - seen = {} - states = [] - p_indices = [] - g_indices = [] - cdef Beam pbeam, gbeam - assert len(pbeams) == len(gbeams) - for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): - p_indices.append([]) - g_indices.append([]) - if pbeam.loss > 0 and pbeam.min_score > gbeam.score: - continue - for i in range(pbeam.size): - state = pbeam.at(i) - if not state.is_final(): - key = tuple([eg_id] + pbeam.histories[i]) - seen[key] = len(states) - p_indices[-1].append(len(states)) - states.append(state) - beam_map.update(seen) - for i in range(gbeam.size): - state = gbeam.at(i) - if not state.is_final(): - key = tuple([eg_id] + gbeam.histories[i]) - if key in seen: - g_indices[-1].append(seen[key]) - else: - g_indices[-1].append(len(states)) - beam_map[key] = len(states) - states.append(state) - p_idx = [numpy.asarray(idx, dtype='i') for idx in p_indices] - g_idx = [numpy.asarray(idx, dtype='i') for idx in g_indices] - return states, p_idx, g_idx - - -def get_gradient(nr_class, beam_maps, histories, losses): - """ - The global model assigns a loss to each parse. The beam scores - are additive, so the same gradient is applied to each action - in the history. This gives the gradient of a single *action* - for a beam state -- so we have "the gradient of loss for taking - action i given history H." - - Histories: Each hitory is a list of actions - Each candidate has a history - Each beam has multiple candidates - Each batch has multiple beams - So history is list of lists of lists of ints - """ - nr_step = len(beam_maps) - grads = [] - for beam_map in beam_maps: - if beam_map: - grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) - assert len(histories) == len(losses) - for eg_id, hists in enumerate(histories): - for loss, hist in zip(losses[eg_id], hists): - key = tuple([eg_id]) - for j, clas in enumerate(hist): - i = beam_maps[j][key] - # In step j, at state i action clas - # resulted in loss - grads[j][i, clas] += loss / len(histories) - key = key + tuple([clas]) - return grads - - diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 9aeeba441..c06851978 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -37,7 +37,6 @@ cdef cppclass StateC: this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) this._sent = calloc(length + (PADDING * 2), sizeof(TokenC)) this._ents = calloc(length + (PADDING * 2), sizeof(Entity)) - this.offset = 0 cdef int i for i in range(length + (PADDING * 2)): this._ents[i].end = -1 diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 9477449a5..29e8de0aa 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -385,7 +385,6 @@ cdef class ArcEager(TransitionSystem): for i in range(self.n_moves): if self.c[i].move == move and self.c[i].label == label: return self.c[i] - return Transition(clas=0, move=MISSING, label=0) def move_name(self, int move, attr_t label): label_str = self.strings[label] diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index f4f66f9fb..e96e28fcf 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -34,7 +34,6 @@ from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass from .parser cimport Parser -from ._beam_utils import is_gold DEBUG = False @@ -238,3 +237,16 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio raise Exception("Gold parse is not gold-standard") +def is_gold(StateClass state, GoldParse gold, StringStore strings): + predicted = set() + truth = set() + for i in range(gold.length): + if gold.cand_to_gold[i] is None: + continue + if state.safe_get(i).dep: + predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) + else: + predicted.add((i, state.H(i), 'ROOT')) + id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] + truth.add((id_, head, dep)) + return truth == predicted diff --git a/spacy/syntax/nn_parser.pxd b/spacy/syntax/nn_parser.pxd index 7ff4b9f9f..524718965 100644 --- a/spacy/syntax/nn_parser.pxd +++ b/spacy/syntax/nn_parser.pxd @@ -14,4 +14,8 @@ cdef class Parser: cdef readonly TransitionSystem moves cdef readonly object cfg + cdef void _parse_step(self, StateC* state, + const float* feat_weights, + int nr_class, int nr_feat, int nr_piece) nogil + #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index a193c96a3..0b39e2216 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -37,17 +37,14 @@ from preshed.maps cimport MapStruct from preshed.maps cimport map_get from thinc.api import layerize, chain, noop, clone -from thinc.neural import Model, Affine, ReLu, Maxout -from thinc.neural._classes.selu import SELU -from thinc.neural._classes.layernorm import LayerNorm +from thinc.neural import Model, Affine, ELU, ReLu, Maxout from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.util import get_array_module from .. import util from ..util import get_async, get_cuda_stream from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts -from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune -from .._ml import Residual, drop_layer +from .._ml import Tok2Vec, doc2feats, rebatch from ..compat import json_dumps from . import _parse_features @@ -62,11 +59,8 @@ from ..structs cimport TokenC from ..tokens.doc cimport Doc from ..strings cimport StringStore from ..gold cimport GoldParse -from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG -from . import _beam_utils +from ..attrs cimport TAG, DEP -USE_FINE_TUNE = True -BEAM_PARSE = True def get_templates(*args, **kwargs): return [] @@ -238,14 +232,11 @@ cdef class Parser: Base class of the DependencyParser and EntityRecognizer. """ @classmethod - def Model(cls, nr_class, token_vector_width=128, hidden_width=300, depth=1, **cfg): + def Model(cls, nr_class, token_vector_width=128, hidden_width=128, depth=1, **cfg): depth = util.env_opt('parser_hidden_depth', depth) token_vector_width = util.env_opt('token_vector_width', token_vector_width) hidden_width = util.env_opt('hidden_width', hidden_width) parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) - embed_size = util.env_opt('embed_size', 4000) - tensors = fine_tune(Tok2Vec(token_vector_width, embed_size, - preprocess=doc2feats())) if parser_maxout_pieces == 1: lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class, nF=cls.nr_feature, @@ -257,10 +248,15 @@ cdef class Parser: nI=token_vector_width) with Model.use_device('cpu'): - upper = chain( - clone(Residual(ReLu(hidden_width)), (depth-1)), - zero_init(Affine(nr_class, drop_factor=0.0)) - ) + if depth == 0: + upper = chain() + upper.is_noop = True + else: + upper = chain( + clone(Maxout(hidden_width), (depth-1)), + zero_init(Affine(nr_class, drop_factor=0.0)) + ) + upper.is_noop = False # TODO: This is an unfortunate hack atm! # Used to set input dimensions in network. lower.begin_training(lower.ops.allocate((500, token_vector_width))) @@ -272,7 +268,7 @@ cdef class Parser: 'hidden_width': hidden_width, 'maxout_pieces': parser_maxout_pieces } - return (tensors, lower, upper), cfg + return (lower, upper), cfg def __init__(self, Vocab vocab, moves=True, model=True, **cfg): """ @@ -348,21 +344,17 @@ cdef class Parser: The number of threads with which to work on the buffer in parallel. Yields (Doc): Documents, in order. """ - if BEAM_PARSE: - beam_width = 8 + cdef StateClass parse_state cdef Doc doc - cdef Beam beam + queue = [] for docs in cytoolz.partition_all(batch_size, docs): docs = list(docs) - tokvecs = [doc.tensor for doc in docs] + tokvecs = [d.tensor for d in docs] if beam_width == 1: parse_states = self.parse_batch(docs, tokvecs) else: - beams = self.beam_parse(docs, tokvecs, - beam_width=beam_width, beam_density=beam_density) - parse_states = [] - for beam in beams: - parse_states.append(beam.at(0)) + parse_states = self.beam_parse(docs, tokvecs, + beam_width=beam_width, beam_density=beam_density) self.set_annotations(docs, parse_states) yield from docs @@ -377,12 +369,8 @@ cdef class Parser: int nr_class, nr_feat, nr_piece, nr_dim, nr_state if isinstance(docs, Doc): docs = [docs] - if isinstance(tokvecses, np.ndarray): - tokvecses = [tokvecses] tokvecs = self.model[0].ops.flatten(tokvecses) - if USE_FINE_TUNE: - tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) nr_state = len(docs) nr_class = self.moves.n_moves @@ -406,20 +394,27 @@ cdef class Parser: cdef np.ndarray scores c_token_ids = token_ids.data c_is_valid = is_valid.data + cdef int has_hidden = not getattr(vec2scores, 'is_noop', False) while not next_step.empty(): - for i in range(next_step.size()): - st = next_step[i] - st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) - self.moves.set_valid(&c_is_valid[i*nr_class], st) + if not has_hidden: + for i in cython.parallel.prange( + next_step.size(), num_threads=6, nogil=True): + self._parse_step(next_step[i], + feat_weights, nr_class, nr_feat, nr_piece) + else: + for i in range(next_step.size()): + st = next_step[i] + st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) + self.moves.set_valid(&c_is_valid[i*nr_class], st) vectors = state2vec(token_ids[:next_step.size()]) - scores = vec2scores(vectors) - c_scores = scores.data - for i in range(next_step.size()): - st = next_step[i] - guess = arg_max_if_valid( - &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) - action = self.moves.c[guess] - action.do(st, action.label) + scores = vec2scores(vectors) + c_scores = scores.data + for i in range(next_step.size()): + st = next_step[i] + guess = arg_max_if_valid( + &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) + action = self.moves.c[guess] + action.do(st, action.label) this_step, next_step = next_step, this_step next_step.clear() for st in this_step: @@ -434,15 +429,11 @@ cdef class Parser: cdef int nr_class = self.moves.n_moves cdef StateClass stcls, output tokvecs = self.model[0].ops.flatten(tokvecses) - if USE_FINE_TUNE: - tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) cuda_stream = get_cuda_stream() state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0) beams = [] cdef int offset = 0 - cdef int j = 0 - cdef int k for doc in docs: beam = Beam(nr_class, beam_width, min_density=beam_density) beam.initialize(self.moves.init_beam_state, doc.length, doc.c) @@ -455,31 +446,44 @@ cdef class Parser: states = [] for i in range(beam.size): stcls = beam.at(i) - # This way we avoid having to score finalized states - # We do have to take care to keep indexes aligned, though - if not stcls.is_final(): - states.append(stcls) + states.append(stcls) token_ids = self.get_token_ids(states) vectors = state2vec(token_ids) scores = vec2scores(vectors) - j = 0 - c_scores = scores.data for i in range(beam.size): stcls = beam.at(i) if not stcls.is_final(): self.moves.set_valid(beam.is_valid[i], stcls.c) - for k in range(nr_class): - beam.scores[i][k] = c_scores[j * scores.shape[1] + k] - j += 1 + for j in range(nr_class): + beam.scores[i][j] = scores[i, j] beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) beams.append(beam) return beams + cdef void _parse_step(self, StateC* state, + const float* feat_weights, + int nr_class, int nr_feat, int nr_piece) nogil: + '''This only works with no hidden layers -- fast but inaccurate''' + #for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True): + # self._parse_step(next_step[i], feat_weights, nr_class, nr_feat) + token_ids = calloc(nr_feat, sizeof(int)) + scores = calloc(nr_class * nr_piece, sizeof(float)) + is_valid = calloc(nr_class, sizeof(int)) + + state.set_context_tokens(token_ids, nr_feat) + sum_state_features(scores, + feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece) + self.moves.set_valid(is_valid, state) + guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece) + action = self.moves.c[guess] + action.do(state, action.label) + + free(is_valid) + free(scores) + free(token_ids) + def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): - if BEAM_PARSE: - return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd, - losses=losses) if losses is not None and self.name not in losses: losses[self.name] = 0. docs, tokvec_lists = docs_tokvecs @@ -487,10 +491,6 @@ cdef class Parser: if isinstance(docs, Doc) and isinstance(golds, GoldParse): docs = [docs] golds = [golds] - if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - my_tokvecs = self.model[0].ops.flatten(my_tokvecs) - tokvecs += my_tokvecs cuda_stream = get_cuda_stream() @@ -517,13 +517,13 @@ cdef class Parser: scores, bp_scores = vec2scores.begin_update(vector, drop=drop) d_scores = self.get_batch_loss(states, golds, scores) - d_vector = bp_scores(d_scores, sgd=sgd) + d_vector = bp_scores(d_scores / d_scores.shape[0], sgd=sgd) if drop != 0: d_vector *= mask if isinstance(self.model[0].ops, CupyOps) \ and not isinstance(token_ids, state2vec.ops.xp.ndarray): - # Move token_ids and d_vector to GPU, asynchronously + # Move token_ids and d_vector to CPU, asynchronously backprops.append(( get_async(cuda_stream, token_ids), get_async(cuda_stream, d_vector), @@ -540,55 +540,7 @@ cdef class Parser: break self._make_updates(d_tokvecs, backprops, sgd, cuda_stream) - d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) - if USE_FINE_TUNE: - bp_my_tokvecs(d_tokvecs, sgd=sgd) - return d_tokvecs - - def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): - if losses is not None and self.name not in losses: - losses[self.name] = 0. - docs, tokvecs = docs_tokvecs - lengths = [len(d) for d in docs] - assert min(lengths) >= 1 - tokvecs = self.model[0].ops.flatten(tokvecs) - if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - my_tokvecs = self.model[0].ops.flatten(my_tokvecs) - tokvecs += my_tokvecs - - states = self.moves.init_batch(docs) - for gold in golds: - self.moves.preprocess_gold(gold) - - cuda_stream = get_cuda_stream() - state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, 0.0) - - states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500, - states, tokvecs, golds, - state2vec, vec2scores, - drop, sgd, losses, - width=8) - backprop_lower = [] - for i, d_scores in enumerate(states_d_scores): - if losses is not None: - losses[self.name] += (d_scores**2).sum() - ids, bp_vectors, bp_scores = backprops[i] - d_vector = bp_scores(d_scores, sgd=sgd) - if isinstance(self.model[0].ops, CupyOps) \ - and not isinstance(ids, state2vec.ops.xp.ndarray): - backprop_lower.append(( - get_async(cuda_stream, ids), - get_async(cuda_stream, d_vector), - bp_vectors)) - else: - backprop_lower.append((ids, d_vector, bp_vectors)) - d_tokvecs = self.model[0].ops.allocate(tokvecs.shape) - self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream) - d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths) - if USE_FINE_TUNE: - bp_my_tokvecs(d_tokvecs, sgd=sgd) - return d_tokvecs + return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) def _init_gold_batch(self, whole_docs, whole_golds): """Make a square batch, of length equal to the shortest doc. A long @@ -633,10 +585,14 @@ cdef class Parser: xp = get_array_module(d_tokvecs) for ids, d_vector, bp_vector in backprops: d_state_features = bp_vector(d_vector, sgd=sgd) - mask = ids >= 0 - indices = xp.nonzero(mask) - self.model[0].ops.scatter_add(d_tokvecs, ids[indices], - d_state_features[indices]) + active_feats = ids * (ids >= 0) + active_feats = active_feats.reshape((ids.shape[0], ids.shape[1], 1)) + if hasattr(xp, 'scatter_add'): + xp.scatter_add(d_tokvecs, + ids, d_state_features * active_feats) + else: + xp.add.at(d_tokvecs, + ids, d_state_features * active_feats) @property def move_names(self): @@ -647,7 +603,7 @@ cdef class Parser: return names def get_batch_model(self, batch_size, tokvecs, stream, dropout): - _, lower, upper = self.model + lower, upper = self.model state2vec = precompute_hiddens(batch_size, tokvecs, lower, stream, drop=dropout) return state2vec, upper @@ -737,12 +693,10 @@ cdef class Parser: def to_disk(self, path, **exclude): serializers = { - 'tok2vec_model': lambda p: p.open('wb').write( - self.model[0].to_bytes()), 'lower_model': lambda p: p.open('wb').write( - self.model[1].to_bytes()), + self.model[0].to_bytes()), 'upper_model': lambda p: p.open('wb').write( - self.model[2].to_bytes()), + self.model[1].to_bytes()), 'vocab': lambda p: self.vocab.to_disk(p), 'moves': lambda p: self.moves.to_disk(p, strings=False), 'cfg': lambda p: p.open('w').write(json_dumps(self.cfg)) @@ -763,29 +717,24 @@ cdef class Parser: self.model, cfg = self.Model(**self.cfg) else: cfg = {} - with (path / 'tok2vec_model').open('rb') as file_: - bytes_data = file_.read() - self.model[0].from_bytes(bytes_data) with (path / 'lower_model').open('rb') as file_: bytes_data = file_.read() - self.model[1].from_bytes(bytes_data) + self.model[0].from_bytes(bytes_data) with (path / 'upper_model').open('rb') as file_: bytes_data = file_.read() - self.model[2].from_bytes(bytes_data) + self.model[1].from_bytes(bytes_data) self.cfg.update(cfg) return self def to_bytes(self, **exclude): serializers = OrderedDict(( - ('tok2vec_model', lambda: self.model[0].to_bytes()), - ('lower_model', lambda: self.model[1].to_bytes()), - ('upper_model', lambda: self.model[2].to_bytes()), + ('lower_model', lambda: self.model[0].to_bytes()), + ('upper_model', lambda: self.model[1].to_bytes()), ('vocab', lambda: self.vocab.to_bytes()), ('moves', lambda: self.moves.to_bytes(strings=False)), ('cfg', lambda: ujson.dumps(self.cfg)) )) if 'model' in exclude: - exclude['tok2vec_model'] = True exclude['lower_model'] = True exclude['upper_model'] = True exclude.pop('model') @@ -796,7 +745,6 @@ cdef class Parser: ('vocab', lambda b: self.vocab.from_bytes(b)), ('moves', lambda b: self.moves.from_bytes(b, strings=False)), ('cfg', lambda b: self.cfg.update(ujson.loads(b))), - ('tok2vec_model', lambda b: None), ('lower_model', lambda b: None), ('upper_model', lambda b: None) )) @@ -806,12 +754,10 @@ cdef class Parser: self.model, cfg = self.Model(self.moves.n_moves) else: cfg = {} - if 'tok2vec_model' in msg: - self.model[0].from_bytes(msg['tok2vec_model']) if 'lower_model' in msg: - self.model[1].from_bytes(msg['lower_model']) + self.model[0].from_bytes(msg['lower_model']) if 'upper_model' in msg: - self.model[2].from_bytes(msg['upper_model']) + self.model[1].from_bytes(msg['upper_model']) self.cfg.update(cfg) return self diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index d3f64f827..27b375bba 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -107,8 +107,6 @@ cdef class TransitionSystem: def is_valid(self, StateClass stcls, move_name): action = self.lookup_transition(move_name) - if action.move == 0: - return False return action.is_valid(stcls.c, action.label) cdef int set_valid(self, int* is_valid, const StateC* st) nogil: diff --git a/spacy/tests/parser/test_neural_parser.py b/spacy/tests/parser/test_neural_parser.py index 30a6367c8..42b55745f 100644 --- a/spacy/tests/parser/test_neural_parser.py +++ b/spacy/tests/parser/test_neural_parser.py @@ -78,16 +78,3 @@ def test_predict_doc_beam(parser, tok2vec, model, doc): parser(doc, beam_width=32, beam_density=0.001) for word in doc: print(word.text, word.head, word.dep_) - - -def test_update_doc_beam(parser, tok2vec, model, doc, gold): - parser.model = model - tokvecs, bp_tokvecs = tok2vec.begin_update([doc]) - d_tokvecs = parser.update_beam(([doc], tokvecs), [gold]) - assert d_tokvecs[0].shape == tokvecs[0].shape - def optimize(weights, gradient, key=None): - weights -= 0.001 * gradient - bp_tokvecs(d_tokvecs, sgd=optimize) - assert d_tokvecs[0].sum() == 0. - - diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py deleted file mode 100644 index 45c85d969..000000000 --- a/spacy/tests/parser/test_nn_beam.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import unicode_literals -import pytest -import numpy -from thinc.api import layerize - -from ...vocab import Vocab -from ...syntax.arc_eager import ArcEager -from ...tokens import Doc -from ...gold import GoldParse -from ...syntax._beam_utils import ParserBeam, update_beam -from ...syntax.stateclass import StateClass - - -@pytest.fixture -def vocab(): - return Vocab() - -@pytest.fixture -def moves(vocab): - aeager = ArcEager(vocab.strings, {}) - aeager.add_action(2, 'nsubj') - aeager.add_action(3, 'dobj') - aeager.add_action(2, 'aux') - return aeager - - -@pytest.fixture -def docs(vocab): - return [Doc(vocab, words=['Rats', 'bite', 'things'])] - -@pytest.fixture -def states(docs): - return [StateClass(doc) for doc in docs] - -@pytest.fixture -def tokvecs(docs, vector_size): - output = [] - for doc in docs: - vec = numpy.random.uniform(-0.1, 0.1, (len(doc), vector_size)) - output.append(numpy.asarray(vec)) - return output - - -@pytest.fixture -def golds(docs): - return [GoldParse(doc) for doc in docs] - - -@pytest.fixture -def batch_size(docs): - return len(docs) - - -@pytest.fixture -def beam_width(): - return 4 - - -@pytest.fixture -def vector_size(): - return 6 - - -@pytest.fixture -def beam(moves, states, golds, beam_width): - return ParserBeam(moves, states, golds, width=beam_width) - -@pytest.fixture -def scores(moves, batch_size, beam_width): - return [ - numpy.asarray( - numpy.random.uniform(-0.1, 0.1, (batch_size, moves.n_moves)), - dtype='f') - for _ in range(batch_size)] - - -def test_create_beam(beam): - pass - - -def test_beam_advance(beam, scores): - beam.advance(scores) - - -def test_beam_advance_too_few_scores(beam, scores): - with pytest.raises(IndexError): - beam.advance(scores[:-1]) diff --git a/spacy/util.py b/spacy/util.py index ccb81fbed..d83fe3416 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -113,7 +113,7 @@ def load_model(name, **overrides): def load_model_from_link(name, **overrides): """Load a model from a shortcut link, or directory in spaCy data path.""" init_file = get_data_path() / name / '__init__.py' - spec = importlib.util.spec_from_file_location(name, init_file) + spec = importlib.util.spec_from_file_location(name, str(init_file)) try: cls = importlib.util.module_from_spec(spec) except AttributeError: diff --git a/website/_includes/_mixins.jade b/website/_includes/_mixins.jade index 16514bcda..b140151b2 100644 --- a/website/_includes/_mixins.jade +++ b/website/_includes/_mixins.jade @@ -103,20 +103,20 @@ mixin button(url, trusted, ...style) label - [string] aside title (optional or false for no label) language - [string] language for syntax highlighting (default: "python") supports basic relevant languages available for PrismJS - icon - [string] icon to display next to code block, mostly used for old/new + prompt - [string] prompt or icon to display next to code block, (mostly used for old/new) height - [integer] optional height to clip code block to -mixin code(label, language, icon, height) +mixin code(label, language, prompt, height) pre.c-code-block.o-block(class="lang-#{(language || DEFAULT_SYNTAX)}" class=icon ? "c-code-block--has-icon" : null style=height ? "height: #{height}px" : null)&attributes(attributes) if label h4.u-text-label.u-text-label--dark=label - + - var icon = (prompt == 'accept' || prompt == 'reject') if icon - var classes = {'accept': 'u-color-green', 'reject': 'u-color-red'} .c-code-block__icon(class=classes[icon] || null class=classes[icon] ? "c-code-block__icon--border" : null) +icon(icon, 18) - code.c-code-block__content + code.c-code-block__content(data-prompt=icon ? null : prompt) block diff --git a/website/assets/css/_components/_code.sass b/website/assets/css/_components/_code.sass index 2e1856c0a..036c5358f 100644 --- a/website/assets/css/_components/_code.sass +++ b/website/assets/css/_components/_code.sass @@ -35,6 +35,13 @@ font: normal normal 1.1rem/#{2} $font-code padding: 1em 2em + &[data-prompt]:before, + content: attr(data-prompt) + margin-right: 0.65em + display: inline-block + vertical-align: middle + opacity: 0.5 + //- Inline code diff --git a/website/docs/api/cli.jade b/website/docs/api/cli.jade index e109e4b66..26aa1f883 100644 --- a/website/docs/api/cli.jade +++ b/website/docs/api/cli.jade @@ -5,16 +5,7 @@ include ../../_includes/_mixins p | As of v1.7.0, spaCy comes with new command line helpers to download and | link models and show useful debugging information. For a list of available - | commands, type #[code python -m spacy]. To make the command even more - | convenient, we recommend - | #[+a("https://askubuntu.com/questions/17536/how-do-i-create-a-permanent-bash-alias/17537#17537") creating an alias] - | mapping #[code python -m spacy] to #[code spacy]. - -+aside("Why python -m?") - | The problem with a global entry point is that it's resolved by looking up - | entries in your #[code PATH] environment variable. This can give you - | unexpected results, like executing the wrong spaCy installation. - | #[code python -m] prevents fallbacks to system modules. + | commands, type #[code spacy --help]. +infobox("⚠️ Deprecation note") | As of spaCy 2.0, the #[code model] command to initialise a model data @@ -33,8 +24,8 @@ p | Direct downloads don't perform any compatibility checks and require the | model name to be specified with its version (e.g., #[code en_core_web_sm-1.2.0]). -+code(false, "bash"). - python -m spacy download [model] [--direct] ++code(false, "bash", "$"). + spacy download [model] [--direct] +table(["Argument", "Type", "Description"]) +row @@ -80,8 +71,8 @@ p | or use the #[+api("cli#package") #[code package]] command to create a | model package. -+code(false, "bash"). - python -m spacy link [origin] [link_name] [--force] ++code(false, "bash", "$"). + spacy link [origin] [link_name] [--force] +table(["Argument", "Type", "Description"]) +row @@ -112,8 +103,8 @@ p | markup to copy-paste into #[+a(gh("spacy") + "/issues") GitHub issues]. +code(false, "bash"). - python -m spacy info [--markdown] - python -m spacy info [model] [--markdown] + spacy info [--markdown] + spacy info [model] [--markdown] +table(["Argument", "Type", "Description"]) +row @@ -139,8 +130,8 @@ p | functions. The right converter is chosen based on the file extension of | the input file. Currently only supports #[code .conllu]. -+code(false, "bash"). - python -m spacy convert [input_file] [output_dir] [--n-sents] [--morphology] ++code(false, "bash", "$"). + spacy convert [input_file] [output_dir] [--n-sents] [--morphology] +table(["Argument", "Type", "Description"]) +row @@ -174,8 +165,8 @@ p | Train a model. Expects data in spaCy's | #[+a("/docs/api/annotation#json-input") JSON format]. -+code(false, "bash"). - python -m spacy train [lang] [output_dir] [train_data] [dev_data] [--n-iter] [--n-sents] [--use-gpu] [--no-tagger] [--no-parser] [--no-entities] ++code(false, "bash", "$"). + spacy train [lang] [output_dir] [train_data] [dev_data] [--n-iter] [--n-sents] [--use-gpu] [--no-tagger] [--no-parser] [--no-entities] +table(["Argument", "Type", "Description"]) +row @@ -345,8 +336,8 @@ p | sure you're always using the latest versions. This means you need to be | connected to the internet to use this command. -+code(false, "bash"). - python -m spacy package [input_dir] [output_dir] [--meta] [--force] ++code(false, "bash", "$"). + spacy package [input_dir] [output_dir] [--meta] [--force] +table(["Argument", "Type", "Description"]) +row @@ -360,10 +351,17 @@ p +cell Directory to create package folder in. +row - +cell #[code meta] + +cell #[code --meta-path], #[code -m] +cell option +cell Path to meta.json file (optional). + +row + +cell #[code --create-meta], #[code -c] + +cell flag + +cell + | Create a meta.json file on the command line, even if one already + | exists in the directory. + +row +cell #[code --force], #[code -f] +cell flag diff --git a/website/docs/api/language-models.jade b/website/docs/api/language-models.jade index 74007f228..c6943b410 100644 --- a/website/docs/api/language-models.jade +++ b/website/docs/api/language-models.jade @@ -8,9 +8,9 @@ p +aside-code("Download language models", "bash"). - python -m spacy download en - python -m spacy download de - python -m spacy download fr + spacy download en + spacy download de + spacy download fr +table([ "Language", "Token", "SBD", "Lemma", "POS", "NER", "Dep", "Vector", "Sentiment"]) +row diff --git a/website/docs/usage/adding-languages.jade b/website/docs/usage/adding-languages.jade index 4cd65a62d..b341c9f9b 100644 --- a/website/docs/usage/adding-languages.jade +++ b/website/docs/usage/adding-languages.jade @@ -789,4 +789,4 @@ p | model use the using spaCy's #[+api("cli#train") #[code train]] command: +code(false, "bash"). - python -m spacy train [lang] [output_dir] [train_data] [dev_data] [--n-iter] [--n-sents] [--use-gpu] [--no-tagger] [--no-parser] [--no-entities] + spacy train [lang] [output_dir] [train_data] [dev_data] [--n-iter] [--n-sents] [--use-gpu] [--no-tagger] [--no-parser] [--no-entities] diff --git a/website/docs/usage/index.jade b/website/docs/usage/index.jade index 817b08ba9..60bc3cd7b 100644 --- a/website/docs/usage/index.jade +++ b/website/docs/usage/index.jade @@ -32,10 +32,10 @@ p +qs({package: 'source'}) pip install -r requirements.txt +qs({package: 'source'}) pip install -e . - +qs({model: 'en'}) python -m spacy download en - +qs({model: 'de'}) python -m spacy download de - +qs({model: 'fr'}) python -m spacy download fr - +qs({model: 'es'}) python -m spacy download es + +qs({model: 'en'}) spacy download en + +qs({model: 'de'}) spacy download de + +qs({model: 'fr'}) spacy download fr + +qs({model: 'es'}) spacy download es +h(2, "installation") Installation instructions @@ -52,7 +52,7 @@ p Using pip, spaCy releases are currently only available as source packages. | and available models, see the #[+a("/docs/usage/models") docs on models]. +code.o-no-block. - python -m spacy download en + spacy download en >>> import spacy >>> nlp = spacy.load('en') @@ -312,7 +312,9 @@ p | This error may occur when running the #[code spacy] command from the | command line. spaCy does not currently add an entry to our #[code PATH] | environment variable, as this can lead to unexpected results, especially - | when using #[code virtualenv]. Run the command with #[code python -m], + | when using #[code virtualenv]. Instead, spaCy adds an auto-alias that + | maps #[code spacy] to #[code python -m spacy]. If this is not working as + | expected, run the command with #[code python -m], yourself – | for example #[code python -m spacy download en]. For more info on this, | see #[+api("cli#download") download]. diff --git a/website/docs/usage/lightning-tour.jade b/website/docs/usage/lightning-tour.jade index 0be3a55be..2b0cf0880 100644 --- a/website/docs/usage/lightning-tour.jade +++ b/website/docs/usage/lightning-tour.jade @@ -10,8 +10,8 @@ p +h(2, "models") Install models and process text +code(false, "bash"). - python -m spacy download en - python -m spacy download de + spacy download en + spacy download de +code. import spacy diff --git a/website/docs/usage/models.jade b/website/docs/usage/models.jade index 39c37a816..bae80d2ad 100644 --- a/website/docs/usage/models.jade +++ b/website/docs/usage/models.jade @@ -20,7 +20,7 @@ p +quickstart(QUICKSTART_MODELS, "Quickstart", "Install a default model, get the code to load it from within spaCy and an example to test it. For more options, see the section on available models below.") for models, lang in MODELS - var package = (models.length == 1) ? models[0] : models.find(function(m) { return m.def }) - +qs({lang: lang}) python -m spacy download #{lang} + +qs({lang: lang}) spacy download #{lang} +qs({lang: lang}, "divider") +qs({lang: lang, load: "module"}, "python") import #{package.id} +qs({lang: lang, load: "module"}, "python") nlp = #{package.id}.load() @@ -52,16 +52,16 @@ p | #[+api("cli#download") #[code download]] command. It takes care of | finding the best-matching model compatible with your spaCy installation. -- var models = Object.keys(MODELS).map(function(lang) { return "python -m spacy download " + lang }) +- var models = Object.keys(MODELS).map(function(lang) { return "spacy download " + lang }) +code(false, "bash"). # out-of-the-box: download best-matching default model - #{Object.keys(MODELS).map(function(l) {return "python -m spacy download " + l}).join('\n')} + #{Object.keys(MODELS).map(function(l) {return "spacy download " + l}).join('\n')} # download best-matching version of specific model for your spaCy installation - python -m spacy download en_core_web_md + spacy download en_core_web_md # download exact model version (doesn't create shortcut link) - python -m spacy download en_core_web_md-1.2.0 --direct + spacy download en_core_web_md-1.2.0 --direct p | The download command will #[+a("#download-pip") install the model] via @@ -72,7 +72,7 @@ p +code(false, "bash"). pip install spacy - python -m spacy download en + spacy download en +code. import spacy @@ -179,8 +179,8 @@ p | model names or IDs. And your system already comes with a native solution | to mapping unicode aliases to file paths: symbolic links. -+code(false, "bash"). - python -m spacy link [package name or path] [shortcut] [--force] ++code(false, "bash", "$"). + spacy link [package name or path] [shortcut] [--force] p | The first argument is the #[strong package name] (if the model was diff --git a/website/docs/usage/saving-loading.jade b/website/docs/usage/saving-loading.jade index 827b54748..de7e4ed33 100644 --- a/website/docs/usage/saving-loading.jade +++ b/website/docs/usage/saving-loading.jade @@ -85,7 +85,7 @@ p } +code(false, "bash"). - python -m spacy package /home/me/data/en_example_model /home/me/my_models + spacy package /home/me/data/en_example_model /home/me/my_models p This command will create a model package directory that should look like this: diff --git a/website/docs/usage/training-ner.jade b/website/docs/usage/training-ner.jade index 3d732b16d..3c74f7a9d 100644 --- a/website/docs/usage/training-ner.jade +++ b/website/docs/usage/training-ner.jade @@ -102,7 +102,7 @@ p | CLI command to create all required files and directories. +code(false, "bash"). - python -m spacy package /home/me/data/en_technology /home/me/my_models + spacy package /home/me/data/en_technology /home/me/my_models p | To build the package and create a #[code .tar.gz] archive, run diff --git a/website/docs/usage/v2.jade b/website/docs/usage/v2.jade index d9727c62b..6d98e3f05 100644 --- a/website/docs/usage/v2.jade +++ b/website/docs/usage/v2.jade @@ -238,11 +238,11 @@ p +h(3, "features-models") Neural network models for English, German, French, Spanish and multi-language NER +aside-code("Example", "bash"). - python -m spacy download en # default English model - python -m spacy download de # default German model - python -m spacy download fr # default French model - python -m spacy download es # default Spanish model - python -m spacy download xx_ent_wiki_sm # multi-language NER + spacy download en # default English model + spacy download de # default German model + spacy download fr # default French model + spacy download es # default Spanish model + spacy download xx_ent_wiki_sm # multi-language NER p | spaCy v2.0 comes with new and improved neural network models for English, diff --git a/website/docs/usage/visualizers.jade b/website/docs/usage/visualizers.jade index b3cbd3b46..96a6bd49f 100644 --- a/website/docs/usage/visualizers.jade +++ b/website/docs/usage/visualizers.jade @@ -259,7 +259,7 @@ p | notebook, the visualizations will be included as HTML. +code("Jupyter Example"). - # don't forget to install a model, e.g.: python -m spacy download en + # don't forget to install a model, e.g.: spacy download en import spacy from spacy import displacy diff --git a/website/index.jade b/website/index.jade index 741db53cf..9336d5c34 100644 --- a/website/index.jade +++ b/website/index.jade @@ -68,7 +68,7 @@ include _includes/_mixins +grid +grid-col("two-thirds") +terminal("lightning_tour.py"). - # Install: pip install spacy && python -m spacy download en + # Install: pip install spacy && spacy download en import spacy # Load English tokenizer, tagger, parser, NER and word vectors