mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Data running through, likely errors in model
This commit is contained in:
parent
fa7c1990b6
commit
7e04260d38
143
spacy/_ml.py
143
spacy/_ml.py
|
@ -1,4 +1,4 @@
|
||||||
from thinc.api import layerize, chain, clone, concatenate
|
from thinc.api import layerize, chain, clone, concatenate, with_flatten
|
||||||
from thinc.neural import Model, Maxout, Softmax
|
from thinc.neural import Model, Maxout, Softmax
|
||||||
from thinc.neural._classes.hash_embed import HashEmbed
|
from thinc.neural._classes.hash_embed import HashEmbed
|
||||||
|
|
||||||
|
@ -10,88 +10,137 @@ from .attrs import ID, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||||
|
|
||||||
def get_col(idx):
|
def get_col(idx):
|
||||||
def forward(X, drop=0.):
|
def forward(X, drop=0.):
|
||||||
return Model.ops.xp.ascontiguousarray(X[:, idx]), None
|
output = Model.ops.xp.ascontiguousarray(X[:, idx])
|
||||||
|
return output, None
|
||||||
return layerize(forward)
|
return layerize(forward)
|
||||||
|
|
||||||
|
|
||||||
def build_model(state2vec, width, depth, nr_class):
|
def build_model(state2vec, width, depth, nr_class):
|
||||||
with Model.define_operators({'>>': chain, '**': clone}):
|
with Model.define_operators({'>>': chain, '**': clone}):
|
||||||
model = state2vec >> Maxout(width) ** depth >> Softmax(nr_class)
|
model = (
|
||||||
|
state2vec
|
||||||
|
>> Maxout(width, 1344)
|
||||||
|
>> Maxout(width, width)
|
||||||
|
>> Softmax(nr_class, width)
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
|
def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
|
||||||
embed_tags = _reshape(chain(get_col(0), HashEmbed(width, nr_vector)))
|
embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector)))
|
||||||
embed_deps = _reshape(chain(get_col(1), HashEmbed(width, nr_vector)))
|
embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector)))
|
||||||
ops = embed_tags.ops
|
ops = embed_tags.ops
|
||||||
attr_names = ops.asarray([TAG, DEP], dtype='i')
|
def forward(tokens_attrs_vectors, drop=0.):
|
||||||
extract = build_feature_extractor(attr_names, nF, nB, nS, nL, nR)
|
tokens, attr_vals, tokvecs = tokens_attrs_vectors
|
||||||
def forward(states, drop=0.):
|
|
||||||
tokens, attr_vals, tokvecs = extract(states)
|
|
||||||
tagvecs, bp_tagvecs = embed_deps.begin_update(attr_vals, drop=drop)
|
tagvecs, bp_tagvecs = embed_deps.begin_update(attr_vals, drop=drop)
|
||||||
depvecs, bp_depvecs = embed_tags.begin_update(attr_vals, drop=drop)
|
depvecs, bp_depvecs = embed_tags.begin_update(attr_vals, drop=drop)
|
||||||
|
orig_tokvecs_shape = tokvecs.shape
|
||||||
tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] *
|
tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] *
|
||||||
tokvecs.shape[2]))
|
tokvecs.shape[2]))
|
||||||
|
|
||||||
vector = ops.concatenate((tagvecs, depvecs, tokvecs))
|
|
||||||
|
|
||||||
shapes = (tagvecs.shape, depvecs.shape, tokvecs.shape)
|
shapes = (tagvecs.shape, depvecs.shape, tokvecs.shape)
|
||||||
|
assert tagvecs.shape[0] == depvecs.shape[0] == tokvecs.shape[0], shapes
|
||||||
|
vector = ops.xp.hstack((tagvecs, depvecs, tokvecs))
|
||||||
|
|
||||||
def backward(d_vector, sgd=None):
|
def backward(d_vector, sgd=None):
|
||||||
d_depvecs, d_tagvecs, d_tokvecs = ops.backprop_concatenate(d_vector, shapes)
|
d_tagvecs, d_depvecs, d_tokvecs = backprop_concatenate(d_vector, shapes)
|
||||||
|
assert d_tagvecs.shape == shapes[0], (d_tagvecs.shape, shapes)
|
||||||
|
assert d_depvecs.shape == shapes[1], (d_depvecs.shape, shapes)
|
||||||
|
assert d_tokvecs.shape == shapes[2], (d_tokvecs.shape, shapes)
|
||||||
bp_tagvecs(d_tagvecs)
|
bp_tagvecs(d_tagvecs)
|
||||||
bp_depvecs(d_depvecs)
|
bp_depvecs(d_depvecs)
|
||||||
d_tokvecs = d_tokvecs.reshape((len(states), tokens.shape[1], tokvecs.shape[2]))
|
d_tokvecs = d_tokvecs.reshape(orig_tokvecs_shape)
|
||||||
return (d_tokvecs, tokens)
|
|
||||||
|
return (tokens, d_tokvecs)
|
||||||
return vector, backward
|
return vector, backward
|
||||||
model = layerize(forward)
|
model = layerize(forward)
|
||||||
model._layers = [embed_tags, embed_deps]
|
model._layers = [embed_tags, embed_deps]
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_feature_extractor(attr_names, nF, nB, nS, nL, nR):
|
def backprop_concatenate(gradient, shapes):
|
||||||
def forward(states, drop=0.):
|
grads = []
|
||||||
ops = model.ops
|
start = 0
|
||||||
n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
|
for shape in shapes:
|
||||||
vector_length = states[0].token_vector_length
|
end = start + shape[1]
|
||||||
tokens = ops.allocate((len(states), n_tokens), dtype='i')
|
grads.append(gradient[:, start : end])
|
||||||
features = ops.allocate((len(states), n_tokens, attr_names.shape[0]), dtype='i')
|
start = end
|
||||||
tokvecs = ops.allocate((len(states), n_tokens, vector_length), dtype='f')
|
return grads
|
||||||
for i, state in enumerate(states):
|
|
||||||
state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR)
|
|
||||||
state.set_attributes(features[i], tokens[i], attr_names)
|
|
||||||
state.set_token_vectors(tokvecs[i], tokens[i])
|
|
||||||
def backward(d_features, sgd=None):
|
|
||||||
return d_features
|
|
||||||
return (tokens, features, tokvecs), backward
|
|
||||||
model = layerize(forward)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _reshape(layer):
|
def _reshape(layer):
|
||||||
def forward(X, drop=0.):
|
'''Transforms input with shape
|
||||||
Xh = X.reshape((X.shape[0] * X.shape[1], X.shape[2]))
|
(states, tokens, features)
|
||||||
yh, bp_yh = layer.begin_update(Xh, drop=drop)
|
into input with shape:
|
||||||
n = X.shape[0]
|
(states * tokens, features)
|
||||||
old_shape = X.shape
|
So that it can be used with a token-wise feature extraction layer, e.g.
|
||||||
def backward(d_y, sgd=None):
|
an embedding layer. The embedding layer outputs:
|
||||||
d_yh = d_y.reshape((n, d_y.size / n))
|
(states * tokens, ndim)
|
||||||
d_Xh = bp_yh(d_yh, sgd)
|
But we want to concatenate the vectors for the tokens, so we produce:
|
||||||
return d_Xh.reshape(old_shape)
|
(states, tokens * ndim)
|
||||||
return yh.reshape((n, yh.shape / n)), backward
|
We then need to reverse the transforms to do the backward pass. Recall
|
||||||
|
the simple rule here: each layer is a map:
|
||||||
|
inputs -> (outputs, (d_outputs->d_inputs))
|
||||||
|
So the shapes must match like this:
|
||||||
|
shape of forward input == shape of backward output
|
||||||
|
shape of backward input == shape of forward output
|
||||||
|
'''
|
||||||
|
def forward(X__bfm, drop=0.):
|
||||||
|
b, f, m = X__bfm.shape
|
||||||
|
B = b*f
|
||||||
|
M = f*m
|
||||||
|
X__Bm = X__bfm.reshape((B, m))
|
||||||
|
y__Bn, bp_yBn = layer.begin_update(X__Bm, drop=drop)
|
||||||
|
n = y__Bn.shape[1]
|
||||||
|
N = f * n
|
||||||
|
y__bN = y__Bn.reshape((b, N))
|
||||||
|
def backward(dy__bN, sgd=None):
|
||||||
|
dy__Bn = dy__bN.reshape((B, n))
|
||||||
|
dX__Bm = bp_yBn(dy__Bn, sgd)
|
||||||
|
if dX__Bm is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return dX__Bm.reshape((b, f, m))
|
||||||
|
return y__bN, backward
|
||||||
model = layerize(forward)
|
model = layerize(forward)
|
||||||
model._layers.append(layer)
|
model._layers.append(layer)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def build_tok2vec(lang, width, depth, embed_size, cols):
|
|
||||||
|
@layerize
|
||||||
|
def flatten(seqs, drop=0.):
|
||||||
|
ops = Model.ops
|
||||||
|
def finish_update(d_X, sgd=None):
|
||||||
|
return d_X
|
||||||
|
X = ops.xp.concatenate([ops.asarray(seq) for seq in seqs])
|
||||||
|
return X, finish_update
|
||||||
|
|
||||||
|
|
||||||
|
def build_tok2vec(lang, width, depth=2, embed_size=1000):
|
||||||
|
cols = [ID, PREFIX, SUFFIX, SHAPE]
|
||||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
||||||
static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
|
#static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
|
||||||
|
lower = get_col(cols.index(ID)) >> HashEmbed(width, embed_size)
|
||||||
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
||||||
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
||||||
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
||||||
tok2vec = (
|
tok2vec = (
|
||||||
(static | prefix | suffix | shape)
|
doc2feats(cols)
|
||||||
>> Maxout(width, width*4)
|
>> with_flatten(
|
||||||
>> (ExtractWindow(nW=1) >> Maxout(width, width*3)) ** depth
|
#(static | prefix | suffix | shape)
|
||||||
|
(lower | prefix | suffix | shape)
|
||||||
|
>> Maxout(width, width*4)
|
||||||
|
>> (ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||||
|
>> (ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
|
def doc2feats(cols):
|
||||||
|
def forward(docs, drop=0.):
|
||||||
|
feats = [doc.to_array(cols) for doc in docs]
|
||||||
|
feats = [model.ops.asarray(f, dtype='uint64') for f in feats]
|
||||||
|
return feats, None
|
||||||
|
model = layerize(forward)
|
||||||
|
return model
|
||||||
|
|
|
@ -304,5 +304,24 @@ TAG_MAP = {
|
||||||
"VERB__VerbForm=Ger": {"morph": "VerbForm=Ger", "pos": "VERB"},
|
"VERB__VerbForm=Ger": {"morph": "VerbForm=Ger", "pos": "VERB"},
|
||||||
"VERB__VerbForm=Inf": {"morph": "VerbForm=Inf", "pos": "VERB"},
|
"VERB__VerbForm=Inf": {"morph": "VerbForm=Inf", "pos": "VERB"},
|
||||||
"X___": {"morph": "_", "pos": "X"},
|
"X___": {"morph": "_", "pos": "X"},
|
||||||
"SP": {"morph": "_", "pos": "SPACE"}
|
"SP": {"morph": "_", "pos": "SPACE"},
|
||||||
|
"ADV": {POS: ADV},
|
||||||
|
"NOUN": {POS: NOUN},
|
||||||
|
"ADP": {POS: ADP},
|
||||||
|
"PRON": {POS: PRON},
|
||||||
|
"SCONJ": {POS: SCONJ},
|
||||||
|
"PROPN": {POS: PROPN},
|
||||||
|
"DET": {POS: DET},
|
||||||
|
"SYM": {POS: SYM},
|
||||||
|
"INTJ": {POS: INTJ},
|
||||||
|
"PUNCT": {POS: PUNCT},
|
||||||
|
"NUM": {POS: NUM},
|
||||||
|
"AUX": {POS: AUX},
|
||||||
|
"X": {POS: X},
|
||||||
|
"CONJ": {POS: CONJ},
|
||||||
|
"CCONJ": {POS: CCONJ}, # U20
|
||||||
|
"ADJ": {POS: ADJ},
|
||||||
|
"VERB": {POS: VERB},
|
||||||
|
"PART": {POS: PART},
|
||||||
|
"_": {POS: PUNCT}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .syntax.parser cimport Parser
|
from .syntax.parser cimport Parser
|
||||||
from .syntax.beam_parser cimport BeamParser
|
#from .syntax.beam_parser cimport BeamParser
|
||||||
from .syntax.ner cimport BiluoPushDown
|
from .syntax.ner cimport BiluoPushDown
|
||||||
from .syntax.arc_eager cimport ArcEager
|
from .syntax.arc_eager cimport ArcEager
|
||||||
from .tagger cimport Tagger
|
from .tagger cimport Tagger
|
||||||
|
@ -13,9 +13,9 @@ cdef class DependencyParser(Parser):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
cdef class BeamEntityRecognizer(BeamParser):
|
#cdef class BeamEntityRecognizer(BeamParser):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
cdef class BeamDependencyParser(BeamParser):
|
#cdef class BeamDependencyParser(BeamParser):
|
||||||
pass
|
# pass
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from thinc.api import chain, layerize, with_getitem
|
||||||
|
from thinc.neural import Model, Softmax
|
||||||
|
|
||||||
from .syntax.parser cimport Parser
|
from .syntax.parser cimport Parser
|
||||||
from .syntax.beam_parser cimport BeamParser
|
#from .syntax.beam_parser cimport BeamParser
|
||||||
from .syntax.ner cimport BiluoPushDown
|
from .syntax.ner cimport BiluoPushDown
|
||||||
from .syntax.arc_eager cimport ArcEager
|
from .syntax.arc_eager cimport ArcEager
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
|
from ._ml import build_tok2vec
|
||||||
|
|
||||||
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
||||||
# internals.
|
# internals.
|
||||||
|
@ -13,6 +17,39 @@ from .syntax.parser import get_templates as get_feature_templates
|
||||||
from .attrs import DEP, ENT_TYPE
|
from .attrs import DEP, ENT_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
class TokenVectorEncoder(object):
|
||||||
|
'''Assign position-sensitive vectors to tokens, using a CNN or RNN.'''
|
||||||
|
def __init__(self, vocab, **cfg):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = build_tok2vec(vocab.lang, 64, **cfg)
|
||||||
|
self.tagger = chain(
|
||||||
|
self.model,
|
||||||
|
Softmax(self.vocab.morphology.n_tags))
|
||||||
|
|
||||||
|
def __call__(self, doc):
|
||||||
|
doc.tensor = self.model([doc])[0]
|
||||||
|
|
||||||
|
def begin_update(self, docs, drop=0.):
|
||||||
|
tensors, bp_tensors = self.model.begin_update(docs, drop=drop)
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
doc.tensor = tensors[i]
|
||||||
|
return tensors, bp_tensors
|
||||||
|
|
||||||
|
def update(self, docs, golds, drop=0., sgd=None):
|
||||||
|
scores, finish_update = self.tagger.begin_update(docs, drop=drop)
|
||||||
|
losses = scores.copy()
|
||||||
|
loss = 0.0
|
||||||
|
idx = 0
|
||||||
|
for i, gold in enumerate(golds):
|
||||||
|
for j, tag in enumerate(gold.tags):
|
||||||
|
tag_id = docs[0].vocab.morphology.tag_names.index(tag)
|
||||||
|
losses[idx, tag_id] -= 1.0
|
||||||
|
loss += 1-scores[idx, tag_id]
|
||||||
|
idx += 1
|
||||||
|
finish_update(losses, sgd)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(Parser):
|
cdef class EntityRecognizer(Parser):
|
||||||
"""
|
"""
|
||||||
Annotate named entities on Doc objects.
|
Annotate named entities on Doc objects.
|
||||||
|
@ -31,25 +68,25 @@ cdef class EntityRecognizer(Parser):
|
||||||
freqs.append([label, 1])
|
freqs.append([label, 1])
|
||||||
self.vocab._serializer = None
|
self.vocab._serializer = None
|
||||||
|
|
||||||
|
#
|
||||||
cdef class BeamEntityRecognizer(BeamParser):
|
#cdef class BeamEntityRecognizer(BeamParser):
|
||||||
"""
|
# """
|
||||||
Annotate named entities on Doc objects.
|
# Annotate named entities on Doc objects.
|
||||||
"""
|
# """
|
||||||
TransitionSystem = BiluoPushDown
|
# TransitionSystem = BiluoPushDown
|
||||||
|
#
|
||||||
feature_templates = get_feature_templates('ner')
|
# feature_templates = get_feature_templates('ner')
|
||||||
|
#
|
||||||
def add_label(self, label):
|
# def add_label(self, label):
|
||||||
Parser.add_label(self, label)
|
# Parser.add_label(self, label)
|
||||||
if isinstance(label, basestring):
|
# if isinstance(label, basestring):
|
||||||
label = self.vocab.strings[label]
|
# label = self.vocab.strings[label]
|
||||||
# Set label into serializer. Super hacky :(
|
# # Set label into serializer. Super hacky :(
|
||||||
for attr, freqs in self.vocab.serializer_freqs:
|
# for attr, freqs in self.vocab.serializer_freqs:
|
||||||
if attr == ENT_TYPE and label not in freqs:
|
# if attr == ENT_TYPE and label not in freqs:
|
||||||
freqs.append([label, 1])
|
# freqs.append([label, 1])
|
||||||
self.vocab._serializer = None
|
# self.vocab._serializer = None
|
||||||
|
#
|
||||||
|
|
||||||
cdef class DependencyParser(Parser):
|
cdef class DependencyParser(Parser):
|
||||||
TransitionSystem = ArcEager
|
TransitionSystem = ArcEager
|
||||||
|
@ -66,21 +103,22 @@ cdef class DependencyParser(Parser):
|
||||||
# Super hacky :(
|
# Super hacky :(
|
||||||
self.vocab._serializer = None
|
self.vocab._serializer = None
|
||||||
|
|
||||||
|
#
|
||||||
|
#cdef class BeamDependencyParser(BeamParser):
|
||||||
|
# TransitionSystem = ArcEager
|
||||||
|
#
|
||||||
|
# feature_templates = get_feature_templates('basic')
|
||||||
|
#
|
||||||
|
# def add_label(self, label):
|
||||||
|
# Parser.add_label(self, label)
|
||||||
|
# if isinstance(label, basestring):
|
||||||
|
# label = self.vocab.strings[label]
|
||||||
|
# for attr, freqs in self.vocab.serializer_freqs:
|
||||||
|
# if attr == DEP and label not in freqs:
|
||||||
|
# freqs.append([label, 1])
|
||||||
|
# # Super hacky :(
|
||||||
|
# self.vocab._serializer = None
|
||||||
|
#
|
||||||
|
|
||||||
cdef class BeamDependencyParser(BeamParser):
|
#__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer]
|
||||||
TransitionSystem = ArcEager
|
__all__ = [Tagger, DependencyParser, EntityRecognizer]
|
||||||
|
|
||||||
feature_templates = get_feature_templates('basic')
|
|
||||||
|
|
||||||
def add_label(self, label):
|
|
||||||
Parser.add_label(self, label)
|
|
||||||
if isinstance(label, basestring):
|
|
||||||
label = self.vocab.strings[label]
|
|
||||||
for attr, freqs in self.vocab.serializer_freqs:
|
|
||||||
if attr == DEP and label not in freqs:
|
|
||||||
freqs.append([label, 1])
|
|
||||||
# Super hacky :(
|
|
||||||
self.vocab._serializer = None
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer]
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ from ..structs cimport TokenC
|
||||||
from thinc.typedefs cimport weight_t
|
from thinc.typedefs cimport weight_t
|
||||||
|
|
||||||
|
|
||||||
cdef class BeamParser(Parser):
|
#cdef class BeamParser(Parser):
|
||||||
cdef public int beam_width
|
# cdef public int beam_width
|
||||||
cdef public weight_t beam_density
|
# cdef public weight_t beam_density
|
||||||
|
#
|
||||||
cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1
|
# #cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1
|
||||||
|
|
|
@ -56,130 +56,130 @@ def get_templates(name):
|
||||||
cdef int BEAM_WIDTH = 16
|
cdef int BEAM_WIDTH = 16
|
||||||
cdef weight_t BEAM_DENSITY = 0.001
|
cdef weight_t BEAM_DENSITY = 0.001
|
||||||
|
|
||||||
cdef class BeamParser(Parser):
|
#cdef class BeamParser(Parser):
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
self.beam_width = kwargs.get('beam_width', BEAM_WIDTH)
|
# self.beam_width = kwargs.get('beam_width', BEAM_WIDTH)
|
||||||
self.beam_density = kwargs.get('beam_density', BEAM_DENSITY)
|
# self.beam_density = kwargs.get('beam_density', BEAM_DENSITY)
|
||||||
Parser.__init__(self, *args, **kwargs)
|
# Parser.__init__(self, *args, **kwargs)
|
||||||
|
#
|
||||||
cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil:
|
# #cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil:
|
||||||
with gil:
|
# # with gil:
|
||||||
self._parseC(tokens, length, nr_feat, self.moves.n_moves)
|
# # self._parseC(tokens, length, nr_feat, self.moves.n_moves)
|
||||||
|
#
|
||||||
cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
|
# #cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
|
||||||
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
# # cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||||
# TODO: How do we handle new labels here? This increases nr_class
|
# # # TODO: How do we handle new labels here? This increases nr_class
|
||||||
beam.initialize(self.moves.init_beam_state, length, tokens)
|
# # beam.initialize(self.moves.init_beam_state, length, tokens)
|
||||||
beam.check_done(_check_final_state, NULL)
|
# # beam.check_done(_check_final_state, NULL)
|
||||||
if beam.is_done:
|
# # if beam.is_done:
|
||||||
_cleanup(beam)
|
# # _cleanup(beam)
|
||||||
return 0
|
# # return 0
|
||||||
while not beam.is_done:
|
# # while not beam.is_done:
|
||||||
self._advance_beam(beam, None, False)
|
# # self._advance_beam(beam, None, False)
|
||||||
state = <StateClass>beam.at(0)
|
# # state = <StateClass>beam.at(0)
|
||||||
self.moves.finalize_state(state.c)
|
# # self.moves.finalize_state(state.c)
|
||||||
for i in range(length):
|
# # for i in range(length):
|
||||||
tokens[i] = state.c._sent[i]
|
# # tokens[i] = state.c._sent[i]
|
||||||
_cleanup(beam)
|
# # _cleanup(beam)
|
||||||
|
#
|
||||||
def update(self, Doc tokens, GoldParse gold_parse, itn=0):
|
# def update(self, Doc tokens, GoldParse gold_parse, itn=0):
|
||||||
self.moves.preprocess_gold(gold_parse)
|
# self.moves.preprocess_gold(gold_parse)
|
||||||
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
# cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
||||||
pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c)
|
# pred.initialize(self.moves.init_beam_state, tokens.length, tokens.c)
|
||||||
pred.check_done(_check_final_state, NULL)
|
# pred.check_done(_check_final_state, NULL)
|
||||||
# Hack for NER
|
# # Hack for NER
|
||||||
for i in range(pred.size):
|
# for i in range(pred.size):
|
||||||
stcls = <StateClass>pred.at(i)
|
# stcls = <StateClass>pred.at(i)
|
||||||
self.moves.initialize_state(stcls.c)
|
# self.moves.initialize_state(stcls.c)
|
||||||
|
#
|
||||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0)
|
# cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0)
|
||||||
gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c)
|
# gold.initialize(self.moves.init_beam_state, tokens.length, tokens.c)
|
||||||
gold.check_done(_check_final_state, NULL)
|
# gold.check_done(_check_final_state, NULL)
|
||||||
violn = MaxViolation()
|
# violn = MaxViolation()
|
||||||
while not pred.is_done and not gold.is_done:
|
# while not pred.is_done and not gold.is_done:
|
||||||
# We search separately here, to allow for ambiguity in the gold parse.
|
# # We search separately here, to allow for ambiguity in the gold parse.
|
||||||
self._advance_beam(pred, gold_parse, False)
|
# self._advance_beam(pred, gold_parse, False)
|
||||||
self._advance_beam(gold, gold_parse, True)
|
# self._advance_beam(gold, gold_parse, True)
|
||||||
violn.check_crf(pred, gold)
|
# violn.check_crf(pred, gold)
|
||||||
if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
|
# if pred.loss > 0 and pred.min_score > (gold.score + self.model.time):
|
||||||
break
|
# break
|
||||||
else:
|
# else:
|
||||||
# The non-monotonic oracle makes it difficult to ensure final costs are
|
# # The non-monotonic oracle makes it difficult to ensure final costs are
|
||||||
# correct. Therefore do final correction
|
# # correct. Therefore do final correction
|
||||||
for i in range(pred.size):
|
# for i in range(pred.size):
|
||||||
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
|
# if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
|
||||||
pred._states[i].loss = 0.0
|
# pred._states[i].loss = 0.0
|
||||||
elif pred._states[i].loss == 0.0:
|
# elif pred._states[i].loss == 0.0:
|
||||||
pred._states[i].loss = 1.0
|
# pred._states[i].loss = 1.0
|
||||||
violn.check_crf(pred, gold)
|
# violn.check_crf(pred, gold)
|
||||||
if pred.size < 1:
|
# if pred.size < 1:
|
||||||
raise Exception("No candidates", tokens.length)
|
# raise Exception("No candidates", tokens.length)
|
||||||
if gold.size < 1:
|
# if gold.size < 1:
|
||||||
raise Exception("No gold", tokens.length)
|
# raise Exception("No gold", tokens.length)
|
||||||
if pred.loss == 0:
|
# if pred.loss == 0:
|
||||||
self.model.update_from_histories(self.moves, tokens, [(0.0, [])])
|
# self.model.update_from_histories(self.moves, tokens, [(0.0, [])])
|
||||||
elif True:
|
# elif True:
|
||||||
#_check_train_integrity(pred, gold, gold_parse, self.moves)
|
# #_check_train_integrity(pred, gold, gold_parse, self.moves)
|
||||||
histories = list(zip(violn.p_probs, violn.p_hist)) + \
|
# histories = list(zip(violn.p_probs, violn.p_hist)) + \
|
||||||
list(zip(violn.g_probs, violn.g_hist))
|
# list(zip(violn.g_probs, violn.g_hist))
|
||||||
self.model.update_from_histories(self.moves, tokens, histories, min_grad=0.001**(itn+1))
|
# self.model.update_from_histories(self.moves, tokens, histories, min_grad=0.001**(itn+1))
|
||||||
else:
|
# else:
|
||||||
self.model.update_from_histories(self.moves, tokens,
|
# self.model.update_from_histories(self.moves, tokens,
|
||||||
[(1.0, violn.p_hist[0]), (-1.0, violn.g_hist[0])])
|
# [(1.0, violn.p_hist[0]), (-1.0, violn.g_hist[0])])
|
||||||
_cleanup(pred)
|
# _cleanup(pred)
|
||||||
_cleanup(gold)
|
# _cleanup(gold)
|
||||||
return pred.loss
|
# return pred.loss
|
||||||
|
#
|
||||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
# def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
# cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef Pool mem = Pool()
|
# cdef Pool mem = Pool()
|
||||||
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
# features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
||||||
if False:
|
# if False:
|
||||||
mb = Minibatch(self.model.widths, beam.size)
|
# mb = Minibatch(self.model.widths, beam.size)
|
||||||
for i in range(beam.size):
|
# for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
# stcls = <StateClass>beam.at(i)
|
||||||
if stcls.c.is_final():
|
# if stcls.c.is_final():
|
||||||
nr_feat = 0
|
# nr_feat = 0
|
||||||
else:
|
# else:
|
||||||
nr_feat = self.model.set_featuresC(context, features, stcls.c)
|
# nr_feat = self.model.set_featuresC(context, features, stcls.c)
|
||||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
# self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||||
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
# mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
||||||
self.model(mb)
|
# self.model(mb)
|
||||||
for i in range(beam.size):
|
# for i in range(beam.size):
|
||||||
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
# memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
||||||
else:
|
# else:
|
||||||
for i in range(beam.size):
|
# for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
# stcls = <StateClass>beam.at(i)
|
||||||
if not stcls.is_final():
|
# if not stcls.is_final():
|
||||||
nr_feat = self.model.set_featuresC(context, features, stcls.c)
|
# nr_feat = self.model.set_featuresC(context, features, stcls.c)
|
||||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
# self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||||
self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
# self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
||||||
if gold is not None:
|
# if gold is not None:
|
||||||
n_gold = 0
|
# n_gold = 0
|
||||||
lines = []
|
# lines = []
|
||||||
for i in range(beam.size):
|
# for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
# stcls = <StateClass>beam.at(i)
|
||||||
if not stcls.c.is_final():
|
# if not stcls.c.is_final():
|
||||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
|
# self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
|
||||||
if follow_gold:
|
# if follow_gold:
|
||||||
for j in range(self.moves.n_moves):
|
# for j in range(self.moves.n_moves):
|
||||||
if beam.costs[i][j] >= 1:
|
# if beam.costs[i][j] >= 1:
|
||||||
beam.is_valid[i][j] = 0
|
# beam.is_valid[i][j] = 0
|
||||||
lines.append((stcls.B(0), stcls.B(1),
|
# lines.append((stcls.B(0), stcls.B(1),
|
||||||
stcls.B_(0).ent_iob, stcls.B_(1).ent_iob,
|
# stcls.B_(0).ent_iob, stcls.B_(1).ent_iob,
|
||||||
stcls.B_(1).sent_start,
|
# stcls.B_(1).sent_start,
|
||||||
j,
|
# j,
|
||||||
beam.is_valid[i][j], 'set invalid',
|
# beam.is_valid[i][j], 'set invalid',
|
||||||
beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label))
|
# beam.costs[i][j], self.moves.c[j].move, self.moves.c[j].label))
|
||||||
n_gold += 1 if beam.is_valid[i][j] else 0
|
# n_gold += 1 if beam.is_valid[i][j] else 0
|
||||||
if follow_gold and n_gold == 0:
|
# if follow_gold and n_gold == 0:
|
||||||
raise Exception("No gold")
|
# raise Exception("No gold")
|
||||||
if follow_gold:
|
# if follow_gold:
|
||||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
# beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||||
else:
|
# else:
|
||||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
# beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||||
beam.check_done(_check_final_state, NULL)
|
# beam.check_done(_check_final_state, NULL)
|
||||||
|
#
|
||||||
|
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||||
|
|
|
@ -40,6 +40,9 @@ from ..structs cimport TokenC
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..strings cimport StringStore
|
from ..strings cimport StringStore
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
from ..attrs cimport TAG, DEP
|
||||||
|
|
||||||
|
from .._ml import build_parser_state2vec, build_model
|
||||||
|
|
||||||
|
|
||||||
USE_FTRL = True
|
USE_FTRL = True
|
||||||
|
@ -107,6 +110,11 @@ cdef class Parser:
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||||
|
|
||||||
|
def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
|
||||||
|
state2vec = build_parser_state2vec(width, nr_vector, nF, nB, nL, nR)
|
||||||
|
model = build_model(state2vec, width, 2, self.moves.n_moves)
|
||||||
|
return model
|
||||||
|
|
||||||
def __call__(self, Doc tokens):
|
def __call__(self, Doc tokens):
|
||||||
"""
|
"""
|
||||||
Apply the parser or entity recognizer, setting the annotations onto the Doc object.
|
Apply the parser or entity recognizer, setting the annotations onto the Doc object.
|
||||||
|
@ -118,25 +126,7 @@ cdef class Parser:
|
||||||
"""
|
"""
|
||||||
self.parse_batch([tokens])
|
self.parse_batch([tokens])
|
||||||
self.moves.finalize_doc(tokens)
|
self.moves.finalize_doc(tokens)
|
||||||
|
|
||||||
def parse_batch(self, docs):
|
|
||||||
states = self._init_states(docs)
|
|
||||||
nr_class = self.moves.n_moves
|
|
||||||
cdef StateClass state
|
|
||||||
cdef int guess
|
|
||||||
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
|
||||||
todo = list(states)
|
|
||||||
while todo:
|
|
||||||
scores = self.model.predict(todo)
|
|
||||||
self._validate_batch(is_valid, states)
|
|
||||||
scores *= is_valid
|
|
||||||
for state, guess in zip(todo, scores.argmax(axis=1)):
|
|
||||||
action = self.moves.c[guess]
|
|
||||||
action.do(state.c, action.label)
|
|
||||||
todo = [state for state in todo if not state.is_final()]
|
|
||||||
for state, doc in zip(states, docs):
|
|
||||||
self.moves.finalize_state(state.c)
|
|
||||||
|
|
||||||
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
||||||
"""
|
"""
|
||||||
Process a stream of documents.
|
Process a stream of documents.
|
||||||
|
@ -170,53 +160,106 @@ cdef class Parser:
|
||||||
self.moves.finalize_doc(doc)
|
self.moves.finalize_doc(doc)
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
|
def parse_batch(self, docs):
|
||||||
|
states = self._init_states(docs)
|
||||||
|
nr_class = self.moves.n_moves
|
||||||
|
cdef Doc doc
|
||||||
|
cdef StateClass state
|
||||||
|
cdef int guess
|
||||||
|
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
||||||
|
tokvecs = [d.tensor for d in docs]
|
||||||
|
attr_names = self.model.ops.allocate((2,), dtype='i')
|
||||||
|
attr_names[0] = TAG
|
||||||
|
attr_names[1] = DEP
|
||||||
|
all_states = list(states)
|
||||||
|
todo = zip(states, tokvecs)
|
||||||
|
while todo:
|
||||||
|
states, tokvecs = zip(*todo)
|
||||||
|
features = self._get_features(states, tokvecs, attr_names)
|
||||||
|
scores = self.model.predict(features)
|
||||||
|
self._validate_batch(is_valid, states)
|
||||||
|
scores *= is_valid
|
||||||
|
for state, guess in zip(states, scores.argmax(axis=1)):
|
||||||
|
action = self.moves.c[guess]
|
||||||
|
action.do(state.c, action.label)
|
||||||
|
todo = filter(lambda sp: not sp[0].is_final(), todo)
|
||||||
|
for state, doc in zip(all_states, docs):
|
||||||
|
self.moves.finalize_state(state.c)
|
||||||
|
for i in range(doc.length):
|
||||||
|
doc.c[i] = state.c._sent[i]
|
||||||
|
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0., sgd=None):
|
def update(self, docs, golds, drop=0., sgd=None):
|
||||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||||
return self.update([docs], [golds], drop=drop)
|
return self.update([docs], [golds], drop=drop)
|
||||||
|
for gold in golds:
|
||||||
|
self.moves.preprocess_gold(gold)
|
||||||
states = self._init_states(docs)
|
states = self._init_states(docs)
|
||||||
|
tokvecs = [d.tensor for d in docs]
|
||||||
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs]
|
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs]
|
||||||
nr_class = self.moves.n_moves
|
nr_class = self.moves.n_moves
|
||||||
costs = self.model.ops.allocate((len(docs), nr_class), dtype='f')
|
costs = self.model.ops.allocate((len(docs), nr_class), dtype='f')
|
||||||
|
gradients = self.model.ops.allocate((len(docs), nr_class), dtype='f')
|
||||||
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
||||||
|
attr_names = self.model.ops.allocate((2,), dtype='i')
|
||||||
|
attr_names[0] = TAG
|
||||||
|
attr_names[1] = DEP
|
||||||
|
output = list(d_tokens)
|
||||||
|
todo = zip(states, tokvecs, golds, d_tokens)
|
||||||
|
assert len(states) == len(todo)
|
||||||
|
loss = 0.
|
||||||
|
while todo:
|
||||||
|
states, tokvecs, golds, d_tokens = zip(*todo)
|
||||||
|
features = self._get_features(states, tokvecs, attr_names)
|
||||||
|
|
||||||
todo = zip(states, golds, d_tokens)
|
scores, finish_update = self.model.begin_update(features, drop=drop)
|
||||||
while states:
|
assert scores.shape == (len(states), self.moves.n_moves), (len(states), scores.shape)
|
||||||
states, golds, d_tokens = zip(*todo)
|
|
||||||
scores, finish_update = self.model.begin_update(states, drop=drop)
|
self._cost_batch(costs, is_valid, states, golds)
|
||||||
|
|
||||||
self._cost_batch(is_valid, costs, states, golds)
|
|
||||||
scores *= is_valid
|
scores *= is_valid
|
||||||
self._set_gradient(gradients, scores, costs)
|
self._set_gradient(gradients, scores, costs)
|
||||||
|
loss += numpy.abs(gradients).sum() / gradients.shape[0]
|
||||||
|
|
||||||
token_ids, batch_token_grads = finish_update(gradients, sgd=sgd)
|
token_ids, batch_token_grads = finish_update(gradients, sgd=sgd)
|
||||||
for i, tok_i in enumerate(token_ids):
|
for i, tok_i in enumerate(token_ids):
|
||||||
d_tokens[tok_i] += batch_token_grads[i]
|
d_tokens[i][tok_i] += batch_token_grads[i]
|
||||||
|
|
||||||
self._transition_batch(states, scores)
|
self._transition_batch(states, scores)
|
||||||
|
|
||||||
# Get unfinished states (and their matching gold and token gradients)
|
# Get unfinished states (and their matching gold and token gradients)
|
||||||
todo = zip(states, golds, d_tokens)
|
todo = filter(lambda sp: not sp[0].is_final(), todo)
|
||||||
todo = filter(todo, lambda sp: sp[0].is_final)
|
|
||||||
|
|
||||||
gradients = gradients[:len(todo)]
|
|
||||||
costs = costs[:len(todo)]
|
costs = costs[:len(todo)]
|
||||||
is_valid = is_valid[:len(todo)]
|
is_valid = is_valid[:len(todo)]
|
||||||
|
gradients = gradients[:len(todo)]
|
||||||
|
|
||||||
gradients.fill(0)
|
gradients.fill(0)
|
||||||
costs.fill(0)
|
costs.fill(0)
|
||||||
is_valid.fill(1)
|
is_valid.fill(1)
|
||||||
return 0
|
return output, loss
|
||||||
|
|
||||||
def _init_states(self, docs):
|
def _init_states(self, docs):
|
||||||
states = []
|
states = []
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
state = StateClass(doc)
|
state = StateClass.init(doc.c, doc.length)
|
||||||
self.moves.initialize_state(state.c)
|
self.moves.initialize_state(state.c)
|
||||||
states.append(state)
|
states.append(state)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
def _get_features(self, states, all_tokvecs, attr_names,
|
||||||
|
nF=1, nB=0, nS=2, nL=2, nR=2):
|
||||||
|
n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
|
||||||
|
vector_length = all_tokvecs[0].shape[1]
|
||||||
|
tokens = self.model.ops.allocate((len(states), n_tokens), dtype='int32')
|
||||||
|
features = self.model.ops.allocate((len(states), n_tokens, attr_names.shape[0]), dtype='uint64')
|
||||||
|
tokvecs = self.model.ops.allocate((len(states), n_tokens, vector_length), dtype='f')
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
state.set_context_tokens(tokens[i], nF, nB, nS, nL, nR)
|
||||||
|
state.set_attributes(features[i], tokens[i], attr_names)
|
||||||
|
state.set_token_vectors(tokvecs[i], all_tokvecs[i], tokens[i])
|
||||||
|
return (tokens, features, tokvecs)
|
||||||
|
|
||||||
def _validate_batch(self, int[:, ::1] is_valid, states):
|
def _validate_batch(self, int[:, ::1] is_valid, states):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int i
|
cdef int i
|
||||||
|
@ -242,13 +285,13 @@ cdef class Parser:
|
||||||
"""Do multi-label log loss"""
|
"""Do multi-label log loss"""
|
||||||
cdef double Z, gZ, max_, g_max
|
cdef double Z, gZ, max_, g_max
|
||||||
g_scores = scores * (costs <= 0)
|
g_scores = scores * (costs <= 0)
|
||||||
maxes = scores.max(axis=1)
|
maxes = scores.max(axis=1).reshape((scores.shape[0], 1))
|
||||||
g_maxes = g_scores.max(axis=1)
|
g_maxes = g_scores.max(axis=1).reshape((g_scores.shape[0], 1))
|
||||||
exps = (scores-maxes).exp()
|
exps = numpy.exp((scores-maxes))
|
||||||
g_exps = (g_scores-g_maxes).exp()
|
g_exps = numpy.exp(g_scores-g_maxes)
|
||||||
|
|
||||||
Zs = exps.sum(axis=1)
|
Zs = exps.sum(axis=1).reshape((exps.shape[0], 1))
|
||||||
gZs = g_exps.sum(axis=1)
|
gZs = g_exps.sum(axis=1).reshape((g_exps.shape[0], 1))
|
||||||
logprob = exps / Zs
|
logprob = exps / Zs
|
||||||
g_logprob = g_exps / gZs
|
g_logprob = g_exps / gZs
|
||||||
gradients[:] = logprob - g_logprob
|
gradients[:] = logprob - g_logprob
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from libc.string cimport memcpy, memset
|
from libc.string cimport memcpy, memset
|
||||||
|
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
cimport cython
|
||||||
|
|
||||||
from ..structs cimport TokenC, Entity
|
from ..structs cimport TokenC, Entity
|
||||||
|
|
||||||
|
@ -8,7 +9,7 @@ from ..vocab cimport EMPTY_LEXEME
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
|
|
||||||
|
|
||||||
|
@cython.final
|
||||||
cdef class StateClass:
|
cdef class StateClass:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef StateC* c
|
cdef StateC* c
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
# cython: infer_types=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from libc.string cimport memcpy, memset
|
from libc.string cimport memcpy, memset
|
||||||
from libc.stdint cimport uint32_t
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
|
|
||||||
from ..vocab cimport EMPTY_LEXEME
|
from ..vocab cimport EMPTY_LEXEME
|
||||||
from ..structs cimport Entity
|
from ..structs cimport Entity
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..symbols cimport punct
|
from ..symbols cimport punct
|
||||||
from ..attrs cimport IS_SPACE
|
from ..attrs cimport IS_SPACE
|
||||||
|
from ..attrs cimport attr_id_t
|
||||||
|
from ..tokens.token cimport Token
|
||||||
|
|
||||||
|
|
||||||
cdef class StateClass:
|
cdef class StateClass:
|
||||||
|
@ -27,6 +30,13 @@ cdef class StateClass:
|
||||||
def queue(self):
|
def queue(self):
|
||||||
return {self.B(i) for i in range(self.c.buffer_length())}
|
return {self.B(i) for i in range(self.c.buffer_length())}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_vector_lenth(self):
|
||||||
|
return self.doc.tensor.shape[1]
|
||||||
|
|
||||||
|
def is_final(self):
|
||||||
|
return self.c.is_final()
|
||||||
|
|
||||||
def print_state(self, words):
|
def print_state(self, words):
|
||||||
words = list(words) + ['_']
|
words = list(words) + ['_']
|
||||||
top = words[self.S(0)] + '_%d' % self.S_(0).head
|
top = words[self.S(0)] + '_%d' % self.S_(0).head
|
||||||
|
@ -35,3 +45,33 @@ cdef class StateClass:
|
||||||
n0 = words[self.B(0)]
|
n0 = words[self.B(0)]
|
||||||
n1 = words[self.B(1)]
|
n1 = words[self.B(1)]
|
||||||
return ' '.join((third, second, top, '|', n0, n1))
|
return ' '.join((third, second, top, '|', n0, n1))
|
||||||
|
|
||||||
|
def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR):
|
||||||
|
return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR)
|
||||||
|
|
||||||
|
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
|
||||||
|
nL=2, nR=2):
|
||||||
|
output[0] = self.B(0)
|
||||||
|
output[1] = self.S(0)
|
||||||
|
output[2] = self.S(1)
|
||||||
|
output[3] = self.L(self.S(0), 1)
|
||||||
|
output[4] = self.L(self.S(0), 2)
|
||||||
|
output[5] = self.R(self.S(0), 1)
|
||||||
|
output[6] = self.R(self.S(0), 2)
|
||||||
|
output[7] = self.L(self.S(1), 1)
|
||||||
|
output[8] = self.L(self.S(1), 2)
|
||||||
|
output[9] = self.R(self.S(1), 1)
|
||||||
|
output[10] = self.R(self.S(1), 2)
|
||||||
|
|
||||||
|
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
|
||||||
|
cdef int i, j, tok_i
|
||||||
|
for i in range(tokens.shape[0]):
|
||||||
|
tok_i = tokens[i]
|
||||||
|
token = &self.c._sent[tok_i]
|
||||||
|
for j in range(names.shape[0]):
|
||||||
|
vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j])
|
||||||
|
|
||||||
|
def set_token_vectors(self, float[:, :] tokvecs,
|
||||||
|
float[:, :] all_tokvecs, int[:] indices):
|
||||||
|
for i in range(indices.shape[0]):
|
||||||
|
tokvecs[i] = all_tokvecs[indices[i]]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user