Merge remote-tracking branch 'upstream/develop' into indonesian

This commit is contained in:
Jim Geovedi 2017-08-09 09:17:46 +07:00
commit c62b49b7cc
9 changed files with 167 additions and 102 deletions

View File

@ -5,10 +5,12 @@ from thinc.neural._classes.hash_embed import HashEmbed
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
import random import random
import cytoolz
from thinc.neural._classes.convolution import ExtractWindow from thinc.neural._classes.convolution import ExtractWindow
from thinc.neural._classes.static_vectors import StaticVectors from thinc.neural._classes.static_vectors import StaticVectors
from thinc.neural._classes.batchnorm import BatchNorm from thinc.neural._classes.batchnorm import BatchNorm
from thinc.neural._classes.layernorm import LayerNorm as LN
from thinc.neural._classes.resnet import Residual from thinc.neural._classes.resnet import Residual
from thinc.neural import ReLu from thinc.neural import ReLu
from thinc.neural._classes.selu import SELU from thinc.neural._classes.selu import SELU
@ -19,7 +21,7 @@ from thinc.api import FeatureExtracter, with_getitem
from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool
from thinc.neural._classes.attention import ParametricAttention from thinc.neural._classes.attention import ParametricAttention
from thinc.linear.linear import LinearModel from thinc.linear.linear import LinearModel
from thinc.api import uniqued, wrap from thinc.api import uniqued, wrap, flatten_add_lengths
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
from .tokens.doc import Doc from .tokens.doc import Doc
@ -53,6 +55,27 @@ def _logistic(X, drop=0.):
return Y, logistic_bwd return Y, logistic_bwd
@layerize
def add_tuples(X, drop=0.):
"""Give inputs of sequence pairs, where each sequence is (vals, length),
sum the values, returning a single sequence.
If input is:
((vals1, length), (vals2, length)
Output is:
(vals1+vals2, length)
vals are a single tensor for the whole batch.
"""
(vals1, length1), (vals2, length2) = X
assert length1 == length2
def add_tuples_bwd(dY, sgd=None):
return (dY, dY)
return (vals1+vals2, length), add_tuples_bwd
def _zero_init(model): def _zero_init(model):
def _zero_init_impl(self, X, y): def _zero_init_impl(self, X, y):
self.W.fill(0) self.W.fill(0)
@ -61,6 +84,7 @@ def _zero_init(model):
model.W.fill(0.) model.W.fill(0.)
return model return model
@layerize @layerize
def _preprocess_doc(docs, drop=0.): def _preprocess_doc(docs, drop=0.):
keys = [doc.to_array([LOWER]) for doc in docs] keys = [doc.to_array([LOWER]) for doc in docs]
@ -72,7 +96,6 @@ def _preprocess_doc(docs, drop=0.):
return (keys, vals, lengths), None return (keys, vals, lengths), None
def _init_for_precomputed(W, ops): def _init_for_precomputed(W, ops):
if (W**2).sum() != 0.: if (W**2).sum() != 0.:
return return
@ -80,6 +103,7 @@ def _init_for_precomputed(W, ops):
ops.xavier_uniform_init(reshaped) ops.xavier_uniform_init(reshaped)
W[:] = reshaped.reshape(W.shape) W[:] = reshaped.reshape(W.shape)
@describe.on_data(_set_dimensions_if_needed) @describe.on_data(_set_dimensions_if_needed)
@describe.attributes( @describe.attributes(
nI=Dimension("Input size"), nI=Dimension("Input size"),
@ -185,9 +209,9 @@ class PrecomputableMaxouts(Model):
def Tok2Vec(width, embed_size, preprocess=None): def Tok2Vec(width, embed_size, preprocess=None):
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE] cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}): with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower') norm = get_col(cols.index(NORM)) >> HashEmbed(width, embed_size, name='embed_lower')
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix') prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2, name='embed_prefix')
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix') suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2, name='embed_suffix')
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape') shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2, name='embed_shape')
@ -196,13 +220,13 @@ def Tok2Vec(width, embed_size, preprocess=None):
tok2vec = ( tok2vec = (
with_flatten( with_flatten(
asarray(Model.ops, dtype='uint64') asarray(Model.ops, dtype='uint64')
>> embed >> uniqued(embed, column=5)
>> Maxout(width, width*4, pieces=3) >> LN(Maxout(width, width*4, pieces=3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) >> Residual(ExtractWindow(nW=1) >> LN(Maxout(width, width*3)))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)) >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)), >> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3)),
pad=4) pad=4)
) )
if preprocess not in (False, None): if preprocess not in (False, None):
tok2vec = preprocess >> tok2vec tok2vec = preprocess >> tok2vec
@ -297,7 +321,7 @@ def zero_init(model):
def doc2feats(cols=None): def doc2feats(cols=None):
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE] cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
def forward(docs, drop=0.): def forward(docs, drop=0.):
feats = [] feats = []
for doc in docs: for doc in docs:
@ -323,6 +347,36 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
return vectors, backward return vectors, backward
def fine_tune(embedding, combine=None):
if combine is not None:
raise NotImplementedError(
"fine_tune currently only supports addition. Set combine=None")
def fine_tune_fwd(docs_tokvecs, drop=0.):
docs, tokvecs = docs_tokvecs
lengths = model.ops.asarray([len(doc) for doc in docs], dtype='i')
vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
flat_tokvecs = embedding.ops.flatten(tokvecs)
flat_vecs = embedding.ops.flatten(vecs)
output = embedding.ops.unflatten(
(model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs),
lengths)
def fine_tune_bwd(d_output, sgd=None):
bp_vecs(d_output, sgd=sgd)
flat_grad = model.ops.flatten(d_output)
model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum()
model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum()
sgd(model._mem.weights, model._mem.gradient, key=model.id)
return d_output
return output, fine_tune_bwd
model = wrap(fine_tune_fwd, embedding)
model.mix = model._mem.add((model.id, 'mix'), (2,))
model.mix.fill(1.)
model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
return model
@layerize @layerize
def flatten(seqs, drop=0.): def flatten(seqs, drop=0.):
if isinstance(seqs[0], numpy.ndarray): if isinstance(seqs[0], numpy.ndarray):
@ -369,6 +423,26 @@ def preprocess_doc(docs, drop=0.):
vals = ops.allocate(keys.shape[0]) + 1 vals = ops.allocate(keys.shape[0]) + 1
return (keys, vals, lengths), None return (keys, vals, lengths), None
def getitem(i):
def getitem_fwd(X, drop=0.):
return X[i], None
return layerize(getitem_fwd)
def build_tagger_model(nr_class, token_vector_width, **cfg):
with Model.define_operators({'>>': chain, '+': add}):
# Input: (doc, tensor) tuples
private_tok2vec = Tok2Vec(token_vector_width, 7500, preprocess=doc2feats())
model = (
fine_tune(private_tok2vec)
>> with_flatten(
Maxout(token_vector_width, token_vector_width)
>> Softmax(nr_class, token_vector_width)
)
)
model.nI = None
return model
def build_text_classifier(nr_class, width=64, **cfg): def build_text_classifier(nr_class, width=64, **cfg):
nr_vector = cfg.get('nr_vector', 200) nr_vector = cfg.get('nr_vector', 200)
@ -383,7 +457,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
>> _flatten_add_lengths >> _flatten_add_lengths
>> with_getitem(0, >> with_getitem(0,
uniqued( uniqued(
(embed_lower | embed_prefix | embed_suffix | embed_shape) (embed_lower | embed_prefix | embed_suffix | embed_shape)
>> Maxout(width, width+(width//2)*3)) >> Maxout(width, width+(width//2)*3))
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3)) >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3)) >> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
@ -404,7 +478,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
>> zero_init(Affine(nr_class, nr_class*2, drop_factor=0.0)) >> zero_init(Affine(nr_class, nr_class*2, drop_factor=0.0))
>> logistic >> logistic
) )
model.lsuv = False model.lsuv = False
return model return model

