Merge remote-tracking branch 'origin/develop' into feature/parser-history-model

This commit is contained in:
Matthew Honnibal 2017-10-03 16:56:42 -05:00
commit 246612cb53
12 changed files with 65 additions and 49 deletions

View File

@ -20,9 +20,10 @@ import plac
from pathlib import Path from pathlib import Path
import random import random
import json import json
import tqdm
from thinc.neural.optimizers import Adam from thinc.neural.optimizers import Adam
from thinc.neural.ops import NumpyOps from thinc.neural.ops import NumpyOps
import tqdm
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.pipeline import TokenVectorEncoder, NeuralEntityRecognizer from spacy.pipeline import TokenVectorEncoder, NeuralEntityRecognizer
@ -35,6 +36,7 @@ from spacy.gold import minibatch
from spacy.scorer import Scorer from spacy.scorer import Scorer
import spacy.util import spacy.util
try: try:
unicode unicode
except NameError: except NameError:
@ -55,20 +57,17 @@ def init_vocab():
class Pipeline(object): class Pipeline(object):
def __init__(self, vocab=None, tokenizer=None, tensorizer=None, entity=None): def __init__(self, vocab=None, tokenizer=None, entity=None):
if vocab is None: if vocab is None:
vocab = init_vocab() vocab = init_vocab()
if tokenizer is None: if tokenizer is None:
tokenizer = Tokenizer(vocab, {}, None, None, None) tokenizer = Tokenizer(vocab, {}, None, None, None)
if tensorizer is None:
tensorizer = TokenVectorEncoder(vocab)
if entity is None: if entity is None:
entity = NeuralEntityRecognizer(vocab) entity = NeuralEntityRecognizer(vocab)
self.vocab = vocab self.vocab = vocab
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tensorizer = tensorizer
self.entity = entity self.entity = entity
self.pipeline = [tensorizer, self.entity] self.pipeline = [self.entity]
def begin_training(self): def begin_training(self):
for model in self.pipeline: for model in self.pipeline:
@ -102,10 +101,8 @@ class Pipeline(object):
golds = [self.make_gold(input_, annot) for input_, annot in golds = [self.make_gold(input_, annot) for input_, annot in
zip(inputs, annots)] zip(inputs, annots)]
tensors, bp_tensors = self.tensorizer.update(docs, golds, drop=drop) self.entity.update(docs, golds, drop=drop,
d_tensors = self.entity.update((docs, tensors), golds, drop=drop, sgd=sgd, losses=losses)
sgd=sgd, losses=losses)
bp_tensors(d_tensors, sgd=sgd)
return losses return losses
def evaluate(self, examples): def evaluate(self, examples):
@ -123,7 +120,6 @@ class Pipeline(object):
elif not path.is_dir(): elif not path.is_dir():
raise IOError("Can't save pipeline to %s\nNot a directory" % path) raise IOError("Can't save pipeline to %s\nNot a directory" % path)
self.vocab.to_disk(path / 'vocab') self.vocab.to_disk(path / 'vocab')
self.tensorizer.to_disk(path / 'tensorizer')
self.entity.to_disk(path / 'ner') self.entity.to_disk(path / 'ner')
def from_disk(self, path): def from_disk(self, path):
@ -133,7 +129,6 @@ class Pipeline(object):
if not path.is_dir(): if not path.is_dir():
raise IOError("Cannot load pipeline from %s\nNot a directory" % path) raise IOError("Cannot load pipeline from %s\nNot a directory" % path)
self.vocab = self.vocab.from_disk(path / 'vocab') self.vocab = self.vocab.from_disk(path / 'vocab')
self.tensorizer = self.tensorizer.from_disk(path / 'tensorizer')
self.entity = self.entity.from_disk(path / 'ner') self.entity = self.entity.from_disk(path / 'ner')

5
fabfile.py vendored
View File

@ -14,6 +14,7 @@ VENV_DIR = path.join(PWD, ENV)
def env(lang='python2.7'): def env(lang='python2.7'):
if path.exists(VENV_DIR): if path.exists(VENV_DIR):
local('rm -rf {env}'.format(env=VENV_DIR)) local('rm -rf {env}'.format(env=VENV_DIR))
local('pip install virtualenv')
local('python -m virtualenv -p {lang} {env}'.format(lang=lang, env=VENV_DIR)) local('python -m virtualenv -p {lang} {env}'.format(lang=lang, env=VENV_DIR))
@ -32,6 +33,10 @@ def make():
local('pip install -r requirements.txt') local('pip install -r requirements.txt')
local('python setup.py build_ext --inplace') local('python setup.py build_ext --inplace')
def sdist():
with virtualenv(VENV_DIR):
with lcd(path.dirname(__file__)):
local('python setup.py sdist')
def clean(): def clean():
with lcd(path.dirname(__file__)): with lcd(path.dirname(__file__)):

