mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
from __future__ import division, unicode_literals, print_function
|
|
import spacy
|
|
|
|
import plac
|
|
from pathlib import Path
|
|
|
|
from spacy_hook import get_embeddings, get_word_ids
|
|
from spacy_hook import create_similarity_pipeline
|
|
|
|
|
|
def train(model_dir, train_loc, dev_loc, shape, settings):
|
|
print("Loading spaCy")
|
|
nlp = spacy.load('en', tagger=False, parser=False, entity=False, matcher=False)
|
|
print("Compiling network")
|
|
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
|
print("Processing texts...")
|
|
train_X = get_features(list(nlp.pipe(train_texts)))
|
|
dev_X = get_features(list(nlp.pipe(dev_texts)))
|
|
|
|
model.fit(
|
|
train_X,
|
|
train_labels,
|
|
validation_data=(dev_X, dev_labels),
|
|
nb_epoch=settings['nr_epoch'],
|
|
batch_size=settings['batch_size'])
|
|
|
|
|
|
def evaluate(model_dir, dev_loc):
|
|
nlp = spacy.load('en', path=model_dir,
|
|
tagger=False, parser=False, entity=False, matcher=False,
|
|
create_pipeline=create_similarity_pipeline)
|
|
n = 0
|
|
correct = 0
|
|
for (text1, text2), label in zip(dev_texts, dev_labels):
|
|
doc1 = nlp(text1)
|
|
doc2 = nlp(text2)
|
|
sim = doc1.similarity(doc2)
|
|
if bool(sim >= 0.5) == label:
|
|
correct += 1
|
|
n += 1
|
|
return correct, total
|
|
|
|
|
|
def demo(model_dir):
|
|
nlp = spacy.load('en', path=model_dir,
|
|
tagger=False, parser=False, entity=False, matcher=False,
|
|
create_pipeline=create_similarity_pipeline)
|
|
doc1 = nlp(u'Worst fries ever! Greasy and horrible...')
|
|
doc2 = nlp(u'The milkshakes are good. The fries are bad.')
|
|
print('doc1.similarity(doc2)', doc1.similarity(doc2))
|
|
sent1a, sent1b = doc1.sents
|
|
print('sent1a.similarity(sent1b)', sent1a.similarity(sent1b))
|
|
print('sent1a.similarity(doc2)', sent1a.similarity(doc2))
|
|
print('sent1b.similarity(doc2)', sent1b.similarity(doc2))
|
|
|
|
|
|
LABELS = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
|
|
def read_snli(loc):
|
|
with open(loc) as file_:
|
|
for line in file_:
|
|
eg = json.loads(line)
|
|
label = eg['gold_label']
|
|
if label == '-':
|
|
continue
|
|
text1 = eg['sentence1']
|
|
text2 = eg['sentence2']
|
|
yield text1, text2, LABELS[label]
|
|
|
|
|
|
@plac.annotations(
|
|
mode=("Mode to execute", "positional", None, str, ["train", "evaluate", "demo"]),
|
|
model_dir=("Path to spaCy model directory", "positional", None, Path),
|
|
train_loc=("Path to training data", "positional", None, Path),
|
|
dev_loc=("Path to development data", "positional", None, Path),
|
|
max_length=("Length to truncate sentences", "option", "L", int),
|
|
nr_hidden=("Number of hidden units", "option", "H", int),
|
|
dropout=("Dropout level", "option", "d", float),
|
|
learn_rate=("Learning rate", "option", "e", float),
|
|
batch_size=("Batch size for neural network training", "option", "b", float),
|
|
nr_epoch=("Number of training epochs", "option", "i", float)
|
|
)
|
|
def main(mode, model_dir, train_loc, dev_loc,
|
|
max_length=100,
|
|
nr_hidden=100,
|
|
dropout=0.2,
|
|
learn_rate=0.001,
|
|
batch_size=100,
|
|
nr_epoch=5):
|
|
shape = (max_length, nr_hidden, 3)
|
|
settings = {
|
|
'lr': learn_rate,
|
|
'dropout': dropout,
|
|
'batch_size': batch_size,
|
|
'nr_epoch': nr_epoch
|
|
}
|
|
if mode == 'train':
|
|
train(model_dir, train_loc, dev_loc, shape, settings)
|
|
elif mode == 'evaluate':
|
|
evaluate(model_dir, dev_loc)
|
|
else:
|
|
demo(model_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
plac.call(main)
|