View File

@ -3,7 +3,7 @@
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py # https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
__title__ = 'spacy-nightly' __title__ = 'spacy-nightly'
__version__ = '2.0.0a6' __version__ = '2.0.0a7'
__summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython' __summary__ = 'Industrial-strength Natural Language Processing (NLP) with Python and Cython'
__uri__ = 'https://spacy.io' __uri__ = 'https://spacy.io'
__author__ = 'Explosion AI' __author__ = 'Explosion AI'

View File

@ -91,7 +91,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
for batch in minibatch(train_docs, size=batch_sizes): for batch in minibatch(train_docs, size=batch_sizes):
docs, golds = zip(*batch) docs, golds = zip(*batch)
nlp.update(docs, golds, sgd=optimizer, nlp.update(docs, golds, sgd=optimizer,
drop=next(dropout_rates), losses=losses) drop=next(dropout_rates), losses=losses,
update_tensors=True)
pbar.update(sum(len(doc) for doc in docs)) pbar.update(sum(len(doc) for doc in docs))
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):

View File

@ -277,7 +277,8 @@ class Language(object):
def make_doc(self, text): def make_doc(self, text):
return self.tokenizer(text) return self.tokenizer(text)
def update(self, docs, golds, drop=0., sgd=None, losses=None): def update(self, docs, golds, drop=0., sgd=None, losses=None,
update_tensors=False):
"""Update the models in the pipeline. """Update the models in the pipeline.
docs (iterable): A batch of `Doc` objects. docs (iterable): A batch of `Doc` objects.
@ -310,7 +311,7 @@ class Language(object):
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
d_tokvecses = proc.update((docs, tokvecses), golds, d_tokvecses = proc.update((docs, tokvecses), golds,
drop=drop, sgd=get_grads, losses=losses) drop=drop, sgd=get_grads, losses=losses)
if d_tokvecses is not None: if update_tensors and d_tokvecses is not None:
bp_tokvecses(d_tokvecses, sgd=sgd) bp_tokvecses(d_tokvecses, sgd=sgd)
for key, (W, dW) in grads.items(): for key, (W, dW) in grads.items():
sgd(W, dW, key=key) sgd(W, dW, key=key)
@ -381,9 +382,18 @@ class Language(object):
return optimizer return optimizer
def evaluate(self, docs_golds): def evaluate(self, docs_golds):
docs, golds = zip(*docs_golds)
scorer = Scorer() scorer = Scorer()
for doc, gold in zip(self.pipe(docs, batch_size=32), golds): docs, golds = zip(*docs_golds)
docs = list(docs)
golds = list(golds)
for pipe in self.pipeline:
if not hasattr(pipe, 'pipe'):
for doc in docs:
pipe(doc)
else:
docs = list(pipe.pipe(docs))
assert len(docs) == len(golds)
for doc, gold in zip(docs, golds):
scorer.score(doc, gold) scorer.score(doc, gold)
doc.tensor = None doc.tensor = None
return scorer return scorer