View File

@ -3,7 +3,7 @@ pathlib
numpy>=1.7 numpy>=1.7
cymem>=1.30,<1.32 cymem>=1.30,<1.32
preshed>=1.0.0,<2.0.0 preshed>=1.0.0,<2.0.0
thinc>=6.8.2,<6.9.0 thinc>=6.9.0,<6.10.0
murmurhash>=0.28,<0.29 murmurhash>=0.28,<0.29
plac<1.0.0,>=0.9.6 plac<1.0.0,>=0.9.6
six six

View File

@ -195,7 +195,7 @@ def setup_package():
'murmurhash>=0.28,<0.29', 'murmurhash>=0.28,<0.29',
'cymem>=1.30,<1.32', 'cymem>=1.30,<1.32',
'preshed>=1.0.0,<2.0.0', 'preshed>=1.0.0,<2.0.0',
'thinc>=6.8.2,<6.9.0', 'thinc>=6.9.0,<6.10.0',
'plac<1.0.0,>=0.9.6', 'plac<1.0.0,>=0.9.6',
'six', 'six',
'pathlib', 'pathlib',

View File

@ -4,11 +4,13 @@ from __future__ import unicode_literals
from .cli.info import info as cli_info from .cli.info import info as cli_info
from .glossary import explain from .glossary import explain
from .deprecated import resolve_load_name from .deprecated import resolve_load_name
#from .about import __version__
from .about import __version__ from .about import __version__
from . import util from . import util
def load(name, **overrides): def load(name, **overrides):
from .deprecated import resolve_load_name
name = resolve_load_name(name, **overrides) name = resolve_load_name(name, **overrides)
return util.load_model(name, **overrides) return util.load_model(name, **overrides)

View File

@ -1,29 +1,27 @@
import ujson import ujson
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow, ParametricAttention
from thinc.t2v import Pooling, max_pool, mean_pool, sum_pool
from thinc.misc import Residual
from thinc.misc import BatchNorm as BN
from thinc.misc import LayerNorm as LN
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, Maxout, Softmax, Affine from thinc.api import FeatureExtracter, with_getitem
from thinc.neural._classes.hash_embed import HashEmbed from thinc.api import uniqued, wrap, flatten_add_lengths, noop
from thinc.linear.linear import LinearModel
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
import thinc.extra.load_nlp
import random import random
import cytoolz import cytoolz
from thinc.neural._classes.convolution import ExtractWindow
from thinc.neural._classes.static_vectors import StaticVectors
from thinc.neural._classes.batchnorm import BatchNorm as BN
from thinc.neural._classes.layernorm import LayerNorm as LN
from thinc.neural._classes.resnet import Residual
from thinc.neural import ReLu
from thinc.neural._classes.selu import SELU
from thinc import describe from thinc import describe
from thinc.describe import Dimension, Synapses, Biases, Gradient from thinc.describe import Dimension, Synapses, Biases, Gradient
from thinc.neural._classes.affine import _set_dimensions_if_needed from thinc.neural._classes.affine import _set_dimensions_if_needed
from thinc.api import FeatureExtracter, with_getitem import thinc.extra.load_nlp
from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool
from thinc.neural._classes.attention import ParametricAttention
from thinc.linear.linear import LinearModel
from thinc.api import uniqued, wrap, flatten_add_lengths, noop
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP, CLUSTER from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP, CLUSTER
from .tokens.doc import Doc from .tokens.doc import Doc
@ -32,6 +30,10 @@ from . import util
import numpy import numpy
import io import io
# TODO: Unset this once we don't want to support models previous models.
import thinc.neural._classes.layernorm
thinc.neural._classes.layernorm.set_compat_six_eight(True)
VECTORS_KEY = 'spacy_pretrained_vectors' VECTORS_KEY = 'spacy_pretrained_vectors'
@layerize @layerize

View File

