diff --git a/examples/chainer_sentiment.py b/examples/chainer_sentiment.py index 251eadef3..ac3881e75 100644 --- a/examples/chainer_sentiment.py +++ b/examples/chainer_sentiment.py @@ -64,6 +64,7 @@ class SentimentAnalyser(object): # For arbitrary data storage, there's: # doc.user_data['my_data'] = y + class Classifier(Chain): def __init__(self, predictor): super(Classifier, self).__init__(predictor=predictor) @@ -77,9 +78,10 @@ class Classifier(Chain): class SentimentModel(Chain): - def __init__(self, shape, **settings): + def __init__(self, nlp, shape, **settings): Chain.__init__(self, - embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden']), + embed=_Embed(shape['nr_vector'], shape['nr_dim'], shape['nr_hidden'], + initialW=lambda arr: set_vectors(arr, nlp.vocab)), encode=_Encode(shape['nr_hidden'], shape['nr_hidden']), attend=_Attend(shape['nr_hidden'], shape['nr_hidden']), predict=_Predict(shape['nr_hidden'], shape['nr_class'])) @@ -205,16 +207,14 @@ def get_features(docs, max_length): return Xs -def get_embeddings(vocab, max_rank=1000): - if max_rank is None: - max_rank = max(lex.rank+1 for lex in vocab if lex.has_vector) - vectors = xp.ndarray((max_rank+1, vocab.vectors_length), dtype='f') +def set_vectors(vectors, vocab): for lex in vocab: - if lex.has_vector and lex.rank < max_rank: + if lex.has_vector and (lex.rank+1) < vectors.shape[0]: lex.norm = lex.rank+1 vectors[lex.rank + 1] = lex.vector else: lex.norm = 0 + vectors.unchain_backwards() return vectors @@ -222,13 +222,10 @@ def train(train_texts, train_labels, dev_texts, dev_labels, lstm_shape, lstm_settings, lstm_optimizer, batch_size=100, nb_epoch=5, by_sentence=True): nlp = spacy.load('en', entity=False) - for lex in nlp.vocab: - if lex.rank >= (lstm_shape['nr_vector'] - 1): - lex.norm = 0 - else: - lex.norm = lex.rank+1 + if 'nr_vector' not in lstm_shape: + lstm_shape['nr_vector'] = max(lex.rank+1 for lex in vocab if lex.has_vector) print("Make model") - model = Classifier(SentimentModel(lstm_shape, **lstm_settings)) + model = Classifier(SentimentModel(nlp, lstm_shape, **lstm_settings)) print("Parsing texts...") if by_sentence: train_data = SentenceDataset(nlp, train_texts, train_labels, lstm_shape['max_length'])