mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Draft up Parser model
This commit is contained in:
parent
ccaf26206b
commit
7d1df50aec
96
spacy/_ml.py
Normal file
96
spacy/_ml.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
from thinc.api import layerize, chain, clone, concatenate, add
|
||||||
|
from thinc.neural import Model, Maxout, Softmax
|
||||||
|
from thinc.neural._classes.static_vectors import StaticVectors
|
||||||
|
from thinc.neural._classes.hash_embed import HashEmbed
|
||||||
|
from thinc.neural._classes.convolution import ExtractWindow
|
||||||
|
|
||||||
|
from .attrs import ID, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||||
|
|
||||||
|
|
||||||
|
@layerize
|
||||||
|
def get_contexts(states, drop=0.):
|
||||||
|
ops = Model.ops
|
||||||
|
context = ops.allocate((len(states), 7), dtype='uint64')
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
context[i, 0] = state.B(0)
|
||||||
|
context[i, 1] = state.S(0)
|
||||||
|
context[i, 2] = state.S(1)
|
||||||
|
context[i, 3] = state.L(state.S(0), 1)
|
||||||
|
context[i, 4] = state.L(state.S(0), 2)
|
||||||
|
context[i, 5] = state.R(state.S(0), 1)
|
||||||
|
context[i, 6] = state.R(state.S(0), 2)
|
||||||
|
return (context, states), None
|
||||||
|
|
||||||
|
def get_col(idx):
|
||||||
|
def forward(X, drop=0.):
|
||||||
|
return Model.ops.xp.ascontiguousarray(X[:, idx]), None
|
||||||
|
return layerize(forward)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(attrs):
|
||||||
|
ops = Model.ops
|
||||||
|
def forward(contexts_states, drop=0.):
|
||||||
|
contexts, states = contexts_states
|
||||||
|
output = ops.allocate((len(states), contexts.shape[1], len(attrs)),
|
||||||
|
dtype='uint64')
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
for j, tok_i in enumerate(contexts[i]):
|
||||||
|
token = state.get_token(tok_i)
|
||||||
|
for k, attr in enumerate(attrs):
|
||||||
|
output[i, j, k] = getattr(token, attr)
|
||||||
|
return output, None
|
||||||
|
return layerize(forward)
|
||||||
|
|
||||||
|
|
||||||
|
def build_tok2vec(lang, width, depth, embed_size):
|
||||||
|
cols = [ID, PREFIX, SUFFIX, SHAPE]
|
||||||
|
|
||||||
|
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
||||||
|
static = get_col(cols.index(ID)) >> StaticVectors('en', width)
|
||||||
|
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
||||||
|
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
||||||
|
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
||||||
|
tok2vec = (
|
||||||
|
extract_features(cols)
|
||||||
|
>> (static | prefix | suffix | shape)
|
||||||
|
>> (ExtractWindow(nW=1) >> Maxout(width)) ** depth
|
||||||
|
)
|
||||||
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
|
def build_parse2vec(width, embed_size):
|
||||||
|
cols = [TAG, DEP]
|
||||||
|
with Model.define_operators({'>>': chain, '|': concatenate}):
|
||||||
|
tag_vector = get_col(cols.index(TAG)) >> HashEmbed(width, 1000)
|
||||||
|
dep_vector = get_col(cols.index(DEP)) >> HashEmbed(width, 1000)
|
||||||
|
model = (
|
||||||
|
extract_features([TAG, DEP])
|
||||||
|
>> (tag_vector | dep_vector)
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(state2context, tok2vec, parse2vec, width, depth, nr_class):
|
||||||
|
with Model.define_operators({'>>': chain, '**': clone, '|': concatenate}):
|
||||||
|
model = (
|
||||||
|
state2context
|
||||||
|
>> (tok2vec | parse2vec)
|
||||||
|
>> Maxout(width) ** depth
|
||||||
|
>> Softmax(nr_class)
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_model(width=100, depth=2, nr_class=10):
|
||||||
|
model = build_model(
|
||||||
|
get_contexts,
|
||||||
|
build_tok2vec('en', width=100, depth=2, embed_size=1000),
|
||||||
|
build_parse2vec(width=100, embed_size=1000),
|
||||||
|
width,
|
||||||
|
depth,
|
||||||
|
nr_class)
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_build_model()
|
Loading…
Reference in New Issue
Block a user