@ -32,18 +32,25 @@ numpy.random.seed(0)
model=("Model name or path", "positional", None, str), model=("Model name or path", "positional", None, str),
data_path=("Location of JSON-formatted evaluation data", "positional", None, str), data_path=("Location of JSON-formatted evaluation data", "positional", None, str),
gold_preproc=("Use gold preprocessing", "flag", "G", bool), gold_preproc=("Use gold preprocessing", "flag", "G", bool),
gpu_id=("Use GPU", "option", "g", int),
) )
def evaluate(cmd, model, data_path, gold_preproc=False): def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False):
""" """
Train a model. Expects data in spaCy's JSON format. Train a model. Expects data in spaCy's JSON format.
""" """
util.set_env_log(True) util.use_gpu(gpu_id)
util.set_env_log(False)
data_path = util.ensure_path(data_path) data_path = util.ensure_path(data_path)
if not data_path.exists(): if not data_path.exists():
prints(data_path, title="Evaluation data not found", exits=1) prints(data_path, title="Evaluation data not found", exits=1)
corpus = GoldCorpus(data_path, data_path) corpus = GoldCorpus(data_path, data_path)
nlp = util.load_model(model) nlp = util.load_model(model)
scorer = nlp.evaluate(list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))) dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
begin = timer()
scorer = nlp.evaluate(dev_docs, verbose=False)
end = timer()
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
print('Time', end-begin, 'words', nwords, 'w.p.s', nwords/(end-begin))
print_results(scorer) print_results(scorer)

View File

@ -388,7 +388,7 @@ class Language(object):
self._optimizer.device = device self._optimizer.device = device
return self._optimizer return self._optimizer
def evaluate(self, docs_golds): def evaluate(self, docs_golds, verbose=False):
scorer = Scorer() scorer = Scorer()
docs, golds = zip(*docs_golds) docs, golds = zip(*docs_golds)
docs = list(docs) docs = list(docs)
@ -401,7 +401,9 @@ class Language(object):
docs = list(pipe.pipe(docs)) docs = list(pipe.pipe(docs))
assert len(docs) == len(golds) assert len(docs) == len(golds)
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
scorer.score(doc, gold) if verbose:
print(doc)
scorer.score(doc, gold, verbose=verbose)
return scorer return scorer
@contextmanager @contextmanager

View File

@ -4,7 +4,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from thinc.api import chain, layerize, with_getitem from thinc.api import chain, layerize, with_getitem
from thinc.neural import Model, Softmax
import numpy import numpy
cimport numpy as np cimport numpy as np
import cytoolz import cytoolz
@ -14,17 +13,18 @@ import ujson
import msgpack import msgpack
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, Maxout, Softmax, Affine from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
from thinc.neural._classes.hash_embed import HashEmbed from thinc.i2v import HashEmbed
from thinc.t2v import Pooling, max_pool, mean_pool, sum_pool
from thinc.t2t import ExtractWindow, ParametricAttention
from thinc.misc import Residual
from thinc.misc import BatchNorm as BN
from thinc.misc import LayerNorm as LN
from thinc.neural.util import to_categorical from thinc.neural.util import to_categorical
from thinc.neural.pooling import Pooling, max_pool, mean_pool
from thinc.neural._classes.difference import Siamese, CauchySimilarity from thinc.neural._classes.difference import Siamese, CauchySimilarity
from thinc.neural._classes.convolution import ExtractWindow
from thinc.neural._classes.resnet import Residual
from thinc.neural._classes.batchnorm import BatchNorm as BN
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .syntax.parser cimport Parser as LinearParser from .syntax.parser cimport Parser as LinearParser
from .syntax.nn_parser cimport Parser as NeuralParser from .syntax.nn_parser cimport Parser as NeuralParser

View File

@ -38,10 +38,9 @@ from preshed.maps cimport MapStruct
from preshed.maps cimport map_get from preshed.maps cimport map_get
from thinc.api import layerize, chain, noop, clone, with_flatten from thinc.api import layerize, chain, noop, clone, with_flatten
from thinc.neural import Model, Affine, ReLu, Maxout from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu, SELU
from thinc.neural._classes.batchnorm import BatchNorm as BN from thinc.misc import LayerNorm
from thinc.neural._classes.selu import SELU
from thinc.neural._classes.layernorm import LayerNorm
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module

View File

@ -9,7 +9,8 @@ from .util import get_doc
from pathlib import Path from pathlib import Path
import pytest import pytest
from thinc.neural import Maxout, Softmax from thinc.neural._classes.maxout import Maxout
from thinc.neural._classes.softmax import Softmax
from thinc.api import chain from thinc.api import chain

View File

@ -563,7 +563,10 @@ def minify_html(html):
def use_gpu(gpu_id): def use_gpu(gpu_id):
import cupy.cuda.device try:
import cupy.cuda.device
except ImportError:
return None
from thinc.neural.ops import CupyOps from thinc.neural.ops import CupyOps
device = cupy.cuda.device.Device(gpu_id) device = cupy.cuda.device.Device(gpu_id)
device.use() device.use()