mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Refactor _advance_beam function
This commit is contained in:
parent
0786d9b3c7
commit
d1b55310a1
|
@ -1,9 +1,11 @@
|
|||
# cython: profile=True
|
||||
"""
|
||||
MALT-style dependency parser
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
cimport cython
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
from libc.string cimport memset, memcpy
|
||||
import random
|
||||
import os.path
|
||||
from os import path
|
||||
|
@ -152,11 +154,11 @@ cdef class Parser:
|
|||
self._advance_beam(gold, gold_parse, True)
|
||||
violn.check(pred, gold)
|
||||
counts = {}
|
||||
if pred._states[0].loss >= 1:
|
||||
if pred.loss >= 1:
|
||||
self._count_feats(counts, tokens, violn.g_hist, 1)
|
||||
self._count_feats(counts, tokens, violn.p_hist, -1)
|
||||
self.model._model.update(counts)
|
||||
return pred._states[0].loss
|
||||
return pred.loss
|
||||
|
||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
|
@ -167,22 +169,26 @@ cdef class Parser:
|
|||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
fill_context(context, state)
|
||||
scores = self.model.score(context)
|
||||
validities = self.moves.get_valid(state)
|
||||
if gold is None:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.set_cell(i, j, scores[j], validities[j], 0)
|
||||
elif not follow_gold:
|
||||
self.model.set_scores(beam.scores[i], context)
|
||||
self.moves.set_valid(beam.is_valid[i], state)
|
||||
|
||||
if follow_gold:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
cost = move.get_cost(move, state, gold)
|
||||
beam.set_cell(i, j, scores[j], validities[j], cost)
|
||||
else:
|
||||
beam.costs[i][j] = move.get_cost(move, state, gold)
|
||||
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||
elif gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
cost = move.get_cost(move, state, gold)
|
||||
beam.set_cell(i, j, scores[j], cost == 0, cost)
|
||||
beam.costs[i][j] = move.get_cost(move, state, gold)
|
||||
beam.advance(_transition_state, <void*>self.moves.c)
|
||||
state = <State*>beam.at(0)
|
||||
if state.sent[state.i].sent_end:
|
||||
beam.size = int(beam.size / 2)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
|
||||
def _count_feats(self, dict counts, Tokens tokens, list hist, int inc):
|
||||
|
|
Loading…
Reference in New Issue
Block a user