mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-28 21:03:41 +03:00
Gradients look correct
This commit is contained in:
parent
7e04260d38
commit
8e48b58cd6
|
@ -1,4 +1,4 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals, print_function
|
||||||
import plac
|
import plac
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
@ -9,7 +9,7 @@ from spacy.syntax.nonproj import PseudoProjectivity
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.tagger import Tagger
|
from spacy.tagger import Tagger
|
||||||
from spacy.pipeline import DependencyParser, BeamDependencyParser
|
from spacy.pipeline import DependencyParser, TokenVectorEncoder
|
||||||
from spacy.syntax.parser import get_templates
|
from spacy.syntax.parser import get_templates
|
||||||
from spacy.syntax.arc_eager import ArcEager
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
|
@ -36,10 +36,10 @@ def read_conllx(loc, n=0):
|
||||||
try:
|
try:
|
||||||
id_ = int(id_) - 1
|
id_ = int(id_) - 1
|
||||||
head = (int(head) - 1) if head != '0' else id_
|
head = (int(head) - 1) if head != '0' else id_
|
||||||
dep = 'ROOT' if dep == 'root' else dep
|
dep = 'ROOT' if dep == 'root' else 'unlabelled'
|
||||||
tokens.append((id_, word, tag, head, dep, 'O'))
|
# Hack for efficiency
|
||||||
|
tokens.append((id_, word, pos+'__'+morph, head, dep, 'O'))
|
||||||
except:
|
except:
|
||||||
print(line)
|
|
||||||
raise
|
raise
|
||||||
tuples = [list(t) for t in zip(*tokens)]
|
tuples = [list(t) for t in zip(*tokens)]
|
||||||
yield (None, [[tuples, []]])
|
yield (None, [[tuples, []]])
|
||||||
|
@ -48,19 +48,37 @@ def read_conllx(loc, n=0):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
def score_model(vocab, encoder, tagger, parser, Xs, ys, verbose=False):
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for _, gold_doc in gold_docs:
|
correct = 0.
|
||||||
for (ids, words, tags, heads, deps, entities), _ in gold_doc:
|
total = 0.
|
||||||
doc = Doc(vocab, words=words)
|
for doc, gold in zip(Xs, ys):
|
||||||
tagger(doc)
|
doc = Doc(vocab, words=[w.text for w in doc])
|
||||||
parser(doc)
|
encoder(doc)
|
||||||
PseudoProjectivity.deprojectivize(doc)
|
tagger(doc)
|
||||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
parser(doc)
|
||||||
scorer.score(doc, gold, verbose=verbose)
|
PseudoProjectivity.deprojectivize(doc)
|
||||||
|
scorer.score(doc, gold, verbose=verbose)
|
||||||
|
for token, tag in zip(doc, gold.tags):
|
||||||
|
univ_guess, _ = token.tag_.split('_', 1)
|
||||||
|
univ_truth, _ = tag.split('_', 1)
|
||||||
|
correct += univ_guess == univ_truth
|
||||||
|
total += 1
|
||||||
return scorer
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
|
def organize_data(vocab, train_sents):
|
||||||
|
Xs = []
|
||||||
|
ys = []
|
||||||
|
for _, doc_sents in train_sents:
|
||||||
|
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||||
|
doc = Doc(vocab, words=words)
|
||||||
|
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||||
|
Xs.append(doc)
|
||||||
|
ys.append(gold)
|
||||||
|
return Xs, ys
|
||||||
|
|
||||||
|
|
||||||
def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
|
def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
|
||||||
LangClass = spacy.util.get_lang_class(lang_name)
|
LangClass = spacy.util.get_lang_class(lang_name)
|
||||||
train_sents = list(read_conllx(train_loc))
|
train_sents = list(read_conllx(train_loc))
|
||||||
|
@ -114,21 +132,37 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
assert tag in vocab.morphology.tag_map, repr(tag)
|
assert tag in vocab.morphology.tag_map, repr(tag)
|
||||||
tagger = Tagger(vocab)
|
tagger = Tagger(vocab)
|
||||||
|
encoder = TokenVectorEncoder(vocab)
|
||||||
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
|
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
|
||||||
|
|
||||||
for itn in range(30):
|
|
||||||
loss = 0.
|
Xs, ys = organize_data(vocab, train_sents)
|
||||||
for _, doc_sents in train_sents:
|
Xs = Xs[:1]
|
||||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
ys = ys[:1]
|
||||||
doc = Doc(vocab, words=words)
|
with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer):
|
||||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
docs = list(Xs)
|
||||||
tagger(doc)
|
for doc in docs:
|
||||||
loss += parser.update(doc, gold, itn=itn)
|
encoder(doc)
|
||||||
doc = Doc(vocab, words=words)
|
parser.begin_training(docs, ys)
|
||||||
|
nn_loss = [0.]
|
||||||
|
def track_progress():
|
||||||
|
scorer = score_model(vocab, encoder, tagger, parser, Xs, ys)
|
||||||
|
itn = len(nn_loss)
|
||||||
|
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc))
|
||||||
|
nn_loss.append(0.)
|
||||||
|
trainer.each_epoch.append(track_progress)
|
||||||
|
trainer.batch_size = 1
|
||||||
|
trainer.nb_epoch = 100
|
||||||
|
for docs, golds in trainer.iterate(Xs, ys, progress_bar=False):
|
||||||
|
docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs]
|
||||||
|
tokvecs, upd_tokvecs = encoder.begin_update(docs)
|
||||||
|
for doc, tokvec in zip(docs, tokvecs):
|
||||||
|
doc.tensor = tokvec
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
tagger.update(doc, gold)
|
tagger.update(doc, gold)
|
||||||
random.shuffle(train_sents)
|
d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer)
|
||||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
upd_tokvecs(d_tokvecs, sgd=optimizer)
|
||||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
|
nn_loss[-1] += loss
|
||||||
nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser)
|
nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser)
|
||||||
nlp.end_training(model_dir)
|
nlp.end_training(model_dir)
|
||||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||||
|
|
34
spacy/_ml.py
34
spacy/_ml.py
|
@ -1,5 +1,5 @@
|
||||||
from thinc.api import layerize, chain, clone, concatenate, with_flatten
|
from thinc.api import layerize, chain, clone, concatenate, with_flatten
|
||||||
from thinc.neural import Model, Maxout, Softmax
|
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||||
from thinc.neural._classes.hash_embed import HashEmbed
|
from thinc.neural._classes.hash_embed import HashEmbed
|
||||||
|
|
||||||
from thinc.neural._classes.convolution import ExtractWindow
|
from thinc.neural._classes.convolution import ExtractWindow
|
||||||
|
@ -21,11 +21,41 @@ def build_model(state2vec, width, depth, nr_class):
|
||||||
state2vec
|
state2vec
|
||||||
>> Maxout(width, 1344)
|
>> Maxout(width, 1344)
|
||||||
>> Maxout(width, width)
|
>> Maxout(width, width)
|
||||||
>> Softmax(nr_class, width)
|
>> Affine(nr_class, width)
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_debug_model(state2vec, width, depth, nr_class):
|
||||||
|
with Model.define_operators({'>>': chain, '**': clone}):
|
||||||
|
model = (
|
||||||
|
state2vec
|
||||||
|
>> Maxout(width)
|
||||||
|
>> Affine(nr_class)
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_debug_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
|
||||||
|
ops = Model.ops
|
||||||
|
def forward(tokens_attrs_vectors, drop=0.):
|
||||||
|
tokens, attr_vals, tokvecs = tokens_attrs_vectors
|
||||||
|
|
||||||
|
orig_tokvecs_shape = tokvecs.shape
|
||||||
|
tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] *
|
||||||
|
tokvecs.shape[2]))
|
||||||
|
|
||||||
|
vector = tokvecs
|
||||||
|
|
||||||
|
def backward(d_vector, sgd=None):
|
||||||
|
d_tokvecs = vector.reshape(orig_tokvecs_shape)
|
||||||
|
return (tokens, d_tokvecs)
|
||||||
|
return vector, backward
|
||||||
|
model = layerize(forward)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
|
def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2):
|
||||||
embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector)))
|
embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector)))
|
||||||
embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector)))
|
embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector)))
|
||||||
|
|
|
@ -28,6 +28,8 @@ from murmurhash.mrmr cimport hash64
|
||||||
from preshed.maps cimport MapStruct
|
from preshed.maps cimport MapStruct
|
||||||
from preshed.maps cimport map_get
|
from preshed.maps cimport map_get
|
||||||
|
|
||||||
|
from numpy import exp
|
||||||
|
|
||||||
from . import _parse_features
|
from . import _parse_features
|
||||||
from ._parse_features cimport CONTEXT_SIZE
|
from ._parse_features cimport CONTEXT_SIZE
|
||||||
from ._parse_features cimport fill_context
|
from ._parse_features cimport fill_context
|
||||||
|
@ -43,6 +45,7 @@ from ..gold cimport GoldParse
|
||||||
from ..attrs cimport TAG, DEP
|
from ..attrs cimport TAG, DEP
|
||||||
|
|
||||||
from .._ml import build_parser_state2vec, build_model
|
from .._ml import build_parser_state2vec, build_model
|
||||||
|
from .._ml import build_debug_state2vec, build_debug_model
|
||||||
|
|
||||||
|
|
||||||
USE_FTRL = True
|
USE_FTRL = True
|
||||||
|
@ -111,8 +114,8 @@ cdef class Parser:
|
||||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||||
|
|
||||||
def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
|
def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_):
|
||||||
state2vec = build_parser_state2vec(width, nr_vector, nF, nB, nL, nR)
|
state2vec = build_debug_state2vec(width, nr_vector, nF, nB, nL, nR)
|
||||||
model = build_model(state2vec, width, 2, self.moves.n_moves)
|
model = build_debug_model(state2vec, width, 2, self.moves.n_moves)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def __call__(self, Doc tokens):
|
def __call__(self, Doc tokens):
|
||||||
|
@ -166,32 +169,22 @@ cdef class Parser:
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef int guess
|
cdef int guess
|
||||||
is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i')
|
|
||||||
tokvecs = [d.tensor for d in docs]
|
tokvecs = [d.tensor for d in docs]
|
||||||
attr_names = self.model.ops.allocate((2,), dtype='i')
|
|
||||||
attr_names[0] = TAG
|
|
||||||
attr_names[1] = DEP
|
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
todo = zip(states, tokvecs)
|
todo = zip(states, tokvecs)
|
||||||
while todo:
|
while todo:
|
||||||
states, tokvecs = zip(*todo)
|
states, tokvecs = zip(*todo)
|
||||||
features = self._get_features(states, tokvecs, attr_names)
|
scores, _ = self._begin_update(states, tokvecs)
|
||||||
scores = self.model.predict(features)
|
|
||||||
self._validate_batch(is_valid, states)
|
|
||||||
scores *= is_valid
|
|
||||||
for state, guess in zip(states, scores.argmax(axis=1)):
|
for state, guess in zip(states, scores.argmax(axis=1)):
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
todo = filter(lambda sp: not sp[0].is_final(), todo)
|
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||||
for state, doc in zip(all_states, docs):
|
for state, doc in zip(all_states, docs):
|
||||||
self.moves.finalize_state(state.c)
|
self.moves.finalize_state(state.c)
|
||||||
for i in range(doc.length):
|
for i in range(doc.length):
|
||||||
doc.c[i] = state.c._sent[i]
|
doc.c[i] = state.c._sent[i]
|
||||||
|
|
||||||
|
def begin_training(self, docs, golds):
|
||||||
def update(self, docs, golds, drop=0., sgd=None):
|
|
||||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
|
||||||
return self.update([docs], [golds], drop=drop)
|
|
||||||
for gold in golds:
|
for gold in golds:
|
||||||
self.moves.preprocess_gold(gold)
|
self.moves.preprocess_gold(gold)
|
||||||
states = self._init_states(docs)
|
states = self._init_states(docs)
|
||||||
|
@ -204,39 +197,60 @@ cdef class Parser:
|
||||||
attr_names = self.model.ops.allocate((2,), dtype='i')
|
attr_names = self.model.ops.allocate((2,), dtype='i')
|
||||||
attr_names[0] = TAG
|
attr_names[0] = TAG
|
||||||
attr_names[1] = DEP
|
attr_names[1] = DEP
|
||||||
|
|
||||||
|
features = self._get_features(states, tokvecs, attr_names)
|
||||||
|
self.model.begin_training(features)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, docs, golds, drop=0., sgd=None):
|
||||||
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||||
|
return self.update([docs], [golds], drop=drop)
|
||||||
|
for gold in golds:
|
||||||
|
self.moves.preprocess_gold(gold)
|
||||||
|
states = self._init_states(docs)
|
||||||
|
tokvecs = [d.tensor for d in docs]
|
||||||
|
d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs]
|
||||||
|
nr_class = self.moves.n_moves
|
||||||
output = list(d_tokens)
|
output = list(d_tokens)
|
||||||
todo = zip(states, tokvecs, golds, d_tokens)
|
todo = zip(states, tokvecs, golds, d_tokens)
|
||||||
assert len(states) == len(todo)
|
assert len(states) == len(todo)
|
||||||
loss = 0.
|
loss = 0.
|
||||||
while todo:
|
while todo:
|
||||||
states, tokvecs, golds, d_tokens = zip(*todo)
|
states, tokvecs, golds, d_tokens = zip(*todo)
|
||||||
features = self._get_features(states, tokvecs, attr_names)
|
scores, finish_update = self._begin_update(states, tokvecs)
|
||||||
|
token_ids, batch_token_grads = finish_update(golds, sgd=sgd)
|
||||||
scores, finish_update = self.model.begin_update(features, drop=drop)
|
|
||||||
assert scores.shape == (len(states), self.moves.n_moves), (len(states), scores.shape)
|
|
||||||
|
|
||||||
self._cost_batch(costs, is_valid, states, golds)
|
|
||||||
scores *= is_valid
|
|
||||||
self._set_gradient(gradients, scores, costs)
|
|
||||||
loss += numpy.abs(gradients).sum() / gradients.shape[0]
|
|
||||||
|
|
||||||
token_ids, batch_token_grads = finish_update(gradients, sgd=sgd)
|
|
||||||
for i, tok_i in enumerate(token_ids):
|
for i, tok_i in enumerate(token_ids):
|
||||||
d_tokens[i][tok_i] += batch_token_grads[i]
|
d_tokens[i][tok_i] += batch_token_grads[i]
|
||||||
|
|
||||||
self._transition_batch(states, scores)
|
self._transition_batch(states, scores)
|
||||||
|
|
||||||
# Get unfinished states (and their matching gold and token gradients)
|
# Get unfinished states (and their matching gold and token gradients)
|
||||||
todo = filter(lambda sp: not sp[0].is_final(), todo)
|
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||||
costs = costs[:len(todo)]
|
|
||||||
is_valid = is_valid[:len(todo)]
|
|
||||||
gradients = gradients[:len(todo)]
|
|
||||||
|
|
||||||
gradients.fill(0)
|
|
||||||
costs.fill(0)
|
|
||||||
is_valid.fill(1)
|
|
||||||
return output, loss
|
return output, loss
|
||||||
|
|
||||||
|
def _begin_update(self, states, tokvecs, drop=0.):
|
||||||
|
nr_class = self.moves.n_moves
|
||||||
|
attr_names = self.model.ops.allocate((2,), dtype='i')
|
||||||
|
attr_names[0] = TAG
|
||||||
|
attr_names[1] = DEP
|
||||||
|
|
||||||
|
features = self._get_features(states, tokvecs, attr_names)
|
||||||
|
scores, finish_update = self.model.begin_update(features, drop=drop)
|
||||||
|
is_valid = self.model.ops.allocate((len(states), nr_class), dtype='i')
|
||||||
|
self._validate_batch(is_valid, states)
|
||||||
|
softmaxed = self.model.ops.softmax(scores)
|
||||||
|
softmaxed *= is_valid
|
||||||
|
softmaxed /= softmaxed.sum(axis=1)
|
||||||
|
print('Scores', softmaxed[0])
|
||||||
|
def backward(golds, sgd=None):
|
||||||
|
costs = self.model.ops.allocate((len(states), nr_class), dtype='f')
|
||||||
|
d_scores = self.model.ops.allocate((len(states), nr_class), dtype='f')
|
||||||
|
|
||||||
|
self._cost_batch(costs, is_valid, states, golds)
|
||||||
|
self._set_gradient(d_scores, scores, is_valid, costs)
|
||||||
|
return finish_update(d_scores, sgd=sgd)
|
||||||
|
return softmaxed, backward
|
||||||
|
|
||||||
def _init_states(self, docs):
|
def _init_states(self, docs):
|
||||||
states = []
|
states = []
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
@ -281,20 +295,20 @@ cdef class Parser:
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
|
||||||
def _set_gradient(self, gradients, scores, costs):
|
def _set_gradient(self, gradients, scores, is_valid, costs):
|
||||||
"""Do multi-label log loss"""
|
"""Do multi-label log loss"""
|
||||||
cdef double Z, gZ, max_, g_max
|
cdef double Z, gZ, max_, g_max
|
||||||
g_scores = scores * (costs <= 0)
|
scores = scores * is_valid
|
||||||
maxes = scores.max(axis=1).reshape((scores.shape[0], 1))
|
g_scores = scores * is_valid * (costs <= 0.)
|
||||||
g_maxes = g_scores.max(axis=1).reshape((g_scores.shape[0], 1))
|
exps = numpy.exp(scores - scores.max(axis=1))
|
||||||
exps = numpy.exp((scores-maxes))
|
exps *= is_valid
|
||||||
g_exps = numpy.exp(g_scores-g_maxes)
|
g_exps = numpy.exp(g_scores - g_scores.max(axis=1))
|
||||||
|
g_exps *= costs <= 0.
|
||||||
Zs = exps.sum(axis=1).reshape((exps.shape[0], 1))
|
g_exps *= is_valid
|
||||||
gZs = g_exps.sum(axis=1).reshape((g_exps.shape[0], 1))
|
gradients[:] = exps / exps.sum(axis=1)
|
||||||
logprob = exps / Zs
|
gradients -= g_exps / g_exps.sum(axis=1)
|
||||||
g_logprob = g_exps / gZs
|
print('Gradient', gradients[0])
|
||||||
gradients[:] = logprob - g_logprob
|
print('Costs', costs[0])
|
||||||
|
|
||||||
def step_through(self, Doc doc, GoldParse gold=None):
|
def step_through(self, Doc doc, GoldParse gold=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -34,7 +34,7 @@ cdef class StateClass:
|
||||||
def token_vector_lenth(self):
|
def token_vector_lenth(self):
|
||||||
return self.doc.tensor.shape[1]
|
return self.doc.tensor.shape[1]
|
||||||
|
|
||||||
def is_final(self):
|
def py_is_final(self):
|
||||||
return self.c.is_final()
|
return self.c.is_final()
|
||||||
|
|
||||||
def print_state(self, words):
|
def print_state(self, words):
|
||||||
|
@ -47,31 +47,38 @@ cdef class StateClass:
|
||||||
return ' '.join((third, second, top, '|', n0, n1))
|
return ' '.join((third, second, top, '|', n0, n1))
|
||||||
|
|
||||||
def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR):
|
def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR):
|
||||||
return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR)
|
return 3
|
||||||
|
#return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR)
|
||||||
|
|
||||||
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
|
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
|
||||||
nL=2, nR=2):
|
nL=2, nR=2):
|
||||||
output[0] = self.B(0)
|
output[0] = self.B(0)
|
||||||
output[1] = self.S(0)
|
output[1] = self.S(0)
|
||||||
output[2] = self.S(1)
|
output[2] = self.S(1)
|
||||||
output[3] = self.L(self.S(0), 1)
|
#output[3] = self.L(self.S(0), 1)
|
||||||
output[4] = self.L(self.S(0), 2)
|
#output[4] = self.L(self.S(0), 2)
|
||||||
output[5] = self.R(self.S(0), 1)
|
#output[5] = self.R(self.S(0), 1)
|
||||||
output[6] = self.R(self.S(0), 2)
|
#output[6] = self.R(self.S(0), 2)
|
||||||
output[7] = self.L(self.S(1), 1)
|
#output[7] = self.L(self.S(1), 1)
|
||||||
output[8] = self.L(self.S(1), 2)
|
#output[8] = self.L(self.S(1), 2)
|
||||||
output[9] = self.R(self.S(1), 1)
|
#output[9] = self.R(self.S(1), 1)
|
||||||
output[10] = self.R(self.S(1), 2)
|
#output[10] = self.R(self.S(1), 2)
|
||||||
|
|
||||||
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
|
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
|
||||||
cdef int i, j, tok_i
|
cdef int i, j, tok_i
|
||||||
for i in range(tokens.shape[0]):
|
for i in range(tokens.shape[0]):
|
||||||
tok_i = tokens[i]
|
tok_i = tokens[i]
|
||||||
token = &self.c._sent[tok_i]
|
if tok_i >= 0:
|
||||||
for j in range(names.shape[0]):
|
token = &self.c._sent[tok_i]
|
||||||
vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j])
|
for j in range(names.shape[0]):
|
||||||
|
vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j])
|
||||||
|
else:
|
||||||
|
vals[i] = 0
|
||||||
|
|
||||||
def set_token_vectors(self, float[:, :] tokvecs,
|
def set_token_vectors(self, float[:, :] tokvecs,
|
||||||
float[:, :] all_tokvecs, int[:] indices):
|
float[:, :] all_tokvecs, int[:] indices):
|
||||||
for i in range(indices.shape[0]):
|
for i in range(indices.shape[0]):
|
||||||
tokvecs[i] = all_tokvecs[indices[i]]
|
if indices[i] >= 0:
|
||||||
|
tokvecs[i] = all_tokvecs[indices[i]]
|
||||||
|
else:
|
||||||
|
tokvecs[i] = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user