View File

@ -42,7 +42,7 @@ from .compat import json_dumps
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
from ._ml import build_text_classifier from ._ml import build_text_classifier, build_tagger_model
from .parts_of_speech import X from .parts_of_speech import X
@ -253,23 +253,25 @@ class NeuralTagger(BaseThincComponent):
self.cfg = dict(cfg) self.cfg = dict(cfg)
def __call__(self, doc): def __call__(self, doc):
tags = self.predict([doc.tensor]) tags = self.predict(([doc], [doc.tensor]))
self.set_annotations([doc], tags) self.set_annotations([doc], tags)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, stream):
docs = list(docs)
tokvecs = [d.tensor for d in docs] tokvecs = [d.tensor for d in docs]
tag_ids = self.predict(tokvecs) tag_ids = self.predict((docs, tokvecs))
self.set_annotations(docs, tag_ids) self.set_annotations(docs, tag_ids)
yield from docs yield from docs
def predict(self, tokvecs): def predict(self, docs_tokvecs):
scores = self.model(tokvecs) scores = self.model(docs_tokvecs)
scores = self.model.ops.flatten(scores) scores = self.model.ops.flatten(scores)
guesses = scores.argmax(axis=1) guesses = scores.argmax(axis=1)
if not isinstance(guesses, numpy.ndarray): if not isinstance(guesses, numpy.ndarray):
guesses = guesses.get() guesses = guesses.get()
tokvecs = docs_tokvecs[1]
guesses = self.model.ops.unflatten(guesses, guesses = self.model.ops.unflatten(guesses,
[tv.shape[0] for tv in tokvecs]) [tv.shape[0] for tv in tokvecs])
return guesses return guesses
@ -294,8 +296,7 @@ class NeuralTagger(BaseThincComponent):
if self.model.nI is None: if self.model.nI is None:
self.model.nI = tokvecs[0].shape[1] self.model.nI = tokvecs[0].shape[1]
tag_scores, bp_tag_scores = self.model.begin_update(docs_tokvecs, drop=drop)
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores) loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
@ -346,10 +347,8 @@ class NeuralTagger(BaseThincComponent):
@classmethod @classmethod
def Model(cls, n_tags, token_vector_width): def Model(cls, n_tags, token_vector_width):
return with_flatten( return build_tagger_model(n_tags, token_vector_width)
chain(Maxout(token_vector_width, token_vector_width),
Softmax(n_tags, token_vector_width)))
def use_params(self, params): def use_params(self, params):
with self.model.use_params(params): with self.model.use_params(params):
yield yield
@ -432,7 +431,7 @@ class NeuralLabeller(NeuralTagger):
@property @property
def labels(self): def labels(self):
return self.cfg.get('labels', {}) return self.cfg.setdefault('labels', {})
@labels.setter @labels.setter
def labels(self, value): def labels(self, value):
@ -455,10 +454,8 @@ class NeuralLabeller(NeuralTagger):
@classmethod @classmethod
def Model(cls, n_tags, token_vector_width): def Model(cls, n_tags, token_vector_width):
return with_flatten( return build_tagger_model(n_tags, token_vector_width)
chain(Maxout(token_vector_width, token_vector_width),
Softmax(n_tags, token_vector_width)))
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
scores = self.model.ops.flatten(scores) scores = self.model.ops.flatten(scores)
cdef int idx = 0 cdef int idx = 0

