mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Draft of NN parser, to be tested
This commit is contained in:
parent
7d1df50aec
commit
ef4fa594aa
121
spacy/_ml.py
121
spacy/_ml.py
|
@ -4,22 +4,6 @@ from thinc.neural._classes.static_vectors import StaticVectors
|
||||||
from thinc.neural._classes.hash_embed import HashEmbed
|
from thinc.neural._classes.hash_embed import HashEmbed
|
||||||
from thinc.neural._classes.convolution import ExtractWindow
|
from thinc.neural._classes.convolution import ExtractWindow
|
||||||
|
|
||||||
from .attrs import ID, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
|
||||||
|
|
||||||
|
|
||||||
@layerize
|
|
||||||
def get_contexts(states, drop=0.):
|
|
||||||
ops = Model.ops
|
|
||||||
context = ops.allocate((len(states), 7), dtype='uint64')
|
|
||||||
for i, state in enumerate(states):
|
|
||||||
context[i, 0] = state.B(0)
|
|
||||||
context[i, 1] = state.S(0)
|
|
||||||
context[i, 2] = state.S(1)
|
|
||||||
context[i, 3] = state.L(state.S(0), 1)
|
|
||||||
context[i, 4] = state.L(state.S(0), 2)
|
|
||||||
context[i, 5] = state.R(state.S(0), 1)
|
|
||||||
context[i, 6] = state.R(state.S(0), 2)
|
|
||||||
return (context, states), None
|
|
||||||
|
|
||||||
def get_col(idx):
|
def get_col(idx):
|
||||||
def forward(X, drop=0.):
|
def forward(X, drop=0.):
|
||||||
|
@ -27,69 +11,68 @@ def get_col(idx):
|
||||||
return layerize(forward)
|
return layerize(forward)
|
||||||
|
|
||||||
|
|
||||||
def extract_features(attrs):
|
def build_model(state2vec, width, depth, nr_class):
|
||||||
ops = Model.ops
|
with Model.define_operators({'>>': chain, '**': clone}):
|
||||||
def forward(contexts_states, drop=0.):
|
model = state2vec >> Maxout(width) ** depth >> Softmax(nr_class)
|
||||||
contexts, states = contexts_states
|
return model
|
||||||
output = ops.allocate((len(states), contexts.shape[1], len(attrs)),
|
|
||||||
dtype='uint64')
|
|
||||||
|
def build_parser_state2vec(tag_vectors, dep_vectors, **cfg):
|
||||||
|
embed_tags = _reshape(chain(get_col(0), tag_vectors))
|
||||||
|
embed_deps = _reshape(chain(get_col(1), dep_vectors))
|
||||||
|
attr_names = ops.asarray([TAG, DEP], dtype='i')
|
||||||
|
def forward(states, drop=0.):
|
||||||
|
n_tokens = state.nr_context_tokens(nF, nB, nS, nL, nR)
|
||||||
for i, state in enumerate(states):
|
for i, state in enumerate(states):
|
||||||
for j, tok_i in enumerate(contexts[i]):
|
state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR)
|
||||||
token = state.get_token(tok_i)
|
state.set_attributes(features[i], tokens[i], attr_names)
|
||||||
for k, attr in enumerate(attrs):
|
state.set_token_vectors(token_vectors[i], tokens[i])
|
||||||
output[i, j, k] = getattr(token, attr)
|
|
||||||
return output, None
|
|
||||||
return layerize(forward)
|
|
||||||
|
|
||||||
|
tagvecs, bp_tag_vecs = embed_deps.begin_update(attr_vals, drop=drop)
|
||||||
|
depvecs, bp_dep_vecs = embed_tags.begin_update(attr_vals, drop=drop)
|
||||||
|
|
||||||
def build_tok2vec(lang, width, depth, embed_size):
|
vector = ops.concatenate((tagvecs, depvecs, tokvecs))
|
||||||
cols = [ID, PREFIX, SUFFIX, SHAPE]
|
|
||||||
|
|
||||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
shapes = (tagvecs.shape, depvecs.shape, tokvecs.shape)
|
||||||
static = get_col(cols.index(ID)) >> StaticVectors('en', width)
|
def backward(d_vector, sgd=None):
|
||||||
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
d_depvecs, d_tagvecs, d_tokvecs = ops.backprop_concatenate(d_vector, shapes)
|
||||||
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
bp_tagvecs(d_tagvecs)
|
||||||
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
bp_depvecs(d_depvecs)
|
||||||
tok2vec = (
|
return (d_tokvecs, tokens)
|
||||||
extract_features(cols)
|
return vector, backward
|
||||||
>> (static | prefix | suffix | shape)
|
model = layerize(forward)
|
||||||
>> (ExtractWindow(nW=1) >> Maxout(width)) ** depth
|
model._layers = [embed_tags, embed_deps]
|
||||||
)
|
|
||||||
return tok2vec
|
|
||||||
|
|
||||||
|
|
||||||
def build_parse2vec(width, embed_size):
|
|
||||||
cols = [TAG, DEP]
|
|
||||||
with Model.define_operators({'>>': chain, '|': concatenate}):
|
|
||||||
tag_vector = get_col(cols.index(TAG)) >> HashEmbed(width, 1000)
|
|
||||||
dep_vector = get_col(cols.index(DEP)) >> HashEmbed(width, 1000)
|
|
||||||
model = (
|
|
||||||
extract_features([TAG, DEP])
|
|
||||||
>> (tag_vector | dep_vector)
|
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_model(state2context, tok2vec, parse2vec, width, depth, nr_class):
|
def _reshape(layer):
|
||||||
with Model.define_operators({'>>': chain, '**': clone, '|': concatenate}):
|
def forward(X, drop=0.):
|
||||||
model = (
|
Xh = X.reshape((X.shape[0] * X.shape[1], X.shape[2]))
|
||||||
state2context
|
yh, bp_yh = layer.begin_update(Xh, drop=drop)
|
||||||
>> (tok2vec | parse2vec)
|
n = X.shape[0]
|
||||||
>> Maxout(width) ** depth
|
def backward(d_y, sgd=None):
|
||||||
>> Softmax(nr_class)
|
d_yh = d_y.reshape((n, d_y.size / n))
|
||||||
)
|
d_Xh = bp_yh(d_yh, sgd)
|
||||||
|
return d_Xh.reshape(old_shape)
|
||||||
|
return yh.reshape((n, yh.shape / n)), backward
|
||||||
|
model = layerize(forward)
|
||||||
|
model._layers.append(layer)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def test_build_model(width=100, depth=2, nr_class=10):
|
|
||||||
model = build_model(
|
#def build_tok2vec(lang, width, depth, embed_size, cols):
|
||||||
get_contexts,
|
# with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
||||||
build_tok2vec('en', width=100, depth=2, embed_size=1000),
|
# static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
|
||||||
build_parse2vec(width=100, embed_size=1000),
|
# prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
||||||
width,
|
# suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
||||||
depth,
|
# shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
||||||
nr_class)
|
# tok2vec = (
|
||||||
assert model is not None
|
# (static | prefix | suffix | shape)
|
||||||
|
# >> Maxout(width, width*4)
|
||||||
|
# >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) ** depth
|
||||||
|
# )
|
||||||
|
# return tok2vec
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -49,67 +49,8 @@ def set_debug(val):
|
||||||
DEBUG = val
|
DEBUG = val
|
||||||
|
|
||||||
|
|
||||||
@layerize
|
def get_templates(*args, **kwargs):
|
||||||
def get_context_tokens(states, drop=0.):
|
return []
|
||||||
for state in states:
|
|
||||||
context[i, 0] = state.B(0)
|
|
||||||
context[i, 1] = state.S(0)
|
|
||||||
context[i, 2] = state.S(1)
|
|
||||||
context[i, 3] = state.L(state.S(0), 1)
|
|
||||||
context[i, 4] = state.L(state.S(0), 2)
|
|
||||||
context[i, 5] = state.R(state.S(0), 1)
|
|
||||||
context[i, 6] = state.R(state.S(0), 2)
|
|
||||||
return (context, states), None
|
|
||||||
|
|
||||||
|
|
||||||
def extract_features(attrs):
|
|
||||||
def forward(contexts_states, drop=0.):
|
|
||||||
contexts, states = contexts_states
|
|
||||||
for i, state in enumerate(states):
|
|
||||||
for j, tok_i in enumerate(contexts[i]):
|
|
||||||
token = state.get_token(tok_i)
|
|
||||||
for k, attr in enumerate(attrs):
|
|
||||||
output[i, j, k] = getattr(token, attr)
|
|
||||||
return output, None
|
|
||||||
return layerize(forward)
|
|
||||||
|
|
||||||
|
|
||||||
def build_tok2vec(lang, width, depth, embed_size):
|
|
||||||
cols = [LEX_ID, PREFIX, SUFFIX, SHAPE]
|
|
||||||
static = StaticVectors('en', width, column=cols.index(LEX_ID))
|
|
||||||
prefix = HashEmbed(width, embed_size, column=cols.index(PREFIX))
|
|
||||||
suffix = HashEmbed(width, embed_size, column=cols.index(SUFFIX))
|
|
||||||
shape = HashEmbed(width, embed_size, column=cols.index(SHAPE))
|
|
||||||
with Model.overload_operaters('>>': chain, '|': concatenate, '+': add):
|
|
||||||
tok2vec = (
|
|
||||||
extract_features(cols)
|
|
||||||
>> (static | prefix | suffix | shape)
|
|
||||||
>> (ExtractWindow(nW=1) >> Maxout(width)) ** depth
|
|
||||||
)
|
|
||||||
return tok2vec
|
|
||||||
|
|
||||||
|
|
||||||
def build_parse2vec(width, embed_size):
|
|
||||||
cols = [TAG, DEP]
|
|
||||||
tag_vector = HashEmbed(width, 1000, column=cols.index(TAG))
|
|
||||||
dep_vector = HashEmbed(width, 1000, column=cols.index(DEP))
|
|
||||||
with Model.overload_operaters('>>': chain):
|
|
||||||
model = (
|
|
||||||
extract_features([TAG, DEP])
|
|
||||||
>> (tag_vector | dep_vector)
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(get_contexts, tok2vec, parse2vec, width, depth, nr_class):
|
|
||||||
with Model.overload_operaters('>>': chain):
|
|
||||||
model = (
|
|
||||||
get_contexts
|
|
||||||
>> (tok2vec | parse2vec)
|
|
||||||
>> Maxout(width) ** depth
|
|
||||||
>> Softmax(nr_class)
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Parser:
|
cdef class Parser:
|
||||||
|
@ -180,17 +121,21 @@ cdef class Parser:
|
||||||
|
|
||||||
def parse_batch(self, docs):
|
def parse_batch(self, docs):
|
||||||
states = self._init_states(docs)
|
states = self._init_states(docs)
|
||||||
todo = list(states)
|
|
||||||
nr_class = self.moves.n_moves
|
nr_class = self.moves.n_moves
|
||||||
|
cdef StateClass state
|
||||||
|
cdef int guess
|
||||||
|
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
||||||
|
todo = list(states)
|
||||||
while todo:
|
while todo:
|
||||||
scores = self.model.predict(todo)
|
scores = self.model.predict(todo)
|
||||||
self._validate_batch(is_valid, scores, states)
|
self._validate_batch(is_valid, states)
|
||||||
|
scores *= is_valid
|
||||||
for state, guess in zip(todo, scores.argmax(axis=1)):
|
for state, guess in zip(todo, scores.argmax(axis=1)):
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
action.do(state, action.label)
|
action.do(state.c, action.label)
|
||||||
todo = [state for state in todo if not state.is_final()]
|
todo = [state for state in todo if not state.is_final()]
|
||||||
for state, doc in zip(states, docs):
|
for state, doc in zip(states, docs):
|
||||||
self.moves.finalize_state(state, doc)
|
self.moves.finalize_state(state.c)
|
||||||
|
|
||||||
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
||||||
"""
|
"""
|
||||||
|
@ -212,8 +157,6 @@ cdef class Parser:
|
||||||
cdef int status
|
cdef int status
|
||||||
queue = []
|
queue = []
|
||||||
for doc in stream:
|
for doc in stream:
|
||||||
doc_ptr[len(queue)] = doc.c
|
|
||||||
lengths[len(queue)] = doc.length
|
|
||||||
queue.append(doc)
|
queue.append(doc)
|
||||||
if len(queue) == batch_size:
|
if len(queue) == batch_size:
|
||||||
self.parse_batch(queue)
|
self.parse_batch(queue)
|
||||||
|
@ -231,48 +174,76 @@ cdef class Parser:
|
||||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||||
return self.update([docs], [golds], drop=drop)
|
return self.update([docs], [golds], drop=drop)
|
||||||
states = self._init_states(docs)
|
states = self._init_states(docs)
|
||||||
|
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs]
|
||||||
nr_class = self.moves.n_moves
|
nr_class = self.moves.n_moves
|
||||||
|
costs = self.model.ops.allocate((len(docs), nr_class), dtype='f')
|
||||||
|
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
||||||
|
|
||||||
|
todo = zip(states, golds, d_tokens)
|
||||||
while states:
|
while states:
|
||||||
|
states, golds, d_tokens = zip(*todo)
|
||||||
scores, finish_update = self.model.begin_update(states, drop=drop)
|
scores, finish_update = self.model.begin_update(states, drop=drop)
|
||||||
self._validate_batch(is_valid, scores, states)
|
|
||||||
for i, state in enumerate(states):
|
self._cost_batch(is_valid, costs, states, golds)
|
||||||
self.moves.set_costs(costs[i], is_valid, state, golds[i])
|
scores *= is_valid
|
||||||
|
self._set_gradient(gradients, scores, costs)
|
||||||
|
|
||||||
|
token_ids, batch_token_grads = finish_update(gradients, sgd=sgd)
|
||||||
|
for i, tok_i in enumerate(token_ids):
|
||||||
|
d_tokens[tok_i] += batch_token_grads[i]
|
||||||
|
|
||||||
self._transition_batch(states, scores)
|
self._transition_batch(states, scores)
|
||||||
self._set_gradient(gradients, scores, costs)
|
|
||||||
finish_update(gradients, sgd=sgd)
|
# Get unfinished states (and their matching gold and token gradients)
|
||||||
|
todo = zip(states, golds, d_tokens)
|
||||||
|
todo = filter(todo, lambda sp: sp[0].is_final)
|
||||||
|
|
||||||
|
gradients = gradients[:len(todo)]
|
||||||
|
costs = costs[:len(todo)]
|
||||||
|
is_valid = is_valid[:len(todo)]
|
||||||
|
|
||||||
gradients.fill(0)
|
gradients.fill(0)
|
||||||
|
costs.fill(0)
|
||||||
states = [state for state in states if not state.is_final()]
|
is_valid.fill(1)
|
||||||
gradients = gradients[:len(states)]
|
|
||||||
costs = costs[:len(states)]
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _validate_batch(self, is_valid, scores, states):
|
|
||||||
for i, state in enumerate(states):
|
|
||||||
self.moves.set_valid(is_valid, state)
|
|
||||||
for j in range(self.moves.n_moves):
|
|
||||||
if not is_valid[j]:
|
|
||||||
scores[i, j] = 0
|
|
||||||
|
|
||||||
def _transition_batch(self, states, scores):
|
|
||||||
for state, guess in zip(states, scores.argmax(axis=1)):
|
|
||||||
action = self.moves.c[guess]
|
|
||||||
action.do(state, action.label)
|
|
||||||
|
|
||||||
def _init_states(self, docs):
|
def _init_states(self, docs):
|
||||||
states = []
|
states = []
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
cdef StateClass state
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
state = StateClass.init(doc)
|
state = StateClass(doc)
|
||||||
self.moves.initialize_state(state)
|
self.moves.initialize_state(state.c)
|
||||||
|
states.append(state)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
def _validate_batch(self, int[:, ::1] is_valid, states):
|
||||||
|
cdef StateClass state
|
||||||
|
cdef int i
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
self.moves.set_valid(&is_valid[i, 0], state.c)
|
||||||
|
|
||||||
|
def _cost_batch(self, weight_t[:, ::1] costs, int[:, ::1] is_valid,
|
||||||
|
states, golds):
|
||||||
|
cdef int i
|
||||||
|
cdef StateClass state
|
||||||
|
cdef GoldParse gold
|
||||||
|
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||||
|
self.moves.set_costs(&is_valid[i, 0], &costs[i, 0], state, gold)
|
||||||
|
|
||||||
|
def _transition_batch(self, states, scores):
|
||||||
|
cdef StateClass state
|
||||||
|
cdef int guess
|
||||||
|
for state, guess in zip(states, scores.argmax(axis=1)):
|
||||||
|
action = self.moves.c[guess]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
|
||||||
def _set_gradient(self, gradients, scores, costs):
|
def _set_gradient(self, gradients, scores, costs):
|
||||||
"""Do multi-label log loss"""
|
"""Do multi-label log loss"""
|
||||||
cdef double Z, gZ, max_, g_max
|
cdef double Z, gZ, max_, g_max
|
||||||
|
g_scores = scores * (costs <= 0)
|
||||||
maxes = scores.max(axis=1)
|
maxes = scores.max(axis=1)
|
||||||
g_maxes = (scores * costs <= 0).max(axis=1)
|
g_maxes = g_scores.max(axis=1)
|
||||||
exps = (scores-maxes).exp()
|
exps = (scores-maxes).exp()
|
||||||
g_exps = (g_scores-g_maxes).exp()
|
g_exps = (g_scores-g_maxes).exp()
|
||||||
|
|
||||||
|
@ -398,11 +369,11 @@ cdef class StepwiseState:
|
||||||
|
|
||||||
def predict(self):
|
def predict(self):
|
||||||
self.eg.reset()
|
self.eg.reset()
|
||||||
self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
|
#self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
|
||||||
self.stcls.c)
|
# self.stcls.c)
|
||||||
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
|
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
|
||||||
self.parser.model.set_scoresC(self.eg.c.scores,
|
#self.parser.model.set_scoresC(self.eg.c.scores,
|
||||||
self.eg.c.features, self.eg.c.nr_feat)
|
# self.eg.c.features, self.eg.c.nr_feat)
|
||||||
|
|
||||||
cdef Transition action = self.parser.moves.c[self.eg.guess]
|
cdef Transition action = self.parser.moves.c[self.eg.guess]
|
||||||
return self.parser.moves.move_name(action.move, action.label)
|
return self.parser.moves.move_name(action.move, action.label)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user