Merge branch 'develop' of https://github.com/explosion/spaCy into develop

This commit is contained in:
Matthew Honnibal 2017-09-14 09:21:20 -05:00
commit 8496d76224
14 changed files with 138 additions and 90 deletions

View File

@ -25,7 +25,7 @@ For more details, see the documentation:
* Saving and loading models: https://spacy.io/docs/usage/saving-loading
Developed for: spaCy 1.7.6
Last tested for: spaCy 1.7.6
Last updated for: spaCy 2.0.0a13
"""
from __future__ import unicode_literals, print_function
@ -34,55 +34,41 @@ from pathlib import Path
import random
import spacy
from spacy.gold import GoldParse
from spacy.tagger import Tagger
from spacy.gold import GoldParse, minibatch
from spacy.pipeline import NeuralEntityRecognizer
from spacy.pipeline import TokenVectorEncoder
def get_gold_parses(tokenizer, train_data):
'''Shuffle and create GoldParse objects'''
random.shuffle(train_data)
for raw_text, entity_offsets in train_data:
doc = tokenizer(raw_text)
gold = GoldParse(doc, entities=entity_offsets)
yield doc, gold
def train_ner(nlp, train_data, output_dir):
# Add new words to vocab
for raw_text, _ in train_data:
doc = nlp.make_doc(raw_text)
for word in doc:
_ = nlp.vocab[word.orth]
random.seed(0)
# You may need to change the learning rate. It's generally difficult to
# guess what rate you should set, especially when you have limited data.
nlp.entity.model.learn_rate = 0.001
for itn in range(1000):
random.shuffle(train_data)
loss = 0.
for raw_text, entity_offsets in train_data:
gold = GoldParse(doc, entities=entity_offsets)
# By default, the GoldParse class assumes that the entities
# described by offset are complete, and all other words should
# have the tag 'O'. You can tell it to make no assumptions
# about the tag of a word by giving it the tag '-'.
# However, this allows a trivial solution to the current
# learning problem: if words are either 'any tag' or 'ANIMAL',
# the model can learn that all words can be tagged 'ANIMAL'.
#for i in range(len(gold.ner)):
#if not gold.ner[i].endswith('ANIMAL'):
# gold.ner[i] = '-'
doc = nlp.make_doc(raw_text)
nlp.tagger(doc)
# As of 1.9, spaCy's parser now lets you supply a dropout probability
# This might help the model generalize better from only a few
# examples.
loss += nlp.entity.update(doc, gold, drop=0.9)
if loss == 0:
break
# This step averages the model's weights. This may or may not be good for
# your situation --- it's empirical.
nlp.end_training()
if output_dir:
if not output_dir.exists():
output_dir.mkdir()
nlp.save_to_directory(output_dir)
optimizer = nlp.begin_training(lambda: [])
nlp.meta['name'] = 'en_ent_animal'
for itn in range(50):
losses = {}
for batch in minibatch(get_gold_parses(nlp.make_doc, train_data), size=3):
docs, golds = zip(*batch)
nlp.update(docs, golds, losses=losses, sgd=optimizer, update_shared=True,
drop=0.35)
print(losses)
if not output_dir:
return
elif not output_dir.exists():
output_dir.mkdir()
nlp.to_disk(output_dir)
def main(model_name, output_directory=None):
print("Loading initial model", model_name)
nlp = spacy.load(model_name)
print("Creating initial model", model_name)
nlp = spacy.blank(model_name)
if output_directory is not None:
output_directory = Path(output_directory)
@ -91,6 +77,11 @@ def main(model_name, output_directory=None):
"Horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')],
),
(
"Do they bite?",
[],
),
(
"horses are too tall and they pretend to care about your feelings",
[(0, 6, 'ANIMAL')]
@ -109,18 +100,20 @@ def main(model_name, output_directory=None):
)
]
nlp.entity.add_label('ANIMAL')
nlp.pipeline.append(TokenVectorEncoder(nlp.vocab))
nlp.pipeline.append(NeuralEntityRecognizer(nlp.vocab))
nlp.pipeline[-1].add_label('ANIMAL')
train_ner(nlp, train_data, output_directory)
# Test that the entity is recognized
doc = nlp('Do you like horses?')
text = 'Do you like horses?'
print("Ents in 'Do you like horses?':")
doc = nlp(text)
for ent in doc.ents:
print(ent.label_, ent.text)
if output_directory:
print("Loading from", output_directory)
nlp2 = spacy.load('en', path=output_directory)
nlp2.entity.add_label('ANIMAL')
nlp2 = spacy.load(output_directory)
doc2 = nlp2('Do you like horses?')
for ent in doc2.ents:
print(ent.label_, ent.text)

View File

@ -229,20 +229,18 @@ def drop_layer(layer, factor=2.):
def Tok2Vec(width, embed_size, preprocess=None):
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower')
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix')
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix')
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape')
norm = HashEmbed(width, embed_size, column=cols.index(NORM), name='embed_norm')
prefix = HashEmbed(width, embed_size//2, column=cols.index(PREFIX), name='embed_prefix')
suffix = HashEmbed(width, embed_size//2, column=cols.index(SUFFIX), name='embed_suffix')
shape = HashEmbed(width, embed_size//2, column=cols.index(SHAPE), name='embed_shape')
embed = (norm | prefix | suffix | shape ) >> LN(Maxout(width, width*4, pieces=3))
tok2vec = (
with_flatten(
asarray(Model.ops, dtype='uint64')
>> uniqued(embed, column=5)
>> drop_layer(
Residual(
(ExtractWindow(nW=1) >> LN(Maxout(width, width*3)))
)
>> Residual(
(ExtractWindow(nW=1) >> LN(Maxout(width, width*3)))
) ** 4, pad=4
)
)
@ -372,6 +370,7 @@ def fine_tune(embedding, combine=None):
"fine_tune currently only supports addition. Set combine=None")
def fine_tune_fwd(docs_tokvecs, drop=0.):
docs, tokvecs = docs_tokvecs
lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i')
vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
@ -556,7 +555,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
cnn_model = (
# TODO Make concatenate support lists
concatenate_lists(trained_vectors, static_vectors)
concatenate_lists(trained_vectors, static_vectors)
>> with_flatten(
LN(Maxout(width, width*2))
>> Residual(

View File

@ -3,7 +3,7 @@
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
__title__ = 'spacy-nightly'
__version__ = '2.0.0a13'
__version__ = '2.0.0a14'
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
__uri__ = 'https://spacy.io'
__author__ = 'Explosion AI'

View File

@ -72,8 +72,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
util.env_opt('batch_compound', 1.001))
if resume:
prints(output_path / 'model19.pickle', title="Resuming training")
nlp = dill.load((output_path / 'model19.pickle').open('rb'))
prints(output_path / 'model9.pickle', title="Resuming training")
nlp = dill.load((output_path / 'model9.pickle').open('rb'))
else:
nlp = lang_class(pipeline=pipeline)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
@ -88,7 +88,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
if resume:
i += 20
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
train_docs = corpus.train_docs(nlp, projectivize=True,
train_docs = corpus.train_docs(nlp, projectivize=True, noise_level=0.0,
gold_preproc=gold_preproc, max_length=0)
losses = {}
for batch in minibatch(train_docs, size=batch_sizes):

View File

@ -7,6 +7,7 @@ import re
import ujson
import random
import cytoolz
import itertools
from .syntax import nonproj
from .util import ensure_path
@ -146,9 +147,13 @@ def minibatch(items, size=8):
'''Iterate over batches of items. `size` may be an iterator,
so that batch-size can vary on each step.
'''
if isinstance(size, int):
size_ = itertools.repeat(8)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size) #if hasattr(size, '__next__') else size
batch_size = next(size_)
batch = list(cytoolz.take(int(batch_size), items))
if len(batch) == 0:
break

View File

@ -347,15 +347,9 @@ class Language(object):
"""Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager.
gold_tuples (iterable): Gold-standard training data.
get_gold_tuples (function): Function returning gold data
**cfg: Config parameters.
YIELDS (tuple): A trainer and an optimizer.
EXAMPLE:
>>> with nlp.begin_training(gold, use_gpu=True) as (trainer, optimizer):
>>> for epoch in trainer.epochs(gold):
>>> for docs, golds in epoch:
>>> state = nlp.update(docs, golds, sgd=optimizer)
returns: An optimizer
"""
if self.parser:
self.pipeline.append(NeuralLabeller(self.vocab))

View File

@ -38,7 +38,8 @@ class Lemmatizer(object):
avoid lemmatization entirely.
"""
morphology = {} if morphology is None else morphology
others = [key for key in morphology if key not in (POS, 'number', 'pos', 'verbform')]
others = [key for key in morphology
if key not in (POS, 'Number', 'POS', 'VerbForm', 'Tense')]
true_morph_key = morphology.get('morph', 0)
if univ_pos == 'noun' and morphology.get('Number') == 'sing':
return True
@ -47,7 +48,9 @@ class Lemmatizer(object):
# This maps 'VBP' to base form -- probably just need 'IS_BASE'
# morphology
elif univ_pos == 'verb' and (morphology.get('VerbForm') == 'fin' and \
morphology.get('Tense') == 'pres'):
morphology.get('Tense') == 'pres' and \
morphology.get('Number') is None and \
not others):
return True
elif univ_pos == 'adj' and morphology.get('Degree') == 'pos':
return True

