diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index 4eae11c75..ab69285a6 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -25,7 +25,7 @@ For more details, see the documentation: * Saving and loading models: https://spacy.io/docs/usage/saving-loading Developed for: spaCy 1.7.6 -Last tested for: spaCy 1.7.6 +Last updated for: spaCy 2.0.0a13 """ from __future__ import unicode_literals, print_function @@ -34,55 +34,41 @@ from pathlib import Path import random import spacy -from spacy.gold import GoldParse -from spacy.tagger import Tagger +from spacy.gold import GoldParse, minibatch +from spacy.pipeline import NeuralEntityRecognizer +from spacy.pipeline import TokenVectorEncoder +def get_gold_parses(tokenizer, train_data): + '''Shuffle and create GoldParse objects''' + random.shuffle(train_data) + for raw_text, entity_offsets in train_data: + doc = tokenizer(raw_text) + gold = GoldParse(doc, entities=entity_offsets) + yield doc, gold + + def train_ner(nlp, train_data, output_dir): - # Add new words to vocab - for raw_text, _ in train_data: - doc = nlp.make_doc(raw_text) - for word in doc: - _ = nlp.vocab[word.orth] random.seed(0) - # You may need to change the learning rate. It's generally difficult to - # guess what rate you should set, especially when you have limited data. - nlp.entity.model.learn_rate = 0.001 - for itn in range(1000): - random.shuffle(train_data) - loss = 0. - for raw_text, entity_offsets in train_data: - gold = GoldParse(doc, entities=entity_offsets) - # By default, the GoldParse class assumes that the entities - # described by offset are complete, and all other words should - # have the tag 'O'. You can tell it to make no assumptions - # about the tag of a word by giving it the tag '-'. - # However, this allows a trivial solution to the current - # learning problem: if words are either 'any tag' or 'ANIMAL', - # the model can learn that all words can be tagged 'ANIMAL'. - #for i in range(len(gold.ner)): - #if not gold.ner[i].endswith('ANIMAL'): - # gold.ner[i] = '-' - doc = nlp.make_doc(raw_text) - nlp.tagger(doc) - # As of 1.9, spaCy's parser now lets you supply a dropout probability - # This might help the model generalize better from only a few - # examples. - loss += nlp.entity.update(doc, gold, drop=0.9) - if loss == 0: - break - # This step averages the model's weights. This may or may not be good for - # your situation --- it's empirical. - nlp.end_training() - if output_dir: - if not output_dir.exists(): - output_dir.mkdir() - nlp.save_to_directory(output_dir) + optimizer = nlp.begin_training(lambda: []) + nlp.meta['name'] = 'en_ent_animal' + for itn in range(50): + losses = {} + for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3): + docs, golds = zip(*batch) + nlp.update(docs, golds, losses=losses, sgd=optimizer, update_shared=True, + drop=0.35) + print(losses) + if not output_dir: + return + elif not output_dir.exists(): + output_dir.mkdir() + nlp.to_disk(output_dir) def main(model_name, output_directory=None): - print("Loading initial model", model_name) - nlp = spacy.load(model_name) + print("Creating initial model", model_name) + nlp = spacy.blank(model_name) if output_directory is not None: output_directory = Path(output_directory) @@ -91,6 +77,11 @@ def main(model_name, output_directory=None): "Horses are too tall and they pretend to care about your feelings", [(0, 6, 'ANIMAL')], ), + ( + "Do they bite?", + [], + ), + ( "horses are too tall and they pretend to care about your feelings", [(0, 6, 'ANIMAL')] @@ -109,18 +100,20 @@ def main(model_name, output_directory=None): ) ] - nlp.entity.add_label('ANIMAL') + nlp.pipeline.append(TokenVectorEncoder(nlp.vocab)) + nlp.pipeline.append(NeuralEntityRecognizer(nlp.vocab)) + nlp.pipeline[-1].add_label('ANIMAL') train_ner(nlp, train_data, output_directory) # Test that the entity is recognized - doc = nlp('Do you like horses?') + text = 'Do you like horses?' print("Ents in 'Do you like horses?':") + doc = nlp(text) for ent in doc.ents: print(ent.label_, ent.text) if output_directory: print("Loading from", output_directory) - nlp2 = spacy.load('en', path=output_directory) - nlp2.entity.add_label('ANIMAL') + nlp2 = spacy.load(output_directory) doc2 = nlp2('Do you like horses?') for ent in doc2.ents: print(ent.label_, ent.text) diff --git a/spacy/_ml.py b/spacy/_ml.py index e6437cdcf..003541f4b 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -229,20 +229,18 @@ def drop_layer(layer, factor=2.): def Tok2Vec(width, embed_size, preprocess=None): cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): - norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower') - prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix') - suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix') - shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape') + norm = HashEmbed(width, embed_size, column=cols.index(NORM), name='embed_norm') + prefix = HashEmbed(width, embed_size//2, column=cols.index(PREFIX), name='embed_prefix') + suffix = HashEmbed(width, embed_size//2, column=cols.index(SUFFIX), name='embed_suffix') + shape = HashEmbed(width, embed_size//2, column=cols.index(SHAPE), name='embed_shape') embed = (norm | prefix | suffix | shape ) >> LN(Maxout(width, width*4, pieces=3)) tok2vec = ( with_flatten( asarray(Model.ops, dtype='uint64') >> uniqued(embed, column=5) - >> drop_layer( - Residual( - (ExtractWindow(nW=1) >> LN(Maxout(width, width*3))) - ) + >> Residual( + (ExtractWindow(nW=1) >> LN(Maxout(width, width*3))) ) ** 4, pad=4 ) ) @@ -372,6 +370,7 @@ def fine_tune(embedding, combine=None): "fine_tune currently only supports addition. Set combine=None") def fine_tune_fwd(docs_tokvecs, drop=0.): docs, tokvecs = docs_tokvecs + lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i') vecs, bp_vecs = embedding.begin_update(docs, drop=drop) @@ -556,7 +555,7 @@ def build_text_classifier(nr_class, width=64, **cfg): cnn_model = ( # TODO Make concatenate support lists - concatenate_lists(trained_vectors, static_vectors) + concatenate_lists(trained_vectors, static_vectors) >> with_flatten( LN(Maxout(width, width*2)) >> Residual( diff --git a/spacy/about.py b/spacy/about.py index d566fbb1f..40444ffd1 100644 --- a/spacy/about.py +++ b/spacy/about.py @@ -3,7 +3,7 @@ # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py __title__ = 'spacy-nightly' -__version__ = '2.0.0a13' +__version__ = '2.0.0a14' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __uri__ = 'https://spacy.io' __author__ = 'Explosion AI' diff --git a/spacy/cli/train.py b/spacy/cli/train.py index b2c87d2b5..7ad94ce9c 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -72,8 +72,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, util.env_opt('batch_compound', 1.001)) if resume: - prints(output_path / 'model19.pickle', title="Resuming training") - nlp = dill.load((output_path / 'model19.pickle').open('rb')) + prints(output_path / 'model9.pickle', title="Resuming training") + nlp = dill.load((output_path / 'model9.pickle').open('rb')) else: nlp = lang_class(pipeline=pipeline) corpus = GoldCorpus(train_path, dev_path, limit=n_sents) @@ -88,7 +88,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, if resume: i += 20 with tqdm.tqdm(total=n_train_words, leave=False) as pbar: - train_docs = corpus.train_docs(nlp, projectivize=True, + train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0, gold_preproc=gold_preproc, max_length=0) losses = {} for batch in minibatch(train_docs, size=batch_sizes): diff --git a/spacy/gold.pyx b/spacy/gold.pyx index f00d04109..fc8d6622b 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -7,6 +7,7 @@ import re import ujson import random import cytoolz +import itertools from .syntax import nonproj from .util import ensure_path @@ -146,9 +147,13 @@ def minibatch(items, size=8): '''Iterate over batches of items. `size` may be an iterator, so that batch-size can vary on each step. ''' + if isinstance(size, int): + size_ = itertools.repeat(8) + else: + size_ = size items = iter(items) while True: - batch_size = next(size) #if hasattr(size, '__next__') else size + batch_size = next(size_) batch = list(cytoolz.take(int(batch_size), items)) if len(batch) == 0: break diff --git a/spacy/language.py b/spacy/language.py index 66b42ff94..e6a5304dd 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -347,15 +347,9 @@ class Language(object): """Allocate models, pre-process training data and acquire a trainer and optimizer. Used as a contextmanager. - gold_tuples (iterable): Gold-standard training data. + get_gold_tuples (function): Function returning gold data **cfg: Config parameters. - YIELDS (tuple): A trainer and an optimizer. - - EXAMPLE: - >>> with nlp.begin_training(gold, use_gpu=True) as (trainer, optimizer): - >>> for epoch in trainer.epochs(gold): - >>> for docs, golds in epoch: - >>> state = nlp.update(docs, golds, sgd=optimizer) + returns: An optimizer """ if self.parser: self.pipeline.append(NeuralLabeller(self.vocab)) diff --git a/spacy/lemmatizer.py b/spacy/lemmatizer.py index 4d534b50f..312c8db72 100644 --- a/spacy/lemmatizer.py +++ b/spacy/lemmatizer.py @@ -38,7 +38,8 @@ class Lemmatizer(object): avoid lemmatization entirely. """ morphology = {} if morphology is None else morphology - others = [key for key in morphology if key not in (POS, 'number', 'pos', 'verbform')] + others = [key for key in morphology + if key not in (POS, 'Number', 'POS', 'VerbForm', 'Tense')] true_morph_key = morphology.get('morph', 0) if univ_pos == 'noun' and morphology.get('Number') == 'sing': return True @@ -47,7 +48,9 @@ class Lemmatizer(object): # This maps 'VBP' to base form -- probably just need 'IS_BASE' # morphology elif univ_pos == 'verb' and (morphology.get('VerbForm') == 'fin' and \ - morphology.get('Tense') == 'pres'): + morphology.get('Tense') == 'pres' and \ + morphology.get('Number') is None and \ + not others): return True elif univ_pos == 'adj' and morphology.get('Degree') == 'pos': return True diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 3da9e5d4c..9a08691de 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -101,9 +101,10 @@ cdef cppclass StateC: elif n == 6: if this.B(0) >= 0: ids[0] = this.B(0) + ids[1] = this.B(0)-1 else: ids[0] = -1 - ids[1] = this.B(0) + ids[1] = -1 ids[2] = this.B(1) ids[3] = this.E(0) if ids[3] >= 1: @@ -118,8 +119,12 @@ cdef cppclass StateC: # TODO error =/ pass for i in range(n): + # Token vectors should be padded, so that there's a vector for + # missing values at the start. if ids[i] >= 0: - ids[i] += this.offset + ids[i] += this.offset + 1 + else: + ids[i] = 0 int S(int i) nogil const: if i >= this._s_i: @@ -162,9 +167,9 @@ cdef cppclass StateC: int E(int i) nogil const: if this._e_i <= 0 or this._e_i >= this.length: - return 0 + return -1 if i < 0 or i >= this._e_i: - return 0 + return -1 return this._ents[this._e_i - (i+1)].start int L(int i, int idx) nogil const: diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 2f5cd4e48..11b429aa2 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -220,6 +220,31 @@ cdef class BiluoPushDown(TransitionSystem): raise Exception(move) return t + def add_action(self, int action, label_name): + cdef attr_t label_id + if not isinstance(label_name, (int, long)): + label_id = self.strings.add(label_name) + else: + label_id = label_name + if action == OUT and label_id != 0: + return + if action == MISSING or action == ISNT: + return + # Check we're not creating a move we already have, so that this is + # idempotent + for trans in self.c[:self.n_moves]: + if trans.move == action and trans.label == label_id: + return 0 + if self.n_moves >= self._size: + self._size *= 2 + self.c = self.mem.realloc(self.c, self._size * sizeof(self.c[0])) + self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id) + assert self.c[self.n_moves].label == label_id + self.n_moves += 1 + return 1 + + + cdef int initialize_state(self, StateC* st) nogil: # This is especially necessary when we use limited training data. for i in range(st.length): diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 34e504da9..ad6ed280e 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -393,9 +393,8 @@ cdef class Parser: tokvecs = self.model[0].ops.flatten(tokvecses) if USE_FINE_TUNE: - # TODO: This is incorrect! Unhack when training next model - tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) - + tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) + tokvecs = self._pad_tokvecs(tokvecs) nr_state = len(docs) nr_class = self.moves.n_moves nr_dim = tokvecs.shape[1] @@ -455,6 +454,7 @@ cdef class Parser: tokvecs = self.model[0].ops.flatten(tokvecses) if USE_FINE_TUNE: tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses))) + tokvecs = self._pad_tokvecs(tokvecs) cuda_stream = get_cuda_stream() state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0) @@ -532,8 +532,10 @@ cdef class Parser: docs = [docs] golds = [golds] if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - tokvecs += self.model[0].ops.flatten(my_tokvecs) + tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) + tokvecs = self.model[0].ops.flatten(tokvecs) + + tokvecs = self._pad_tokvecs(tokvecs) cuda_stream = get_cuda_stream() @@ -584,6 +586,7 @@ cdef class Parser: break self._make_updates(d_tokvecs, backprops, sgd, cuda_stream) + d_tokvecs = self._unpad_tokvecs(d_tokvecs) d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) if USE_FINE_TUNE: d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd) @@ -606,8 +609,8 @@ cdef class Parser: assert min(lengths) >= 1 tokvecs = self.model[0].ops.flatten(tokvecs) if USE_FINE_TUNE: - my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) - tokvecs += self.model[0].ops.flatten(my_tokvecs) + tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) + tokvecs = self.model[0].ops.flatten(tokvecs) states = self.moves.init_batch(docs) for gold in golds: @@ -640,10 +643,20 @@ cdef class Parser: d_tokvecs = self.model[0].ops.allocate(tokvecs.shape) self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream) d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths) + d_tokvecs = self._unpad_tokvecs(d_tokvecs) if USE_FINE_TUNE: d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd) return d_tokvecs + def _pad_tokvecs(self, tokvecs): + # Add a vector for missing values at the start of tokvecs + xp = get_array_module(tokvecs) + pad = xp.zeros((1, tokvecs.shape[1]), dtype=tokvecs.dtype) + return xp.vstack((pad, tokvecs)) + + def _unpad_tokvecs(self, d_tokvecs): + return d_tokvecs[1:] + def _init_gold_batch(self, whole_docs, whole_golds): """Make a square batch, of length equal to the shortest doc. A long doc will get multiple states. Let's say we have a doc of length 2*N, @@ -706,7 +719,7 @@ cdef class Parser: lower, stream, drop=dropout) return state2vec, upper - nr_feature = 13 + nr_feature = 8 def get_token_ids(self, states): cdef StateClass state diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 9cf82e0c7..055129c8b 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -148,7 +148,7 @@ cdef class TransitionSystem: def add_action(self, int action, label_name): cdef attr_t label_id - if not isinstance(label_name, int): + if not isinstance(label_name, (int, long)): label_id = self.strings.add(label_name) else: label_id = label_name diff --git a/spacy/tests/regression/test_issue1305.py b/spacy/tests/regression/test_issue1305.py new file mode 100644 index 000000000..e123ce0ba --- /dev/null +++ b/spacy/tests/regression/test_issue1305.py @@ -0,0 +1,8 @@ +import pytest + +@pytest.mark.models('en') +def test_issue1305(EN): + '''Test lemmatization of English VBZ''' + assert EN.vocab.morphology.lemmatizer('works', 'verb') == set(['work']) + doc = EN(u'This app works well') + assert doc[2].lemma_ == 'work' diff --git a/spacy/tests/regression/test_issue429.py b/spacy/tests/regression/test_issue429.py index 1baa9a1db..74f12bd9f 100644 --- a/spacy/tests/regression/test_issue429.py +++ b/spacy/tests/regression/test_issue429.py @@ -9,11 +9,14 @@ import pytest @pytest.mark.models('en') def test_issue429(EN): def merge_phrases(matcher, doc, i, matches): - if i != len(matches) - 1: - return None - spans = [(ent_id, ent_id, doc[start:end]) for ent_id, start, end in matches] - for ent_id, label, span in spans: - span.merge('NNP' if label else span.root.tag_, span.text, EN.vocab.strings[label]) + if i != len(matches) - 1: + return None + spans = [(ent_id, ent_id, doc[start:end]) for ent_id, start, end in matches] + for ent_id, label, span in spans: + span.merge( + tag=('NNP' if label else span.root.tag_), + lemma=span.text, + label='PERSON') doc = EN('a') matcher = Matcher(EN.vocab) diff --git a/website/docs/usage/customizing-tokenizer.jade b/website/docs/usage/customizing-tokenizer.jade index 7e0b4b479..0bc81771d 100644 --- a/website/docs/usage/customizing-tokenizer.jade +++ b/website/docs/usage/customizing-tokenizer.jade @@ -282,7 +282,7 @@ p def __call__(self, text): words = text.split(' ') # All tokens 'own' a subsequent space character in this tokenizer - spaces = [True] * len(word) + spaces = [True] * len(words) return Doc(self.vocab, words=words, spaces=spaces) p