mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Merge pull request #1392 from explosion/feature/parser-history-model
💫 Parser history features
This commit is contained in:
commit
eb0595bea9
68
spacy/_ml.py
68
spacy/_ml.py
|
@ -32,7 +32,7 @@ import io
|
|||
|
||||
# TODO: Unset this once we don't want to support models previous models.
|
||||
import thinc.neural._classes.layernorm
|
||||
thinc.neural._classes.layernorm.set_compat_six_eight(True)
|
||||
thinc.neural._classes.layernorm.set_compat_six_eight(False)
|
||||
|
||||
VECTORS_KEY = 'spacy_pretrained_vectors'
|
||||
|
||||
|
@ -213,6 +213,72 @@ class PrecomputableMaxouts(Model):
|
|||
return dXf
|
||||
return Yfp, backward
|
||||
|
||||
# Thinc's Embed class is a bit broken atm, so drop this here.
|
||||
from thinc import describe
|
||||
from thinc.neural._classes.embed import _uniform_init
|
||||
|
||||
|
||||
@describe.attributes(
|
||||
nV=describe.Dimension("Number of vectors"),
|
||||
nO=describe.Dimension("Size of output"),
|
||||
vectors=describe.Weights("Embedding table",
|
||||
lambda obj: (obj.nV, obj.nO),
|
||||
_uniform_init(-0.1, 0.1)
|
||||
),
|
||||
d_vectors=describe.Gradient("vectors")
|
||||
)
|
||||
class Embed(Model):
|
||||
name = 'embed'
|
||||
|
||||
def __init__(self, nO, nV=None, **kwargs):
|
||||
if nV is not None:
|
||||
nV += 1
|
||||
Model.__init__(self, **kwargs)
|
||||
if 'name' in kwargs:
|
||||
self.name = kwargs['name']
|
||||
self.column = kwargs.get('column', 0)
|
||||
self.nO = nO
|
||||
self.nV = nV
|
||||
|
||||
def predict(self, ids):
|
||||
if ids.ndim == 2:
|
||||
ids = ids[:, self.column]
|
||||
return self.ops.xp.ascontiguousarray(self.vectors[ids], dtype='f')
|
||||
|
||||
def begin_update(self, ids, drop=0.):
|
||||
if ids.ndim == 2:
|
||||
ids = ids[:, self.column]
|
||||
vectors = self.ops.xp.ascontiguousarray(self.vectors[ids], dtype='f')
|
||||
def backprop_embed(d_vectors, sgd=None):
|
||||
n_vectors = d_vectors.shape[0]
|
||||
self.ops.scatter_add(self.d_vectors, ids, d_vectors)
|
||||
if sgd is not None:
|
||||
sgd(self._mem.weights, self._mem.gradient, key=self.id)
|
||||
return None
|
||||
return vectors, backprop_embed
|
||||
|
||||
|
||||
def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
|
||||
'''Wrap a model, adding features representing action history.'''
|
||||
if hist_size == 0:
|
||||
return layerize(noop())
|
||||
embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d')
|
||||
for i in range(hist_size)]
|
||||
embed = concatenate(*embed_tables)
|
||||
ops = embed.ops
|
||||
def add_history_fwd(vectors_hists, drop=0.):
|
||||
vectors, hist_ids = vectors_hists
|
||||
hist_feats, bp_hists = embed.begin_update(hist_ids, drop=drop)
|
||||
outputs = ops.xp.hstack((vectors, hist_feats))
|
||||
|
||||
def add_history_bwd(d_outputs, sgd=None):
|
||||
d_vectors = d_outputs[:, :vectors.shape[1]]
|
||||
d_hists = d_outputs[:, vectors.shape[1]:]
|
||||
bp_hists(d_hists, sgd=sgd)
|
||||
return embed.ops.xp.ascontiguousarray(d_vectors)
|
||||
return outputs, add_history_bwd
|
||||
return wrap(add_history_fwd, embed)
|
||||
|
||||
|
||||
def drop_layer(layer, factor=2.):
|
||||
def drop_layer_fwd(X, drop=0.):
|
||||
|
|
|
@ -42,6 +42,7 @@ def evaluate(cmd, model, data_path, gpu_id=-1, gold_preproc=False,
|
|||
Evaluate a model. To render a sample of parses in a HTML file, set an output
|
||||
directory as the displacy_path argument.
|
||||
"""
|
||||
if gpu_id >= 0:
|
||||
util.use_gpu(gpu_id)
|
||||
util.set_env_log(False)
|
||||
data_path = util.ensure_path(data_path)
|
||||
|
|
|
@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
dest.c.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
|
@ -149,7 +150,7 @@ nr_update = 0
|
|||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
int width, float density,
|
||||
int width, float density, int hist_feats,
|
||||
losses=None, drop=0.):
|
||||
global nr_update
|
||||
cdef MaxViolation violn
|
||||
|
@ -180,6 +181,10 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|||
# Now that we have our flat list of states, feed them through the model
|
||||
token_ids = get_token_ids(states, nr_feature)
|
||||
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
||||
if hist_feats:
|
||||
hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i')
|
||||
scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop)
|
||||
else:
|
||||
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
||||
|
||||
# Store the callbacks for the backward pass
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
from libc.string cimport memcpy, memset, memmove
|
||||
from libc.stdlib cimport malloc, calloc, free
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
|
||||
|
@ -15,6 +15,23 @@ from ..typedefs cimport attr_t
|
|||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
||||
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
||||
|
||||
cdef struct RingBufferC:
|
||||
int[8] data
|
||||
int i
|
||||
int default
|
||||
|
||||
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
|
||||
ring.data[ring.i] = value
|
||||
ring.i += 1
|
||||
if ring.i >= 8:
|
||||
ring.i = 0
|
||||
|
||||
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
|
||||
if i >= ring.i:
|
||||
return ring.default
|
||||
else:
|
||||
return ring.data[ring.i-i]
|
||||
|
||||
|
||||
cdef cppclass StateC:
|
||||
int* _stack
|
||||
|
@ -23,6 +40,7 @@ cdef cppclass StateC:
|
|||
TokenC* _sent
|
||||
Entity* _ents
|
||||
TokenC _empty_token
|
||||
RingBufferC _hist
|
||||
int length
|
||||
int offset
|
||||
int _s_i
|
||||
|
@ -37,6 +55,7 @@ cdef cppclass StateC:
|
|||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
||||
memset(&this._hist, 0, sizeof(this._hist))
|
||||
this.offset = 0
|
||||
cdef int i
|
||||
for i in range(length + (PADDING * 2)):
|
||||
|
@ -74,6 +93,9 @@ cdef cppclass StateC:
|
|||
free(this.shifted - PADDING)
|
||||
|
||||
void set_context_tokens(int* ids, int n) nogil:
|
||||
if n == 2:
|
||||
ids[0] = this.B(0)
|
||||
ids[1] = this.S(0)
|
||||
if n == 8:
|
||||
ids[0] = this.B(0)
|
||||
ids[1] = this.B(1)
|
||||
|
@ -271,7 +293,14 @@ cdef cppclass StateC:
|
|||
sig[8] = this.B_(0)[0]
|
||||
sig[9] = this.E_(0)[0]
|
||||
sig[10] = this.E_(1)[0]
|
||||
return hash64(sig, sizeof(sig), this._s_i)
|
||||
return hash64(sig, sizeof(sig), this._s_i) \
|
||||
+ hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
|
||||
|
||||
void push_hist(int act) nogil:
|
||||
ring_push(&this._hist, act+1)
|
||||
|
||||
int get_hist(int i) nogil:
|
||||
return ring_get(&this._hist, i)
|
||||
|
||||
void push() nogil:
|
||||
if this.B(0) != -1:
|
||||
|
|
|
@ -50,6 +50,7 @@ from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
|||
from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune
|
||||
from .._ml import Residual, drop_layer, flatten
|
||||
from .._ml import link_vectors_to_models
|
||||
from .._ml import HistoryFeatures
|
||||
from ..compat import json_dumps
|
||||
|
||||
from . import _parse_features
|
||||
|
@ -67,12 +68,10 @@ from ..gold cimport GoldParse
|
|||
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
||||
from . import _beam_utils
|
||||
|
||||
USE_FINE_TUNE = True
|
||||
|
||||
def get_templates(*args, **kwargs):
|
||||
return []
|
||||
|
||||
USE_FTRL = True
|
||||
DEBUG = False
|
||||
def set_debug(val):
|
||||
global DEBUG
|
||||
|
@ -239,12 +238,17 @@ cdef class Parser:
|
|||
Base class of the DependencyParser and EntityRecognizer.
|
||||
"""
|
||||
@classmethod
|
||||
def Model(cls, nr_class, token_vector_width=128, hidden_width=200, depth=1, **cfg):
|
||||
depth = util.env_opt('parser_hidden_depth', depth)
|
||||
token_vector_width = util.env_opt('token_vector_width', token_vector_width)
|
||||
hidden_width = util.env_opt('hidden_width', hidden_width)
|
||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
|
||||
embed_size = util.env_opt('embed_size', 7000)
|
||||
def Model(cls, nr_class, **cfg):
|
||||
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 2))
|
||||
token_vector_width = util.env_opt('token_vector_width', cfg.get('token_vector_width', 128))
|
||||
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128))
|
||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 1))
|
||||
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
|
||||
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 4))
|
||||
hist_width = util.env_opt('history_width', cfg.get('hist_width', 16))
|
||||
if hist_size >= 1 and depth == 0:
|
||||
raise ValueError("Inconsistent hyper-params: "
|
||||
"history_feats >= 1 but parser_hidden_depth==0")
|
||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||
pretrained_dims=cfg.get('pretrained_dims', 0))
|
||||
tok2vec = chain(tok2vec, flatten)
|
||||
|
@ -262,22 +266,40 @@ cdef class Parser:
|
|||
if depth == 0:
|
||||
upper = chain()
|
||||
upper.is_noop = True
|
||||
else:
|
||||
elif hist_size and depth == 1:
|
||||
upper = chain(
|
||||
clone(Maxout(hidden_width), depth-1),
|
||||
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
|
||||
nr_dim=hist_width),
|
||||
zero_init(Affine(nr_class, hidden_width+hist_size*hist_width,
|
||||
drop_factor=0.0)))
|
||||
upper.is_noop = False
|
||||
elif hist_size:
|
||||
upper = chain(
|
||||
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
|
||||
nr_dim=hist_width),
|
||||
LayerNorm(Maxout(hidden_width, hidden_width+hist_size*hist_width)),
|
||||
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-2),
|
||||
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
||||
)
|
||||
upper.is_noop = False
|
||||
else:
|
||||
upper = chain(
|
||||
clone(LayerNorm(Maxout(hidden_width, hidden_width)), depth-1),
|
||||
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
||||
)
|
||||
upper.is_noop = False
|
||||
|
||||
# TODO: This is an unfortunate hack atm!
|
||||
# Used to set input dimensions in network.
|
||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||
upper.begin_training(upper.ops.allocate((500, hidden_width)))
|
||||
cfg = {
|
||||
'nr_class': nr_class,
|
||||
'depth': depth,
|
||||
'hidden_depth': depth,
|
||||
'token_vector_width': token_vector_width,
|
||||
'hidden_width': hidden_width,
|
||||
'maxout_pieces': parser_maxout_pieces
|
||||
'maxout_pieces': parser_maxout_pieces,
|
||||
'hist_size': hist_size,
|
||||
'hist_width': hist_width
|
||||
}
|
||||
return (tok2vec, lower, upper), cfg
|
||||
|
||||
|
@ -350,7 +372,7 @@ cdef class Parser:
|
|||
_cleanup(beam)
|
||||
return output
|
||||
|
||||
def pipe(self, docs, int batch_size=1000, int n_threads=2,
|
||||
def pipe(self, docs, int batch_size=256, int n_threads=2,
|
||||
beam_width=None, beam_density=None):
|
||||
"""
|
||||
Process a stream of documents.
|
||||
|
@ -427,11 +449,17 @@ cdef class Parser:
|
|||
self._parse_step(next_step[i],
|
||||
feat_weights, nr_class, nr_feat, nr_piece)
|
||||
else:
|
||||
hists = []
|
||||
for i in range(nr_step):
|
||||
st = next_step[i]
|
||||
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
|
||||
self.moves.set_valid(&c_is_valid[i*nr_class], st)
|
||||
hists.append([st.get_hist(j+1) for j in range(8)])
|
||||
hists = numpy.asarray(hists)
|
||||
vectors = state2vec(token_ids[:next_step.size()])
|
||||
if self.cfg.get('hist_size'):
|
||||
scores = vec2scores((vectors, hists))
|
||||
else:
|
||||
scores = vec2scores(vectors)
|
||||
c_scores = <float*>scores.data
|
||||
for i in range(nr_step):
|
||||
|
@ -440,6 +468,7 @@ cdef class Parser:
|
|||
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
|
||||
action = self.moves.c[guess]
|
||||
action.do(st, action.label)
|
||||
st.push_hist(guess)
|
||||
this_step, next_step = next_step, this_step
|
||||
next_step.clear()
|
||||
for st in this_step:
|
||||
|
@ -478,6 +507,11 @@ cdef class Parser:
|
|||
states.append(stcls)
|
||||
token_ids = self.get_token_ids(states)
|
||||
vectors = state2vec(token_ids)
|
||||
if self.cfg.get('hist_size', 0):
|
||||
hists = numpy.asarray([st.history[:self.cfg['hist_size']]
|
||||
for st in states], dtype='i')
|
||||
scores = vec2scores((vectors, hists))
|
||||
else:
|
||||
scores = vec2scores(vectors)
|
||||
j = 0
|
||||
c_scores = <float*>scores.data
|
||||
|
@ -497,8 +531,6 @@ cdef class Parser:
|
|||
const float* feat_weights,
|
||||
int nr_class, int nr_feat, int nr_piece) nogil:
|
||||
'''This only works with no hidden layers -- fast but inaccurate'''
|
||||
#for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True):
|
||||
# self._parse_step(next_step[i], feat_weights, nr_class, nr_feat)
|
||||
token_ids = <int*>calloc(nr_feat, sizeof(int))
|
||||
scores = <float*>calloc(nr_class * nr_piece, sizeof(float))
|
||||
is_valid = <int*>calloc(nr_class, sizeof(int))
|
||||
|
@ -510,6 +542,7 @@ cdef class Parser:
|
|||
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
|
||||
action = self.moves.c[guess]
|
||||
action.do(state, action.label)
|
||||
state.push_hist(guess)
|
||||
|
||||
free(is_valid)
|
||||
free(scores)
|
||||
|
@ -550,6 +583,10 @@ cdef class Parser:
|
|||
if drop != 0:
|
||||
mask = vec2scores.ops.get_dropout_mask(vector.shape, drop)
|
||||
vector *= mask
|
||||
hists = numpy.asarray([st.history for st in states], dtype='i')
|
||||
if self.cfg.get('hist_size', 0):
|
||||
scores, bp_scores = vec2scores.begin_update((vector, hists), drop=drop)
|
||||
else:
|
||||
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
|
||||
|
||||
d_scores = self.get_batch_loss(states, golds, scores)
|
||||
|
@ -569,7 +606,8 @@ cdef class Parser:
|
|||
else:
|
||||
backprops.append((token_ids, d_vector, bp_vector))
|
||||
self.transition_batch(states, scores)
|
||||
todo = [st for st in todo if not st[0].is_final()]
|
||||
todo = [(st, gold) for (st, gold) in todo
|
||||
if not st.is_final()]
|
||||
if losses is not None:
|
||||
losses[self.name] += (d_scores**2).sum()
|
||||
n_steps += 1
|
||||
|
@ -602,7 +640,7 @@ cdef class Parser:
|
|||
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
width, density,
|
||||
width, density, self.cfg.get('hist_size', 0),
|
||||
drop=drop, losses=losses)
|
||||
backprop_lower = []
|
||||
cdef float batch_size = len(docs)
|
||||
|
@ -648,6 +686,7 @@ cdef class Parser:
|
|||
while state.B(0) < start and not state.is_final():
|
||||
action = self.moves.c[oracle_actions.pop(0)]
|
||||
action.do(state.c, action.label)
|
||||
state.c.push_hist(action.clas)
|
||||
n_moves += 1
|
||||
has_gold = self.moves.has_gold(gold, start=start,
|
||||
end=start+max_length)
|
||||
|
@ -711,6 +750,7 @@ cdef class Parser:
|
|||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
c_scores += scores.shape[1]
|
||||
state.c.push_hist(guess)
|
||||
|
||||
def get_batch_loss(self, states, golds, float[:, ::1] scores):
|
||||
cdef StateClass state
|
||||
|
@ -934,6 +974,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
dest.c.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
import numpy
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport Entity
|
||||
|
@ -38,6 +39,13 @@ cdef class StateClass:
|
|||
def token_vector_lenth(self):
|
||||
return self.doc.tensor.shape[1]
|
||||
|
||||
@property
|
||||
def history(self):
|
||||
hist = numpy.ndarray((8,), dtype='i')
|
||||
for i in range(8):
|
||||
hist[i] = self.c.get_hist(i+1)
|
||||
return hist
|
||||
|
||||
def is_final(self):
|
||||
return self.c.is_final()
|
||||
|
||||
|
@ -54,27 +62,3 @@ cdef class StateClass:
|
|||
n0 = words[self.B(0)]
|
||||
n1 = words[self.B(1)]
|
||||
return ' '.join((third, second, top, '|', n0, n1))
|
||||
|
||||
@classmethod
|
||||
def nr_context_tokens(cls):
|
||||
return 13
|
||||
|
||||
def set_context_tokens(self, int[::1] output):
|
||||
output[0] = self.B(0)
|
||||
output[1] = self.B(1)
|
||||
output[2] = self.S(0)
|
||||
output[3] = self.S(1)
|
||||
output[4] = self.S(2)
|
||||
output[5] = self.L(self.S(0), 1)
|
||||
output[6] = self.L(self.S(0), 2)
|
||||
output[6] = self.R(self.S(0), 1)
|
||||
output[7] = self.L(self.B(0), 1)
|
||||
output[8] = self.R(self.S(0), 2)
|
||||
output[9] = self.L(self.S(1), 1)
|
||||
output[10] = self.L(self.S(1), 2)
|
||||
output[11] = self.R(self.S(1), 1)
|
||||
output[12] = self.R(self.S(1), 2)
|
||||
|
||||
for i in range(13):
|
||||
if output[i] != -1:
|
||||
output[i] += self.c.offset
|
||||
|
|
|
@ -314,6 +314,16 @@ p
|
|||
+cell Size of the parser's and NER's hidden layers.
|
||||
+cell #[code 128]
|
||||
|
||||
+row
|
||||
+cell #[code history_feats]
|
||||
+cell Number of previous action ID features for parser and NER.
|
||||
+cell #[code 128]
|
||||
|
||||
+row
|
||||
+cell #[code history_width]
|
||||
+cell Number of embedding dimensions for each action ID.
|
||||
+cell #[code 128]
|
||||
|
||||
+row
|
||||
+cell #[code learn_rate]
|
||||
+cell Learning rate.
|
||||
|
|
Loading…
Reference in New Issue
Block a user