View File

@ -101,9 +101,10 @@ cdef cppclass StateC:
elif n == 6:
if this.B(0) >= 0:
ids[0] = this.B(0)
ids[1] = this.B(0)-1
else:
ids[0] = -1
ids[1] = this.B(0)
ids[1] = -1
ids[2] = this.B(1)
ids[3] = this.E(0)
if ids[3] >= 1:
@ -118,8 +119,12 @@ cdef cppclass StateC:
# TODO error =/
pass
for i in range(n):
# Token vectors should be padded, so that there's a vector for
# missing values at the start.
if ids[i] >= 0:
ids[i] += this.offset
ids[i] += this.offset + 1
else:
ids[i] = 0
int S(int i) nogil const:
if i >= this._s_i:
@ -162,9 +167,9 @@ cdef cppclass StateC:
int E(int i) nogil const:
if this._e_i <= 0 or this._e_i >= this.length:
return 0
return -1
if i < 0 or i >= this._e_i:
return 0
return -1
return this._ents[this._e_i - (i+1)].start
int L(int i, int idx) nogil const:

View File

@ -220,6 +220,31 @@ cdef class BiluoPushDown(TransitionSystem):
raise Exception(move)
return t
def add_action(self, int action, label_name):
cdef attr_t label_id
if not isinstance(label_name, (int, long)):
label_id = self.strings.add(label_name)
else:
label_id = label_name
if action == OUT and label_id != 0:
return
if action == MISSING or action == ISNT:
return
# Check we're not creating a move we already have, so that this is
# idempotent
for trans in self.c[:self.n_moves]:
if trans.move == action and trans.label == label_id:
return 0
if self.n_moves >= self._size:
self._size *= 2
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label_id)
assert self.c[self.n_moves].label == label_id
self.n_moves += 1
return 1
cdef int initialize_state(self, StateC* st) nogil:
# This is especially necessary when we use limited training data.
for i in range(st.length):

