2017-11-07 03:22:30 +03:00
|
|
|
"""
|
2018-12-02 06:26:26 +03:00
|
|
|
This example shows how to use an LSTM sentiment classification model trained
|
|
|
|
using Keras in spaCy. spaCy splits the document into sentences, and each
|
|
|
|
sentence is classified using the LSTM. The scores for the sentences are then
|
|
|
|
aggregated to give the document score. This kind of hierarchical model is quite
|
|
|
|
difficult in "pure" Keras or Tensorflow, but it's very effective. The Keras
|
|
|
|
example on this dataset performs quite poorly, because it cuts off the documents
|
|
|
|
so that they're a fixed size. This hurts review accuracy a lot, because people
|
|
|
|
often summarise their rating in the final sentence
|
2017-11-07 03:22:30 +03:00
|
|
|
|
|
|
|
Prerequisites:
|
|
|
|
spacy download en_vectors_web_lg
|
|
|
|
pip install keras==2.0.9
|
|
|
|
|
|
|
|
Compatible with: spaCy v2.0.0+
|
|
|
|
"""
|
|
|
|
|
2016-10-20 03:49:14 +03:00
|
|
|
import plac
|
|
|
|
import random
|
2016-10-20 04:42:34 +03:00
|
|
|
import pathlib
|
2016-10-20 03:49:14 +03:00
|
|
|
import cytoolz
|
2016-10-19 15:43:13 +03:00
|
|
|
import numpy
|
2016-10-20 04:21:56 +03:00
|
|
|
from keras.models import Sequential, model_from_json
|
2017-11-07 03:22:30 +03:00
|
|
|
from keras.layers import LSTM, Dense, Embedding, Bidirectional
|
2016-10-20 22:32:26 +03:00
|
|
|
from keras.layers import TimeDistributed
|
2016-10-20 05:39:54 +03:00
|
|
|
from keras.optimizers import Adam
|
2017-11-05 19:11:00 +03:00
|
|
|
import thinc.extra.datasets
|
2017-11-07 03:22:30 +03:00
|
|
|
from spacy.compat import pickle
|
2016-10-19 15:43:13 +03:00
|
|
|
import spacy
|
|
|
|
|
|
|
|
|
|
|
|
class SentimentAnalyser(object):
|
|
|
|
@classmethod
|
2016-10-24 00:17:41 +03:00
|
|
|
def load(cls, path, nlp, max_length=100):
|
2018-12-02 06:26:26 +03:00
|
|
|
with (path / "config.json").open() as file_:
|
2016-10-20 03:49:14 +03:00
|
|
|
model = model_from_json(file_.read())
|
2018-12-02 06:26:26 +03:00
|
|
|
with (path / "model").open("rb") as file_:
|
2016-10-20 03:49:14 +03:00
|
|
|
lstm_weights = pickle.load(file_)
|
|
|
|
embeddings = get_embeddings(nlp.vocab)
|
|
|
|
model.set_weights([embeddings] + lstm_weights)
|
2016-10-24 00:17:41 +03:00
|
|
|
return cls(model, max_length=max_length)
|
2016-10-19 15:43:13 +03:00
|
|
|
|
2016-10-24 00:17:41 +03:00
|
|
|
def __init__(self, model, max_length=100):
|
2016-10-19 15:43:13 +03:00
|
|
|
self._model = model
|
2016-10-24 00:17:41 +03:00
|
|
|
self.max_length = max_length
|
2016-10-20 03:49:14 +03:00
|
|
|
|
2016-10-19 15:43:13 +03:00
|
|
|
def __call__(self, doc):
|
|
|
|
X = get_features([doc], self.max_length)
|
2016-10-19 20:37:09 +03:00
|
|
|
y = self._model.predict(X)
|
2016-10-19 15:43:13 +03:00
|
|
|
self.set_sentiment(doc, y)
|
|
|
|
|
|
|
|
def pipe(self, docs, batch_size=1000, n_threads=2):
|
2016-10-20 03:49:14 +03:00
|
|
|
for minibatch in cytoolz.partition_all(batch_size, docs):
|
2016-10-24 00:17:41 +03:00
|
|
|
minibatch = list(minibatch)
|
|
|
|
sentences = []
|
|
|
|
for doc in minibatch:
|
|
|
|
sentences.extend(doc.sents)
|
|
|
|
Xs = get_features(sentences, self.max_length)
|
2016-10-20 03:49:14 +03:00
|
|
|
ys = self._model.predict(Xs)
|
2016-10-24 00:17:41 +03:00
|
|
|
for sent, label in zip(sentences, ys):
|
|
|
|
sent.doc.sentiment += label - 0.5
|
|
|
|
for doc in minibatch:
|
|
|
|
yield doc
|
2016-10-19 15:43:13 +03:00
|
|
|
|
|
|
|
def set_sentiment(self, doc, y):
|
2016-10-20 03:49:14 +03:00
|
|
|
doc.sentiment = float(y[0])
|
|
|
|
# Sentiment has a native slot for a single float.
|
|
|
|
# For arbitrary data storage, there's:
|
|
|
|
# doc.user_data['my_data'] = y
|
2016-10-19 15:43:13 +03:00
|
|
|
|
|
|
|
|
2016-10-24 00:17:41 +03:00
|
|
|
def get_labelled_sentences(docs, doc_labels):
|
|
|
|
labels = []
|
|
|
|
sentences = []
|
|
|
|
for doc, y in zip(docs, doc_labels):
|
|
|
|
for sent in doc.sents:
|
|
|
|
sentences.append(sent)
|
|
|
|
labels.append(y)
|
2018-12-02 06:26:26 +03:00
|
|
|
return sentences, numpy.asarray(labels, dtype="int32")
|
2016-10-24 00:17:41 +03:00
|
|
|
|
|
|
|
|
2016-10-19 20:37:09 +03:00
|
|
|
def get_features(docs, max_length):
|
2016-10-20 22:32:26 +03:00
|
|
|
docs = list(docs)
|
2018-12-02 06:26:26 +03:00
|
|
|
Xs = numpy.zeros((len(docs), max_length), dtype="int32")
|
2016-10-20 03:49:14 +03:00
|
|
|
for i, doc in enumerate(docs):
|
2016-10-20 22:32:26 +03:00
|
|
|
j = 0
|
|
|
|
for token in doc:
|
2017-11-05 19:11:00 +03:00
|
|
|
vector_id = token.vocab.vectors.find(key=token.orth)
|
|
|
|
if vector_id >= 0:
|
|
|
|
Xs[i, j] = vector_id
|
|
|
|
else:
|
|
|
|
Xs[i, j] = 0
|
|
|
|
j += 1
|
|
|
|
if j >= max_length:
|
|
|
|
break
|
2016-10-19 20:37:09 +03:00
|
|
|
return Xs
|
2016-10-20 03:49:14 +03:00
|
|
|
|
|
|
|
|
2018-12-02 06:26:26 +03:00
|
|
|
def train(
|
|
|
|
train_texts,
|
|
|
|
train_labels,
|
|
|
|
dev_texts,
|
|
|
|
dev_labels,
|
|
|
|
lstm_shape,
|
|
|
|
lstm_settings,
|
|
|
|
lstm_optimizer,
|
|
|
|
batch_size=100,
|
|
|
|
nb_epoch=5,
|
|
|
|
by_sentence=True,
|
|
|
|
):
|
|
|
|
|
2016-10-24 00:17:41 +03:00
|
|
|
print("Loading spaCy")
|
2018-12-02 06:26:26 +03:00
|
|
|
nlp = spacy.load("en_vectors_web_lg")
|
|
|
|
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
2016-10-20 03:49:14 +03:00
|
|
|
embeddings = get_embeddings(nlp.vocab)
|
|
|
|
model = compile_lstm(embeddings, lstm_shape, lstm_settings)
|
2018-12-02 06:26:26 +03:00
|
|
|
|
2016-10-24 00:17:41 +03:00
|
|
|
print("Parsing texts...")
|
2017-11-05 19:11:00 +03:00
|
|
|
train_docs = list(nlp.pipe(train_texts))
|
|
|
|
dev_docs = list(nlp.pipe(dev_texts))
|
2016-10-24 00:17:41 +03:00
|
|
|
if by_sentence:
|
|
|
|
train_docs, train_labels = get_labelled_sentences(train_docs, train_labels)
|
|
|
|
dev_docs, dev_labels = get_labelled_sentences(dev_docs, dev_labels)
|
2017-11-07 03:22:30 +03:00
|
|
|
|
2018-12-02 06:26:26 +03:00
|
|
|
train_X = get_features(train_docs, lstm_shape["max_length"])
|
|
|
|
dev_X = get_features(dev_docs, lstm_shape["max_length"])
|
|
|
|
model.fit(
|
|
|
|
train_X,
|
|
|
|
train_labels,
|
|
|
|
validation_data=(dev_X, dev_labels),
|
|
|
|
epochs=nb_epoch,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
2016-10-20 03:49:14 +03:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def compile_lstm(embeddings, shape, settings):
|
2016-10-19 15:43:13 +03:00
|
|
|
model = Sequential()
|
|
|
|
model.add(
|
|
|
|
Embedding(
|
|
|
|
embeddings.shape[0],
|
2016-10-20 04:42:34 +03:00
|
|
|
embeddings.shape[1],
|
2018-12-02 06:26:26 +03:00
|
|
|
input_length=shape["max_length"],
|
2016-10-19 15:43:13 +03:00
|
|
|
trainable=False,
|
2016-10-20 22:32:26 +03:00
|
|
|
weights=[embeddings],
|
2018-12-02 06:26:26 +03:00
|
|
|
mask_zero=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
model.add(TimeDistributed(Dense(shape["nr_hidden"], use_bias=False)))
|
|
|
|
model.add(
|
|
|
|
Bidirectional(
|
|
|
|
LSTM(
|
|
|
|
shape["nr_hidden"],
|
|
|
|
recurrent_dropout=settings["dropout"],
|
|
|
|
dropout=settings["dropout"],
|
|
|
|
)
|
2016-10-19 15:43:13 +03:00
|
|
|
)
|
|
|
|
)
|
2018-12-02 06:26:26 +03:00
|
|
|
model.add(Dense(shape["nr_class"], activation="sigmoid"))
|
|
|
|
model.compile(
|
|
|
|
optimizer=Adam(lr=settings["lr"]),
|
|
|
|
loss="binary_crossentropy",
|
|
|
|
metrics=["accuracy"],
|
|
|
|
)
|
2016-10-19 15:43:13 +03:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def get_embeddings(vocab):
|
2017-11-05 19:11:00 +03:00
|
|
|
return vocab.vectors.data
|
2016-10-19 15:43:13 +03:00
|
|
|
|
|
|
|
|
2016-10-24 00:17:41 +03:00
|
|
|
def evaluate(model_dir, texts, labels, max_length=100):
|
2018-12-02 06:26:26 +03:00
|
|
|
nlp = spacy.load("en_vectors_web_lg")
|
|
|
|
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
💫 Port master changes over to develop (#2979)
* Create aryaprabhudesai.md (#2681)
* Update _install.jade (#2688)
Typo fix: "models" -> "model"
* Add FAC to spacy.explain (resolves #2706)
* Remove docstrings for deprecated arguments (see #2703)
* When calling getoption() in conftest.py, pass a default option (#2709)
* When calling getoption() in conftest.py, pass a default option
This is necessary to allow testing an installed spacy by running:
pytest --pyargs spacy
* Add contributor agreement
* update bengali token rules for hyphen and digits (#2731)
* Less norm computations in token similarity (#2730)
* Less norm computations in token similarity
* Contributor agreement
* Remove ')' for clarity (#2737)
Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know.
* added contributor agreement for mbkupfer (#2738)
* Basic support for Telugu language (#2751)
* Lex _attrs for polish language (#2750)
* Signed spaCy contributor agreement
* Added polish version of english lex_attrs
* Introduces a bulk merge function, in order to solve issue #653 (#2696)
* Fix comment
* Introduce bulk merge to increase performance on many span merges
* Sign contributor agreement
* Implement pull request suggestions
* Describe converters more explicitly (see #2643)
* Add multi-threading note to Language.pipe (resolves #2582) [ci skip]
* Fix formatting
* Fix dependency scheme docs (closes #2705) [ci skip]
* Don't set stop word in example (closes #2657) [ci skip]
* Add words to portuguese language _num_words (#2759)
* Add words to portuguese language _num_words
* Add words to portuguese language _num_words
* Update Indonesian model (#2752)
* adding e-KTP in tokenizer exceptions list
* add exception token
* removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception
* add tokenizer exceptions list
* combining base_norms with norm_exceptions
* adding norm_exception
* fix double key in lemmatizer
* remove unused import on punctuation.py
* reformat stop_words to reduce number of lines, improve readibility
* updating tokenizer exception
* implement is_currency for lang/id
* adding orth_first_upper in tokenizer_exceptions
* update the norm_exception list
* remove bunch of abbreviations
* adding contributors file
* Fixed spaCy+Keras example (#2763)
* bug fixes in keras example
* created contributor agreement
* Adding French hyphenated first name (#2786)
* Fix typo (closes #2784)
* Fix typo (#2795) [ci skip]
Fixed typo on line 6 "regcognizer --> recognizer"
* Adding basic support for Sinhala language. (#2788)
* adding Sinhala language package, stop words, examples and lex_attrs.
* Adding contributor agreement
* Updating contributor agreement
* Also include lowercase norm exceptions
* Fix error (#2802)
* Fix error
ValueError: cannot resize an array that references or is referenced
by another array in this way. Use the resize function
* added spaCy Contributor Agreement
* Add charlax's contributor agreement (#2805)
* agreement of contributor, may I introduce a tiny pl languge contribution (#2799)
* Contributors agreement
* Contributors agreement
* Contributors agreement
* Add jupyter=True to displacy.render in documentation (#2806)
* Revert "Also include lowercase norm exceptions"
This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e.
* Remove deprecated encoding argument to msgpack
* Set up dependency tree pattern matching skeleton (#2732)
* Fix bug when too many entity types. Fixes #2800
* Fix Python 2 test failure
* Require older msgpack-numpy
* Restore encoding arg on msgpack-numpy
* Try to fix version pin for msgpack-numpy
* Update Portuguese Language (#2790)
* Add words to portuguese language _num_words
* Add words to portuguese language _num_words
* Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols
* Extended punctuation and norm_exceptions in the Portuguese language
* Correct error in spacy universe docs concerning spacy-lookup (#2814)
* Update Keras Example for (Parikh et al, 2016) implementation (#2803)
* bug fixes in keras example
* created contributor agreement
* baseline for Parikh model
* initial version of parikh 2016 implemented
* tested asymmetric models
* fixed grevious error in normalization
* use standard SNLI test file
* begin to rework parikh example
* initial version of running example
* start to document the new version
* start to document the new version
* Update Decompositional Attention.ipynb
* fixed calls to similarity
* updated the README
* import sys package duh
* simplified indexing on mapping word to IDs
* stupid python indent error
* added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround
* Fix typo (closes #2815) [ci skip]
* Update regex version dependency
* Set version to 2.0.13.dev3
* Skip seemingly problematic test
* Remove problematic test
* Try previous version of regex
* Revert "Remove problematic test"
This reverts commit bdebbef45552d698d390aa430b527ee27830f11b.
* Unskip test
* Try older version of regex
* 💫 Update training examples and use minibatching (#2830)
<!--- Provide a general summary of your changes in the title. -->
## Description
Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results.
### Types of change
enhancements
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
* Visual C++ link updated (#2842) (closes #2841) [ci skip]
* New landing page
* Add contribution agreement
* Correcting lang/ru/examples.py (#2845)
* Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement
* Correct some grammatical inaccuracies in lang\ru\examples.py
* Move contributor agreement to separate file
* Set version to 2.0.13.dev4
* Add Persian(Farsi) language support (#2797)
* Also include lowercase norm exceptions
* Remove in favour of https://github.com/explosion/spaCy/graphs/contributors
* Rule-based French Lemmatizer (#2818)
<!--- Provide a general summary of your changes in the title. -->
## Description
<!--- Use this section to describe your changes. If your changes required
testing, include information about the testing environment and the tests you
ran. If your test fixes a bug reported in an issue, don't forget to include the
issue number. If your PR is still a work in progress, that's totally fine – just
include a note to let us know. -->
Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class.
### Types of change
<!-- What type of change does your PR cover? Is it a bug fix, an enhancement
or new feature, or a change to the documentation? -->
- Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version.
- Add several files containing exhaustive list of words for each part of speech
- Add some lemma rules
- Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX
- Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned
- Modify the lemmatize function to check in lookup table as a last resort
- Init files are updated so the model can support all the functionalities mentioned above
- Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [X] I have submitted the spaCy Contributor Agreement.
- [X] I ran the tests, and all new and existing tests passed.
- [X] My changes don't require a change to the documentation, or if they do, I've added all required information.
* Set version to 2.0.13
* Fix formatting and consistency
* Update docs for new version [ci skip]
* Increment version [ci skip]
* Add info on wheels [ci skip]
* Adding "This is a sentence" example to Sinhala (#2846)
* Add wheels badge
* Update badge [ci skip]
* Update README.rst [ci skip]
* Update murmurhash pin
* Increment version to 2.0.14.dev0
* Update GPU docs for v2.0.14
* Add wheel to setup_requires
* Import prefer_gpu and require_gpu functions from Thinc
* Add tests for prefer_gpu() and require_gpu()
* Update requirements and setup.py
* Workaround bug in thinc require_gpu
* Set version to v2.0.14
* Update push-tag script
* Unhack prefer_gpu
* Require thinc 6.10.6
* Update prefer_gpu and require_gpu docs [ci skip]
* Fix specifiers for GPU
* Set version to 2.0.14.dev1
* Set version to 2.0.14
* Update Thinc version pin
* Increment version
* Fix msgpack-numpy version pin
* Increment version
* Update version to 2.0.16
* Update version [ci skip]
* Redundant ')' in the Stop words' example (#2856)
<!--- Provide a general summary of your changes in the title. -->
## Description
<!--- Use this section to describe your changes. If your changes required
testing, include information about the testing environment and the tests you
ran. If your test fixes a bug reported in an issue, don't forget to include the
issue number. If your PR is still a work in progress, that's totally fine – just
include a note to let us know. -->
### Types of change
<!-- What type of change does your PR cover? Is it a bug fix, an enhancement
or new feature, or a change to the documentation? -->
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [ ] I have submitted the spaCy Contributor Agreement.
- [ ] I ran the tests, and all new and existing tests passed.
- [ ] My changes don't require a change to the documentation, or if they do, I've added all required information.
* Documentation improvement regarding joblib and SO (#2867)
Some documentation improvements
## Description
1. Fixed the dead URL to joblib
2. Fixed Stack Overflow brand name (with space)
### Types of change
Documentation
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
* raise error when setting overlapping entities as doc.ents (#2880)
* Fix out-of-bounds access in NER training
The helper method state.B(1) gets the index of the first token of the
buffer, or -1 if no such token exists. Normally this is safe because we
pass this to functions like state.safe_get(), which returns an empty
token. Here we used it directly as an array index, which is not okay!
This error may have been the cause of out-of-bounds access errors during
training. Similar errors may still be around, so much be hunted down.
Hunting this one down took a long time...I printed out values across
training runs and diffed, looking for points of divergence between
runs, when no randomness should be allowed.
* Change PyThaiNLP Url (#2876)
* Fix missing comma
* Add example showing a fix-up rule for space entities
* Set version to 2.0.17.dev0
* Update regex version
* Revert "Update regex version"
This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a.
* Try setting older regex version, to align with conda
* Set version to 2.0.17
* Add spacy-js to universe [ci-skip]
* Add spacy-raspberry to universe (closes #2889)
* Add script to validate universe json [ci skip]
* Removed space in docs + added contributor indo (#2909)
* - removed unneeded space in documentation
* - added contributor info
* Allow input text of length up to max_length, inclusive (#2922)
* Include universe spec for spacy-wordnet component (#2919)
* feat: include universe spec for spacy-wordnet component
* chore: include spaCy contributor agreement
* Minor formatting changes [ci skip]
* Fix image [ci skip]
Twitter URL doesn't work on live site
* Check if the word is in one of the regular lists specific to each POS (#2886)
* 💫 Create random IDs for SVGs to prevent ID clashes (#2927)
Resolves #2924.
## Description
Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.)
### Types of change
bug fix
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
* Fix typo [ci skip]
* fixes symbolic link on py3 and windows (#2949)
* fixes symbolic link on py3 and windows
during setup of spacy using command
python -m spacy link en_core_web_sm en
closes #2948
* Update spacy/compat.py
Co-Authored-By: cicorias <cicorias@users.noreply.github.com>
* Fix formatting
* Update universe [ci skip]
* Catalan Language Support (#2940)
* Catalan language Support
* Ddding Catalan to documentation
* Sort languages alphabetically [ci skip]
* Update tests for pytest 4.x (#2965)
<!--- Provide a general summary of your changes in the title. -->
## Description
- [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize))
- [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here)
### Types of change
<!-- What type of change does your PR cover? Is it a bug fix, an enhancement
or new feature, or a change to the documentation? -->
## Checklist
<!--- Before you submit the PR, go over this checklist and make sure you can
tick off all the boxes. [] -> [x] -->
- [x] I have submitted the spaCy Contributor Agreement.
- [x] I ran the tests, and all new and existing tests passed.
- [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
* Fix regex pin to harmonize with conda (#2964)
* Update README.rst
* Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977)
Fixes #2976
* Fix typo
* Fix typo
* Remove duplicate file
* Require thinc 7.0.0.dev2
Fixes bug in gpu_ops that would use cupy instead of numpy on CPU
* Add missing import
* Fix error IDs
* Fix tests
2018-11-29 18:30:29 +03:00
|
|
|
nlp.add_pipe(SentimentAnalyser.load(model_dir, nlp, max_length=max_length))
|
2016-10-20 03:49:14 +03:00
|
|
|
|
2016-10-24 00:17:41 +03:00
|
|
|
correct = 0
|
2017-11-07 03:22:30 +03:00
|
|
|
i = 0
|
2016-10-19 15:43:13 +03:00
|
|
|
for doc in nlp.pipe(texts, batch_size=1000, n_threads=4):
|
2016-10-24 00:17:41 +03:00
|
|
|
correct += bool(doc.sentiment >= 0.5) == bool(labels[i])
|
|
|
|
i += 1
|
|
|
|
return float(correct) / i
|
2016-10-19 15:43:13 +03:00
|
|
|
|
|
|
|
|
|
|
|
def read_data(data_dir, limit=0):
|
|
|
|
examples = []
|
2018-12-02 06:26:26 +03:00
|
|
|
for subdir, label in (("pos", 1), ("neg", 0)):
|
2016-10-19 15:43:13 +03:00
|
|
|
for filename in (data_dir / subdir).iterdir():
|
|
|
|
with filename.open() as file_:
|
2016-10-20 03:49:14 +03:00
|
|
|
text = file_.read()
|
2016-10-19 15:43:13 +03:00
|
|
|
examples.append((text, label))
|
|
|
|
random.shuffle(examples)
|
|
|
|
if limit >= 1:
|
|
|
|
examples = examples[:limit]
|
2018-12-02 06:26:26 +03:00
|
|
|
return zip(*examples) # Unzips into two lists
|
2016-10-19 15:43:13 +03:00
|
|
|
|
|
|
|
|
|
|
|
@plac.annotations(
|
2016-10-20 04:21:56 +03:00
|
|
|
train_dir=("Location of training file or directory"),
|
|
|
|
dev_dir=("Location of development file or directory"),
|
2016-10-19 15:43:13 +03:00
|
|
|
model_dir=("Location of output model directory",),
|
|
|
|
is_runtime=("Demonstrate run-time usage", "flag", "r", bool),
|
2016-10-20 04:21:56 +03:00
|
|
|
nr_hidden=("Number of hidden units", "option", "H", int),
|
|
|
|
max_length=("Maximum sentence length", "option", "L", int),
|
|
|
|
dropout=("Dropout", "option", "d", float),
|
2016-10-20 05:39:54 +03:00
|
|
|
learn_rate=("Learn rate", "option", "e", float),
|
2016-10-20 04:21:56 +03:00
|
|
|
nb_epoch=("Number of training epochs", "option", "i", int),
|
|
|
|
batch_size=("Size of minibatches for training LSTM", "option", "b", int),
|
2018-12-02 06:26:26 +03:00
|
|
|
nr_examples=("Limit to N examples", "option", "n", int),
|
2016-10-19 15:43:13 +03:00
|
|
|
)
|
2018-12-02 06:26:26 +03:00
|
|
|
def main(
|
|
|
|
model_dir=None,
|
|
|
|
train_dir=None,
|
|
|
|
dev_dir=None,
|
|
|
|
is_runtime=False,
|
|
|
|
nr_hidden=64,
|
|
|
|
max_length=100, # Shape
|
|
|
|
dropout=0.5,
|
|
|
|
learn_rate=0.001, # General NN config
|
|
|
|
nb_epoch=5,
|
|
|
|
batch_size=256,
|
|
|
|
nr_examples=-1,
|
|
|
|
): # Training params
|
2017-11-05 19:11:00 +03:00
|
|
|
if model_dir is not None:
|
|
|
|
model_dir = pathlib.Path(model_dir)
|
|
|
|
if train_dir is None or dev_dir is None:
|
|
|
|
imdb_data = thinc.extra.datasets.imdb()
|
2016-10-19 15:43:13 +03:00
|
|
|
if is_runtime:
|
2017-11-05 19:11:00 +03:00
|
|
|
if dev_dir is None:
|
|
|
|
dev_texts, dev_labels = zip(*imdb_data[1])
|
|
|
|
else:
|
|
|
|
dev_texts, dev_labels = read_data(dev_dir)
|
2016-10-24 00:17:41 +03:00
|
|
|
acc = evaluate(model_dir, dev_texts, dev_labels, max_length=max_length)
|
|
|
|
print(acc)
|
2016-10-19 15:43:13 +03:00
|
|
|
else:
|
2017-11-05 19:11:00 +03:00
|
|
|
if train_dir is None:
|
|
|
|
train_texts, train_labels = zip(*imdb_data[0])
|
|
|
|
else:
|
|
|
|
print("Read data")
|
|
|
|
train_texts, train_labels = read_data(train_dir, limit=nr_examples)
|
|
|
|
if dev_dir is None:
|
|
|
|
dev_texts, dev_labels = zip(*imdb_data[1])
|
|
|
|
else:
|
|
|
|
dev_texts, dev_labels = read_data(dev_dir, imdb_data, limit=nr_examples)
|
2018-12-02 06:26:26 +03:00
|
|
|
train_labels = numpy.asarray(train_labels, dtype="int32")
|
|
|
|
dev_labels = numpy.asarray(dev_labels, dtype="int32")
|
|
|
|
lstm = train(
|
|
|
|
train_texts,
|
|
|
|
train_labels,
|
|
|
|
dev_texts,
|
|
|
|
dev_labels,
|
|
|
|
{"nr_hidden": nr_hidden, "max_length": max_length, "nr_class": 1},
|
|
|
|
{"dropout": dropout, "lr": learn_rate},
|
|
|
|
{},
|
|
|
|
nb_epoch=nb_epoch,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
2016-10-20 03:49:14 +03:00
|
|
|
weights = lstm.get_weights()
|
2017-11-05 19:11:00 +03:00
|
|
|
if model_dir is not None:
|
2018-12-02 06:26:26 +03:00
|
|
|
with (model_dir / "model").open("wb") as file_:
|
2017-11-05 19:11:00 +03:00
|
|
|
pickle.dump(weights[1:], file_)
|
2018-12-02 06:26:26 +03:00
|
|
|
with (model_dir / "config.json").open("w") as file_:
|
2017-11-05 19:11:00 +03:00
|
|
|
file_.write(lstm.to_json())
|
2016-10-19 15:43:13 +03:00
|
|
|
|
|
|
|
|
2018-12-02 06:26:26 +03:00
|
|
|
if __name__ == "__main__":
|
2016-10-19 15:43:13 +03:00
|
|
|
plac.call(main)
|