mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Ensure parser and tagger function correctly when training from missing values, indicated by -1
This commit is contained in:
parent
4ff180db74
commit
67d6e53a69
|
@ -255,19 +255,23 @@ cdef class EnPosTagger:
|
|||
tokens._tag_strings = self.tag_names
|
||||
tokens.is_tagged = True
|
||||
|
||||
def train(self, Tokens tokens, object golds):
|
||||
def train(self, Tokens tokens, object gold_tag_strs):
|
||||
cdef int i
|
||||
cdef int loss
|
||||
cdef atom_t[N_CONTEXT_FIELDS] context
|
||||
cdef const weight_t* scores
|
||||
golds = [self.tag_names.index(g) if g is not None else -1
|
||||
for g in gold_tag_strs]
|
||||
correct = 0
|
||||
for i in range(tokens.length):
|
||||
fill_context(context, i, tokens.data)
|
||||
scores = self.model.score(context)
|
||||
guess = arg_max(scores, self.model.n_classes)
|
||||
self.model.update(context, guess, golds[i], guess != golds[i])
|
||||
loss = guess != golds[i] if golds[i] != -1 else 0
|
||||
self.model.update(context, guess, golds[i], loss)
|
||||
tokens.data[i].tag = guess
|
||||
self.set_morph(i, tokens.data)
|
||||
correct += guess == golds[i]
|
||||
correct += loss == 0
|
||||
return correct
|
||||
|
||||
cdef int set_morph(self, const int i, TokenC* tokens) except -1:
|
||||
|
|
|
@ -102,8 +102,12 @@ cdef class GreedyParser:
|
|||
cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int))
|
||||
cdef int i
|
||||
for i in range(tokens.length):
|
||||
heads_array[i] = gold_heads[i]
|
||||
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
||||
if gold_heads[i] is None:
|
||||
heads_array[i] = -1
|
||||
labels_array[i] = -1
|
||||
else:
|
||||
heads_array[i] = gold_heads[i]
|
||||
labels_array[i] = self.moves.label_ids[gold_labels[i]]
|
||||
|
||||
py_words = [t.orth_ for t in tokens]
|
||||
py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR']
|
||||
|
@ -123,6 +127,7 @@ cdef class GreedyParser:
|
|||
self.moves.transition(state, &guess)
|
||||
cdef int n_corr = 0
|
||||
for i in range(tokens.length):
|
||||
if gold_heads[i] != -1:
|
||||
n_corr += (i + state.sent[i].head) == gold_heads[i]
|
||||
if force_gold and n_corr != tokens.length:
|
||||
print py_words
|
||||
|
|
Loading…
Reference in New Issue
Block a user