View File

@ -393,9 +393,8 @@ cdef class Parser:
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
# TODO: This is incorrect! Unhack when training next model
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
tokvecs = self._pad_tokvecs(tokvecs)
nr_state = len(docs)
nr_class = self.moves.n_moves
nr_dim = tokvecs.shape[1]
@ -455,6 +454,7 @@ cdef class Parser:
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs = self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
tokvecs = self._pad_tokvecs(tokvecs)
cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
cuda_stream, 0.0)
@ -532,8 +532,10 @@ cdef class Parser:
docs = [docs]
golds = [golds]
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs += self.model[0].ops.flatten(my_tokvecs)
tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs = self.model[0].ops.flatten(tokvecs)
tokvecs = self._pad_tokvecs(tokvecs)
cuda_stream = get_cuda_stream()
@ -584,6 +586,7 @@ cdef class Parser:
break
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
d_tokvecs = self._unpad_tokvecs(d_tokvecs)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
if USE_FINE_TUNE:
d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd)
@ -606,8 +609,8 @@ cdef class Parser:
assert min(lengths) >= 1
tokvecs = self.model[0].ops.flatten(tokvecs)
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs += self.model[0].ops.flatten(my_tokvecs)
tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs = self.model[0].ops.flatten(tokvecs)
states = self.moves.init_batch(docs)
for gold in golds:
@ -640,10 +643,20 @@ cdef class Parser:
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths)
d_tokvecs = self._unpad_tokvecs(d_tokvecs)
if USE_FINE_TUNE:
d_tokvecs = bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs
def _pad_tokvecs(self, tokvecs):
# Add a vector for missing values at the start of tokvecs
xp = get_array_module(tokvecs)
pad = xp.zeros((1, tokvecs.shape[1]), dtype=tokvecs.dtype)
return xp.vstack((pad, tokvecs))
def _unpad_tokvecs(self, d_tokvecs):
return d_tokvecs[1:]
def _init_gold_batch(self, whole_docs, whole_golds):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
@ -706,7 +719,7 @@ cdef class Parser:
lower, stream, drop=dropout)
return state2vec, upper
nr_feature = 13
nr_feature = 8
def get_token_ids(self, states):
cdef StateClass state

View File

@ -148,7 +148,7 @@ cdef class TransitionSystem:
def add_action(self, int action, label_name):
cdef attr_t label_id
if not isinstance(label_name, int):
if not isinstance(label_name, (int, long)):
label_id = self.strings.add(label_name)
else:
label_id = label_name

View File

@ -0,0 +1,8 @@
import pytest
@pytest.mark.models('en')
def test_issue1305(EN):
'''Test lemmatization of English VBZ'''
assert EN.vocab.morphology.lemmatizer('works', 'verb') == set(['work'])
doc = EN(u'This app works well')
assert doc[2].lemma_ == 'work'

View File

@ -9,11 +9,14 @@ import pytest
@pytest.mark.models('en')
def test_issue429(EN):
def merge_phrases(matcher, doc, i, matches):
if i != len(matches) - 1:
return None
spans = [(ent_id, ent_id, doc[start:end]) for ent_id, start, end in matches]
for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, EN.vocab.strings[label])
if i != len(matches) - 1:
return None
spans = [(ent_id, ent_id, doc[start:end]) for ent_id, start, end in matches]
for ent_id, label, span in spans:
span.merge(
tag=('NNP' if label else span.root.tag_),
lemma=span.text,
label='PERSON')
doc = EN('a')
matcher = Matcher(EN.vocab)

View File

@ -282,7 +282,7 @@ p
def __call__(self, text):
words = text.split(' ')
# All tokens 'own' a subsequent space character in this tokenizer
spaces = [True] * len(word)
spaces = [True] * len(words)
return Doc(self.vocab, words=words, spaces=spaces)
p