View File

@ -385,6 +385,7 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves): for i in range(self.n_moves):
if self.c[i].move == move and self.c[i].label == label: if self.c[i].move == move and self.c[i].label == label:
return self.c[i] return self.c[i]
return Transition(clas=0, move=MISSING, label=0)
def move_name(self, int move, attr_t label): def move_name(self, int move, attr_t label):
label_str = self.strings[label] label_str = self.strings[label]

View File

@ -14,8 +14,4 @@ cdef class Parser:
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef readonly object cfg cdef readonly object cfg
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil
#cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil

View File

@ -44,7 +44,7 @@ from thinc.neural.util import get_array_module
from .. import util from .. import util
from ..util import get_async, get_cuda_stream from ..util import get_async, get_cuda_stream
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
from .._ml import Tok2Vec, doc2feats, rebatch from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune
from ..compat import json_dumps from ..compat import json_dumps
from . import _parse_features from . import _parse_features
@ -237,6 +237,7 @@ cdef class Parser:
token_vector_width = util.env_opt('token_vector_width', token_vector_width) token_vector_width = util.env_opt('token_vector_width', token_vector_width)
hidden_width = util.env_opt('hidden_width', hidden_width) hidden_width = util.env_opt('hidden_width', hidden_width)
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2) parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
tensors = fine_tune(Tok2Vec(token_vector_width, 7500, preprocess=doc2feats()))
if parser_maxout_pieces == 1: if parser_maxout_pieces == 1:
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class, lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature, nF=cls.nr_feature,
@ -248,15 +249,10 @@ cdef class Parser:
nI=token_vector_width) nI=token_vector_width)
with Model.use_device('cpu'): with Model.use_device('cpu'):
if depth == 0: upper = chain(
upper = chain() clone(Maxout(hidden_width), (depth-1)),
upper.is_noop = True zero_init(Affine(nr_class, drop_factor=0.0))
else: )
upper = chain(
clone(Maxout(hidden_width), (depth-1)),
zero_init(Affine(nr_class, drop_factor=0.0))
)
upper.is_noop = False
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
# Used to set input dimensions in network. # Used to set input dimensions in network.
lower.begin_training(lower.ops.allocate((500, token_vector_width))) lower.begin_training(lower.ops.allocate((500, token_vector_width)))
@ -268,7 +264,7 @@ cdef class Parser:
'hidden_width': hidden_width, 'hidden_width': hidden_width,
'maxout_pieces': parser_maxout_pieces 'maxout_pieces': parser_maxout_pieces
} }
return (lower, upper), cfg return (tensors, lower, upper), cfg
def __init__(self, Vocab vocab, moves=True, model=True, **cfg): def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
""" """
@ -344,12 +340,10 @@ cdef class Parser:
The number of threads with which to work on the buffer in parallel. The number of threads with which to work on the buffer in parallel.
Yields (Doc): Documents, in order. Yields (Doc): Documents, in order.
""" """
cdef StateClass parse_state
cdef Doc doc cdef Doc doc
queue = []
for docs in cytoolz.partition_all(batch_size, docs): for docs in cytoolz.partition_all(batch_size, docs):
docs = list(docs) docs = list(docs)
tokvecs = [d.tensor for d in docs] tokvecs = [doc.tensor for doc in docs]
if beam_width == 1: if beam_width == 1:
parse_states = self.parse_batch(docs, tokvecs) parse_states = self.parse_batch(docs, tokvecs)
else: else:
@ -369,8 +363,11 @@ cdef class Parser:
int nr_class, nr_feat, nr_piece, nr_dim, nr_state int nr_class, nr_feat, nr_piece, nr_dim, nr_state
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
if isinstance(tokvecses, np.ndarray):
tokvecses = [tokvecses]
tokvecs = self.model[0].ops.flatten(tokvecses) tokvecs = self.model[0].ops.flatten(tokvecses)
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
nr_state = len(docs) nr_state = len(docs)
nr_class = self.moves.n_moves nr_class = self.moves.n_moves
@ -394,27 +391,20 @@ cdef class Parser:
cdef np.ndarray scores cdef np.ndarray scores
c_token_ids = <int*>token_ids.data c_token_ids = <int*>token_ids.data
c_is_valid = <int*>is_valid.data c_is_valid = <int*>is_valid.data
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
while not next_step.empty(): while not next_step.empty():
if not has_hidden: for i in range(next_step.size()):
for i in cython.parallel.prange( st = next_step[i]
next_step.size(), num_threads=6, nogil=True): st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self._parse_step(next_step[i], self.moves.set_valid(&c_is_valid[i*nr_class], st)
feat_weights, nr_class, nr_feat, nr_piece)
else:
for i in range(next_step.size()):
st = next_step[i]
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self.moves.set_valid(&c_is_valid[i*nr_class], st)
vectors = state2vec(token_ids[:next_step.size()]) vectors = state2vec(token_ids[:next_step.size()])
scores = vec2scores(vectors) scores = vec2scores(vectors)
c_scores = <float*>scores.data c_scores = <float*>scores.data
for i in range(next_step.size()): for i in range(next_step.size()):
st = next_step[i] st = next_step[i]
guess = arg_max_if_valid( guess = arg_max_if_valid(
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(st, action.label) action.do(st, action.label)
this_step, next_step = next_step, this_step this_step, next_step = next_step, this_step
next_step.clear() next_step.clear()
for st in this_step: for st in this_step:
@ -429,6 +419,7 @@ cdef class Parser:
cdef int nr_class = self.moves.n_moves cdef int nr_class = self.moves.n_moves
cdef StateClass stcls, output cdef StateClass stcls, output
tokvecs = self.model[0].ops.flatten(tokvecses) tokvecs = self.model[0].ops.flatten(tokvecses)
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
cuda_stream, 0.0) cuda_stream, 0.0)
@ -461,28 +452,6 @@ cdef class Parser:
beams.append(beam) beams.append(beam)
return beams return beams
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil:
'''This only works with no hidden layers -- fast but inaccurate'''
#for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True):
# self._parse_step(next_step[i], feat_weights, nr_class, nr_feat)
token_ids = <int*>calloc(nr_feat, sizeof(int))
scores = <float*>calloc(nr_class * nr_piece, sizeof(float))
is_valid = <int*>calloc(nr_class, sizeof(int))
state.set_context_tokens(token_ids, nr_feat)
sum_state_features(scores,
feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece)
self.moves.set_valid(is_valid, state)
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
action = self.moves.c[guess]
action.do(state, action.label)
free(is_valid)
free(scores)
free(token_ids)
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
losses[self.name] = 0. losses[self.name] = 0.
@ -491,6 +460,9 @@ cdef class Parser:
if isinstance(docs, Doc) and isinstance(golds, GoldParse): if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs] docs = [docs]
golds = [golds] golds = [golds]
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=0.)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
@ -540,7 +512,9 @@ cdef class Parser:
break break
self._make_updates(d_tokvecs, self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream) backprops, sgd, cuda_stream)
return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
#bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs
def _init_gold_batch(self, whole_docs, whole_golds): def _init_gold_batch(self, whole_docs, whole_golds):
"""Make a square batch, of length equal to the shortest doc. A long """Make a square batch, of length equal to the shortest doc. A long
@ -603,7 +577,7 @@ cdef class Parser:
return names return names
def get_batch_model(self, batch_size, tokvecs, stream, dropout): def get_batch_model(self, batch_size, tokvecs, stream, dropout):
lower, upper = self.model _, lower, upper = self.model
state2vec = precompute_hiddens(batch_size, tokvecs, state2vec = precompute_hiddens(batch_size, tokvecs,
lower, stream, drop=dropout) lower, stream, drop=dropout)
return state2vec, upper return state2vec, upper
@ -693,10 +667,12 @@ cdef class Parser:
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serializers = { serializers = {
'lower_model': lambda p: p.open('wb').write( 'tok2vec_model': lambda p: p.open('wb').write(
self.model[0].to_bytes()), self.model[0].to_bytes()),
'upper_model': lambda p: p.open('wb').write( 'lower_model': lambda p: p.open('wb').write(
self.model[1].to_bytes()), self.model[1].to_bytes()),
'upper_model': lambda p: p.open('wb').write(
self.model[2].to_bytes()),
'vocab': lambda p: self.vocab.to_disk(p), 'vocab': lambda p: self.vocab.to_disk(p),
'moves': lambda p: self.moves.to_disk(p, strings=False), 'moves': lambda p: self.moves.to_disk(p, strings=False),
'cfg': lambda p: p.open('w').write(json_dumps(self.cfg)) 'cfg': lambda p: p.open('w').write(json_dumps(self.cfg))
@ -717,24 +693,29 @@ cdef class Parser:
self.model, cfg = self.Model(**self.cfg) self.model, cfg = self.Model(**self.cfg)
else: else:
cfg = {} cfg = {}
with (path / 'lower_model').open('rb') as file_: with (path / 'tok2vec_model').open('rb') as file_:
bytes_data = file_.read() bytes_data = file_.read()
self.model[0].from_bytes(bytes_data) self.model[0].from_bytes(bytes_data)
with (path / 'upper_model').open('rb') as file_: with (path / 'lower_model').open('rb') as file_:
bytes_data = file_.read() bytes_data = file_.read()
self.model[1].from_bytes(bytes_data) self.model[1].from_bytes(bytes_data)
with (path / 'upper_model').open('rb') as file_:
bytes_data = file_.read()
self.model[2].from_bytes(bytes_data)
self.cfg.update(cfg) self.cfg.update(cfg)
return self return self
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
serializers = OrderedDict(( serializers = OrderedDict((
('lower_model', lambda: self.model[0].to_bytes()), ('tok2vec_model', lambda: self.model[0].to_bytes()),
('upper_model', lambda: self.model[1].to_bytes()), ('lower_model', lambda: self.model[1].to_bytes()),
('upper_model', lambda: self.model[2].to_bytes()),
('vocab', lambda: self.vocab.to_bytes()), ('vocab', lambda: self.vocab.to_bytes()),
('moves', lambda: self.moves.to_bytes(strings=False)), ('moves', lambda: self.moves.to_bytes(strings=False)),
('cfg', lambda: ujson.dumps(self.cfg)) ('cfg', lambda: ujson.dumps(self.cfg))
)) ))
if 'model' in exclude: if 'model' in exclude:
exclude['tok2vec_model'] = True
exclude['lower_model'] = True exclude['lower_model'] = True
exclude['upper_model'] = True exclude['upper_model'] = True
exclude.pop('model') exclude.pop('model')
@ -745,6 +726,7 @@ cdef class Parser:
('vocab', lambda b: self.vocab.from_bytes(b)), ('vocab', lambda b: self.vocab.from_bytes(b)),
('moves', lambda b: self.moves.from_bytes(b, strings=False)), ('moves', lambda b: self.moves.from_bytes(b, strings=False)),
('cfg', lambda b: self.cfg.update(ujson.loads(b))), ('cfg', lambda b: self.cfg.update(ujson.loads(b))),
('tok2vec_model', lambda b: None),
('lower_model', lambda b: None), ('lower_model', lambda b: None),
('upper_model', lambda b: None) ('upper_model', lambda b: None)
)) ))
@ -754,10 +736,12 @@ cdef class Parser:
self.model, cfg = self.Model(self.moves.n_moves) self.model, cfg = self.Model(self.moves.n_moves)
else: else:
cfg = {} cfg = {}
if 'tok2vec_model' in msg:
self.model[0].from_bytes(msg['tok2vec_model'])
if 'lower_model' in msg: if 'lower_model' in msg:
self.model[0].from_bytes(msg['lower_model']) self.model[1].from_bytes(msg['lower_model'])
if 'upper_model' in msg: if 'upper_model' in msg:
self.model[1].from_bytes(msg['upper_model']) self.model[2].from_bytes(msg['upper_model'])
self.cfg.update(cfg) self.cfg.update(cfg)
return self return self

View File

@ -107,6 +107,8 @@ cdef class TransitionSystem:
def is_valid(self, StateClass stcls, move_name): def is_valid(self, StateClass stcls, move_name):
action = self.lookup_transition(move_name) action = self.lookup_transition(move_name)
if action.move == 0:
return False
return action.is_valid(stcls.c, action.label) return action.is_valid(stcls.c, action.label)
cdef int set_valid(self, int* is_valid, const StateC* st) nogil: cdef int set_valid(self, int* is_valid, const StateC* st) nogil: