Merge branch 'master' of ssh://github.com/explosion/spaCy

This commit is contained in:
Matthew Honnibal 2016-10-24 20:32:06 +02:00
commit f0a079dc0b
8 changed files with 129 additions and 85 deletions

View File

@ -16,18 +16,18 @@ import spacy
class SentimentAnalyser(object):
@classmethod
def load(cls, path, nlp):
def load(cls, path, nlp, max_length=100):
with (path / 'config.json').open() as file_:
model = model_from_json(file_.read())
with (path / 'model').open('rb') as file_:
lstm_weights = pickle.load(file_)
embeddings = get_embeddings(nlp.vocab)
model.set_weights([embeddings] + lstm_weights)
return cls(model)
return cls(model, max_length=max_length)
def __init__(self, model):
def __init__(self, model, max_length=100):
self._model = model
self.max_length = max_length
def __call__(self, doc):
X = get_features([doc], self.max_length)
@ -36,10 +36,16 @@ class SentimentAnalyser(object):
def pipe(self, docs, batch_size=1000, n_threads=2):
for minibatch in cytoolz.partition_all(batch_size, docs):
Xs = get_features(minibatch, self.max_length)
minibatch = list(minibatch)
sentences = []
for doc in minibatch:
sentences.extend(doc.sents)
Xs = get_features(sentences, self.max_length)
ys = self._model.predict(Xs)
for i, doc in enumerate(minibatch):
doc.user_data['sentiment'] = ys[i]
for sent, label in zip(sentences, ys):
sent.doc.sentiment += label - 0.5
for doc in minibatch:
yield doc
def set_sentiment(self, doc, y):
doc.sentiment = float(y[0])
@ -48,6 +54,16 @@ class SentimentAnalyser(object):
# doc.user_data['my_data'] = y
def get_labelled_sentences(docs, doc_labels):
labels = []
sentences = []
for doc, y in zip(docs, doc_labels):
for sent in doc.sents:
sentences.append(sent)
labels.append(y)
return sentences, numpy.asarray(labels, dtype='int32')
def get_features(docs, max_length):
docs = list(docs)
Xs = numpy.zeros((len(docs), max_length), dtype='int32')
@ -63,12 +79,21 @@ def get_features(docs, max_length):
def train(train_texts, train_labels, dev_texts, dev_labels,
lstm_shape, lstm_settings, lstm_optimizer, batch_size=100, nb_epoch=5):
nlp = spacy.load('en', parser=False, tagger=False, entity=False)
lstm_shape, lstm_settings, lstm_optimizer, batch_size=100, nb_epoch=5,
by_sentence=True):
print("Loading spaCy")
nlp = spacy.load('en', entity=False)
embeddings = get_embeddings(nlp.vocab)
model = compile_lstm(embeddings, lstm_shape, lstm_settings)
train_X = get_features(nlp.pipe(train_texts), lstm_shape['max_length'])
dev_X = get_features(nlp.pipe(dev_texts), lstm_shape['max_length'])
print("Parsing texts...")
train_docs = list(nlp.pipe(train_texts, batch_size=5000, n_threads=3))
dev_docs = list(nlp.pipe(dev_texts, batch_size=5000, n_threads=3))
if by_sentence:
train_docs, train_labels = get_labelled_sentences(train_docs, train_labels)
dev_docs, dev_labels = get_labelled_sentences(dev_docs, dev_labels)
train_X = get_features(train_docs, lstm_shape['max_length'])
dev_X = get_features(dev_docs, lstm_shape['max_length'])
model.fit(train_X, train_labels, validation_data=(dev_X, dev_labels),
nb_epoch=nb_epoch, batch_size=batch_size)
return model
@ -86,7 +111,7 @@ def compile_lstm(embeddings, shape, settings):
mask_zero=True
)
)
model.add(TimeDistributed(Dense(shape['nr_hidden'] * 2)))
model.add(TimeDistributed(Dense(shape['nr_hidden'] * 2, bias=False)))
model.add(Dropout(settings['dropout']))
model.add(Bidirectional(LSTM(shape['nr_hidden'])))
model.add(Dropout(settings['dropout']))
@ -105,25 +130,23 @@ def get_embeddings(vocab):
return vectors
def demonstrate_runtime(model_dir, texts):
'''Demonstrate runtime usage of the custom sentiment model with spaCy.
Here we return a dictionary mapping entities to the average sentiment of the
documents they occurred in.
'''
def evaluate(model_dir, texts, labels, max_length=100):
def create_pipeline(nlp):
'''
This could be a lambda, but named functions are easier to read in Python.
'''
return [nlp.tagger, nlp.entity, SentimentAnalyser.load(model_dir, nlp)]
return [nlp.tagger, nlp.parser, SentimentAnalyser.load(model_dir, nlp,
max_length=max_length)]
nlp = spacy.load('en', create_pipeline=create_pipeline)
nlp = spacy.load('en')
nlp.pipeline = create_pipeline(nlp)
entity_sentiments = collections.Counter(float)
correct = 0
i = 0
for doc in nlp.pipe(texts, batch_size=1000, n_threads=4):
for ent in doc.ents:
entity_sentiments[ent.text] += doc.sentiment
return entity_sentiments
correct += bool(doc.sentiment >= 0.5) == bool(labels[i])
i += 1
return float(correct) / i
def read_data(data_dir, limit=0):
@ -162,10 +185,12 @@ def main(model_dir, train_dir, dev_dir,
dev_dir = pathlib.Path(dev_dir)
if is_runtime:
dev_texts, dev_labels = read_data(dev_dir)
demonstrate_runtime(model_dir, dev_texts)
acc = evaluate(model_dir, dev_texts, dev_labels, max_length=max_length)
print(acc)
else:
print("Read data")
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
dev_texts, dev_labels = read_data(dev_dir)
dev_texts, dev_labels = read_data(dev_dir, limit=nr_examples)
train_labels = numpy.asarray(train_labels, dtype='int32')
dev_labels = numpy.asarray(dev_labels, dtype='int32')
lstm = train(train_texts, train_labels, dev_texts, dev_labels,
@ -175,7 +200,9 @@ def main(model_dir, train_dir, dev_dir,
nb_epoch=nb_epoch, batch_size=batch_size)
weights = lstm.get_weights()
with (model_dir / 'model').open('wb') as file_:
pickle.dump(file_, weights[1:])
pickle.dump(weights[1:], file_)
with (model_dir / 'config.json').open('wb') as file_:
file_.write(lstm.to_json())
if __name__ == '__main__':

View File

@ -6,6 +6,7 @@ import random
import spacy
from spacy.pipeline import EntityRecognizer
from spacy.gold import GoldParse
from spacy.tagger import Tagger
def train_ner(nlp, train_data, entity_types):
@ -27,7 +28,16 @@ def main(model_dir=None):
model_dir.mkdir()
assert model_dir.is_dir()
nlp = spacy.load('en', parser=False, entity=False, vectors=False)
nlp = spacy.load('en', parser=False, entity=False, add_vectors=False)
# v1.1.2 onwards
if nlp.tagger is None:
print('---- WARNING ----')
print('Data directory not found')
print('please run: `python -m spacy.en.download force all` for better performance')
print('Using feature templates for tagging')
print('-----------------')
nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates)
train_data = [
(

View File

@ -32,7 +32,7 @@ def main(model_dir=None):
model_dir.mkdir()
assert model_dir.is_dir()
nlp = spacy.load('en', tagger=False, parser=False, entity=False, vectors=False)
nlp = spacy.load('en', tagger=False, parser=False, entity=False, add_vectors=False)
train_data = [
(

View File

@ -10,8 +10,9 @@ from pathlib import Path
from spacy.vocab import Vocab
from spacy.tagger import Tagger
from spacy.tokens import Doc
import random
from spacy.gold import GoldParse
import random
# You need to define a mapping from your data's part-of-speech tag names to the
# Universal Part-of-Speech tag set, as spaCy includes an enum of these tags.
@ -20,24 +21,25 @@ import random
# You may also specify morphological features for your tags, from the universal
# scheme.
TAG_MAP = {
'N': {"pos": "NOUN"},
'V': {"pos": "VERB"},
'J': {"pos": "ADJ"}
}
'N': {"pos": "NOUN"},
'V': {"pos": "VERB"},
'J': {"pos": "ADJ"}
}
# Usually you'll read this in, of course. Data formats vary.
# Ensure your strings are unicode.
DATA = [
(
["I", "like", "green", "eggs"],
["N", "V", "J", "N"]
["N", "V", "J", "N"]
),
(
["Eat", "blue", "ham"],
["V", "J", "N"]
["V", "J", "N"]
)
]
def ensure_dir(path):
if not path.exists():
path.mkdir()
@ -49,18 +51,19 @@ def main(output_dir=None):
ensure_dir(output_dir)
ensure_dir(output_dir / "pos")
ensure_dir(output_dir / "vocab")
vocab = Vocab(tag_map=TAG_MAP)
# The default_templates argument is where features are specified. See
# spacy/tagger.pyx for the defaults.
tagger = Tagger(vocab)
for i in range(5):
for i in range(25):
for words, tags in DATA:
doc = Doc(vocab, words=words)
tagger.update(doc, tags)
gold = GoldParse(doc, tags=tags)
tagger.update(doc, gold)
random.shuffle(DATA)
tagger.model.end_training()
doc = Doc(vocab, orths_and_spaces=zip(["I", "like", "blue", "eggs"], [True]*4))
doc = Doc(vocab, orths_and_spaces=zip(["I", "like", "blue", "eggs"], [True] * 4))
tagger(doc)
for word in doc:
print(word.text, word.tag_, word.pos_)

View File

@ -6,54 +6,56 @@
+h(2, "matcher", "https://github.com/" + SOCIAL.github + "/spaCy/blob/master/spacy/matcher.pyx")
| #[+tag class] Matcher
p A full example can be found #[a(href="https://github.com/" + SOCIAL.github + "blob/master/examples/matcher_example.py") here].
p A full example can be found #[a(href="https://github.com/" + SOCIAL.github + "/spaCy/blob/master/examples/matcher_example.py") here].
+table(["Usage", "Description"])
+row
+cell #[code.lang-python nlp(doc)]
+cell As part of annotation pipeline.
+row
+cell #[code.lang-python nlp(doc)]
+cell As part of annotation pipeline.
+row
+cell #[code.lang-python nlp.matcher(doc)]
+cell Explicit invocation.
+row
+cell #[code.lang-python nlp.matcher(doc)]
+cell Explicit invocation.
+row
+cell #[code.lang-python nlp.matcher.add(u'FooCorp', u'ORG', {}, [[{u'ORTH': u'Foo'}]])]
+cell Add a pattern to match.
+row
+cell #[code.lang-python nlp.matcher.add(u'FooCorp', u'ORG', {}, [[{u'ORTH': u'Foo'}]])]
+cell Add a pattern to match.
+section("matcher-init")
+h(3, "matcher-init") __init__(self, vocab, patterns)
+table(["Name", "Type", "Description"])
+row
+cell vocab
+cell #[code.lang-python spacy.vocab.Vocab]
+cell Reference to the shared vocabulary object.
+row
+cell patterns
+cell #[code {entity_key: (etype, attrs, specs)}]
+cell.
Initial patterns to match. See #[code Matcher.add]
+table(["Name", "Type", "Description"])
+row
+cell vocab
+cell #[code.lang-python spacy.vocab.Vocab]
+cell Reference to the shared vocabulary object.
+row
+cell patterns
+cell #[code {entity_key: (etype, attrs, specs)}]
+cell.
Initial patterns to match. See #[code Matcher.add]
+section("matcher-add")
+h(3, "matcher-add") add(self, entity_key, etype, attrs, specs)
+table(["Name", "Type", "Description"])
+row
+cell entity_key
+cell unicode or int
+cell Your arbitrary ID string (or its integer encoding)
+row
+cell etype
+cell unicode or int
+cell A pre-registered entity type, e.g. u'PERSON', u'ORG', etc.
+row
+cell attrs
+cell #[code dict]
+cell Placeholder for future support of entity attributes.
+row
+cell specs
+cell #[code [[{int: unicode}]]]
+cell A list of surface forms, where each surface form is defined as a list of token definitions, and each token definition is a dictionary mapping attribute IDs to attribute values.
+table(["Name", "Type", "Description"])
+row
+cell entity_key
+cell unicode or int
+cell Your arbitrary ID string (or its integer encoding)
+row
+cell etype
+cell unicode or int
+cell A pre-registered entity type, e.g. u'PERSON', u'ORG', etc.
+row
+cell attrs
+cell #[code dict]
+cell Placeholder for future support of entity attributes.
+row
+cell specs
+cell #[code [[{int: unicode}]]]
+cell A list of surface forms, where each surface form is defined as a list of token definitions, and each token definition is a dictionary mapping attribute IDs to attribute values.
+section("matcher-saveload")
+h(3, "matcher-saveload")

View File

@ -14,7 +14,8 @@
["Span", "#span", "span"],
["Lexeme", "#lexeme", "lexeme"],
["Vocab", "#vocab", "vocab"],
["StringStore", "#stringstore", "stringstore"]
["StringStore", "#stringstore", "stringstore"],
["Matcher", "#matcher", "matcher"]
],
"More": [
["Annotation Specs", "#annotation", "annotation"],

View File

@ -20,6 +20,7 @@ include _api-span
include _api-lexeme
include _api-vocab
include _api-stringstore
include _api-matcher
include _annotation-specs
include _tutorials

View File

@ -19,7 +19,7 @@ p I'll start with some quick code examples, that describe how to train each mode
tagger.model.end_training()
p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/examples/training/train_tagger.py") Full example]
p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/blob/master/examples/training/train_tagger.py") Full example]
+h(2, "train-entity") Training the named entity recognizer
@ -37,7 +37,7 @@ p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/examples/training/train_t
entity.model.end_training()
p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/examples/training/train_ner.y") Full example]
p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/blob/master/examples/training/train_ner.y") Full example]
+h(2, "train-entity") Training the dependency parser
@ -54,7 +54,7 @@ p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/examples/training/train_n
parser.model.end_training()
p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/examples/training/train_parser.py") Full example]
p #[+a("https://github.com/" + SOCIAL.github + "/spaCy/blob/master/examples/training/train_parser.py") Full example]
+h(2, 'feature-templates') Customising the feature extraction
@ -64,9 +64,9 @@ p Because it's a linear model, it's important for accuracy to build conjunction
p The feature extraction proceeds in two passes. In the first pass, we fill an array with the values of all of the atomic predictors. In the second pass, we iterate over the feature templates, and fill a small temporary array with the predictors that will be combined into a conjunction feature. Finally, we hash this array into a 64-bit integer, using the MurmurHash algorithm. You can see this at work in the #[+a("https://github.com/" + SOCIAL.github + "/thinc/blob/94dbe06fd3c8f24d86ab0f5c7984e52dbfcdc6cb/thinc/linear/features.pyx") thinc.linear.features] module.
p It's very easy to change the feature templates, to create novel combinations of the existing atomic predictors. There's currently no API available to add new atomic predictors, though. You'll have to create a subclass of the model, and write your own #[+code set_featuresC] method.
p It's very easy to change the feature templates, to create novel combinations of the existing atomic predictors. There's currently no API available to add new atomic predictors, though. You'll have to create a subclass of the model, and write your own #[code set_featuresC] method.
p The feature templates are passed in using the #[+code features] keyword argument to the constructors of the Tagger, DependencyParser and EntityRecognizer:
p The feature templates are passed in using the #[code features] keyword argument to the constructors of the Tagger, DependencyParser and EntityRecognizer:
+code('python', 'custom tagger templates').
from spacy.vocab import Vocab
@ -79,4 +79,4 @@ p The feature templates are passed in using the #[+code features] keyword argume
(P2_orth,), (P1_orth,), (W_orth,),
(N1_orth,), (N2_orth,)])
p Custom feature templates can be passed to the DependencyParser and EntityRecognizer as well, also using the #[+code features] keyword argument of the constructor.
p Custom feature templates can be passed to the DependencyParser and EntityRecognizer as well, also using the #[code features] keyword argument of the constructor.