Restore patches from nn-beam-parser to spacy/syntax

This commit is contained in:
Matthew Honnibal 2017-08-18 22:23:03 +02:00
parent ec482580b5
commit 5f81d700ff
7 changed files with 151 additions and 214 deletions

View File

@ -6,7 +6,6 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.typedefs cimport hash_t, class_t
from thinc.extra.search cimport MaxViolation
from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateClass
@ -46,10 +45,9 @@ cdef class ParserBeam(object):
cdef public object states
cdef public object golds
cdef public object beams
cdef public object dones
def __init__(self, TransitionSystem moves, states, golds,
int width, float density):
int width=4, float density=0.001):
self.moves = moves
self.states = states
self.golds = golds
@ -63,7 +61,6 @@ cdef class ParserBeam(object):
st = <StateClass>beam.at(i)
st.c.offset = state.c.offset
self.beams.append(beam)
self.dones = [False] * len(self.beams)
def __dealloc__(self):
if self.beams is not None:
@ -73,7 +70,7 @@ cdef class ParserBeam(object):
@property
def is_done(self):
return all(b.is_done or self.dones[i] for i, b in enumerate(self.beams))
return all(b.is_done for b in self.beams)
def __getitem__(self, i):
return self.beams[i]
@ -84,42 +81,32 @@ cdef class ParserBeam(object):
def advance(self, scores, follow_gold=False):
cdef Beam beam
for i, beam in enumerate(self.beams):
if beam.is_done or not scores[i].size or self.dones[i]:
if beam.is_done or not scores[i].size:
continue
self._set_scores(beam, scores[i])
if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
if follow_gold:
assert self.golds is not None
beam.advance(_transition_state, NULL, <void*>self.moves.c)
else:
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
if beam.is_done and self.golds is not None:
if beam.is_done:
for j in range(beam.size):
state = <StateClass>beam.at(j)
if state.is_final():
try:
if self.moves.is_gold_parse(state, self.golds[i]):
beam._states[j].loss = 0.0
elif beam._states[j].loss == 0.0:
beam._states[j].loss = 1.0
except NotImplementedError:
break
if is_gold(<StateClass>beam.at(j), self.golds[i], self.moves.strings):
beam._states[j].loss = 0.0
elif beam._states[j].loss == 0.0:
beam._states[j].loss = 1.0
def _set_scores(self, Beam beam, float[:, ::1] scores):
cdef float* c_scores = &scores[0, 0]
cdef int nr_state = min(scores.shape[0], beam.size)
cdef int nr_class = scores.shape[1]
for i in range(nr_state):
for i in range(beam.size):
state = <StateClass>beam.at(i)
if not state.is_final():
for j in range(nr_class):
beam.scores[i][j] = c_scores[i * nr_class + j]
self.moves.set_valid(beam.is_valid[i], state.c)
else:
for j in range(beam.nr_class):
beam.scores[i][j] = 0
beam.costs[i][j] = 0
beam.scores[i][j] = c_scores[i * beam.nr_class + j]
self.moves.set_valid(beam.is_valid[i], state.c)
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
for i in range(beam.size):
@ -132,6 +119,21 @@ cdef class ParserBeam(object):
beam.is_valid[i][j] = 0
def is_gold(StateClass state, GoldParse gold, strings):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def get_token_ids(states, int n_tokens):
cdef StateClass state
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
@ -148,11 +150,9 @@ def get_token_ids(states, int n_tokens):
nr_update = 0
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
states, tokvecs, golds,
state2vec, vec2scores,
int width, float density,
sgd=None, losses=None, drop=0.):
state2vec, vec2scores, drop=0., sgd=None,
losses=None, int width=4, float density=0.001):
global nr_update
cdef MaxViolation violn
nr_update += 1
pbeam = ParserBeam(moves, states, golds,
width=width, density=density)
@ -163,8 +163,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
backprops = []
violns = [MaxViolation() for _ in range(len(states))]
for t in range(max_steps):
if pbeam.is_done and gbeam.is_done:
break
# The beam maps let us find the right row in the flattened scores
# arrays for each state. States are identified by (example id, history).
# We keep a different beam map for each step (since we'll have a flat
@ -196,17 +194,14 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i])
histories = []
losses = []
for violn in violns:
if violn.p_hist:
histories.append(violn.p_hist + violn.g_hist)
losses.append(violn.p_probs + violn.g_probs)
else:
histories.append([])
losses.append([])
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, losses)
return states_d_scores, backprops[:len(states_d_scores)]
# Only make updates if we have non-gold states
histories = [((v.p_hist + v.g_hist) if v.p_hist else []) for v in violns]
losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns]
states_d_scores = get_gradient(moves.n_moves, beam_maps,
histories, losses)
assert len(states_d_scores) == len(backprops), (len(states_d_scores), len(backprops))
return states_d_scores, backprops
def get_states(pbeams, gbeams, beam_map, nr_update):
@ -219,11 +214,12 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
p_indices.append([])
g_indices.append([])
if pbeam.loss > 0 and pbeam.min_score > gbeam.score:
continue
for i in range(pbeam.size):
state = <StateClass>pbeam.at(i)
if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i])
assert key not in seen, (key, seen)
seen[key] = len(states)
p_indices[-1].append(len(states))
states.append(state)
@ -259,27 +255,18 @@ def get_gradient(nr_class, beam_maps, histories, losses):
"""
nr_step = len(beam_maps)
grads = []
nr_step = 0
for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists):
if loss != 0.0 and not numpy.isnan(loss):
nr_step = max(nr_step, len(hist))
for i in range(nr_step):
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f'))
for beam_map in beam_maps:
if beam_map:
grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f'))
assert len(histories) == len(losses)
for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists):
if loss == 0.0 or numpy.isnan(loss):
continue
key = tuple([eg_id])
# Adjust loss for length
avg_loss = loss / len(hist)
loss += avg_loss * (nr_step - len(hist))
for j, clas in enumerate(hist):
i = beam_maps[j][key]
# In step j, at state i action clas
# resulted in loss
grads[j][i, clas] += loss
grads[j][i, clas] += loss / len(histories)
key = key + tuple([clas])
return grads

View File

@ -37,6 +37,7 @@ cdef cppclass StateC:
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
this.offset = 0
cdef int i
for i in range(length + (PADDING * 2)):
this._ents[i].end = -1
@ -73,16 +74,7 @@ cdef cppclass StateC:
free(this.shifted - PADDING)
void set_context_tokens(int* ids, int n) nogil:
if n == 8:
ids[0] = this.B(0)
ids[1] = this.B(1)
ids[2] = this.S(0)
ids[3] = this.S(1)
ids[4] = this.H(this.S(0))
ids[5] = this.L(this.B(0), 1)
ids[6] = this.L(this.S(0), 2)
ids[7] = this.R(this.S(0), 1)
elif n == 13:
if n == 13:
ids[0] = this.B(0)
ids[1] = this.B(1)
ids[2] = this.S(0)

View File

@ -351,20 +351,6 @@ cdef class ArcEager(TransitionSystem):
def __get__(self):
return (SHIFT, REDUCE, LEFT, RIGHT, BREAK)
def is_gold_parse(self, StateClass state, GoldParse gold):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), self.strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def has_gold(self, GoldParse gold, start=0, end=None):
end = end or len(gold.heads)
if all([tag is None for tag in gold.heads[start:end]]):
@ -399,6 +385,7 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves):
if self.c[i].move == move and self.c[i].label == label:
return self.c[i]
return Transition(clas=0, move=MISSING, label=0)
def move_name(self, int move, attr_t label):
label_str = self.strings[label]

View File

@ -34,6 +34,7 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context
from .stateclass cimport StateClass
from .parser cimport Parser
from ._beam_utils import is_gold
DEBUG = False
@ -107,7 +108,7 @@ cdef class BeamParser(Parser):
# The non-monotonic oracle makes it difficult to ensure final costs are
# correct. Therefore do final correction
for i in range(pred.size):
if self.moves.is_gold_parse(<StateClass>pred.at(i), gold_parse):
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings):
pred._states[i].loss = 0.0
elif pred._states[i].loss == 0.0:
pred._states[i].loss = 1.0
@ -213,7 +214,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
if not pred._states[i].is_done or pred._states[i].loss == 0:
continue
state = <StateClass>pred.at(i)
if moves.is_gold_parse(state, gold_parse) == True:
if is_gold(state, gold_parse, moves.strings) == True:
for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4])
print("Cost", pred._states[i].loss)
@ -227,7 +228,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
if not gold._states[i].is_done:
continue
state = <StateClass>gold.at(i)
if moves.is_gold(state, gold_parse) == False:
if is_gold(state, gold_parse, moves.strings) == False:
print("Truth")
for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4])
@ -237,16 +238,3 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
raise Exception("Gold parse is not gold-standard")
def is_gold(StateClass state, GoldParse gold, StringStore strings):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted

View File

@ -14,8 +14,4 @@ cdef class Parser:
cdef readonly TransitionSystem moves
cdef readonly object cfg
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil
#cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil

View File

@ -37,9 +37,7 @@ from preshed.maps cimport MapStruct
from preshed.maps cimport map_get
from thinc.api import layerize, chain, noop, clone
from thinc.neural import Model, Affine, ELU, ReLu, Maxout
from thinc.neural import Model, Affine, ReLu, Maxout
from thinc.neural._classes.batchnorm import BatchNorm as BN
from thinc.neural._classes.selu import SELU
from thinc.neural._classes.layernorm import LayerNorm
from thinc.neural.ops import NumpyOps, CupyOps
@ -48,10 +46,10 @@ from thinc.neural.util import get_array_module
from .. import util
from ..util import get_async, get_cuda_stream
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
from .._ml import Tok2Vec, doc2feats, rebatch
from .._ml import Tok2Vec, doc2feats, rebatch, fine_tune
from .._ml import Residual, drop_layer
from ..compat import json_dumps
from . import _beam_utils
from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context
@ -64,8 +62,11 @@ from ..structs cimport TokenC
from ..tokens.doc cimport Doc
from ..strings cimport StringStore
from ..gold cimport GoldParse
from ..attrs cimport TAG, DEP
from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
from . import _beam_utils
USE_FINE_TUNE = True
BEAM_PARSE = True
def get_templates(*args, **kwargs):
return []
@ -237,11 +238,14 @@ cdef class Parser:
Base class of the DependencyParser and EntityRecognizer.
"""
@classmethod
def Model(cls, nr_class, token_vector_width=128, hidden_width=128, depth=1, **cfg):
def Model(cls, nr_class, token_vector_width=128, hidden_width=300, depth=1, **cfg):
depth = util.env_opt('parser_hidden_depth', depth)
token_vector_width = util.env_opt('token_vector_width', token_vector_width)
hidden_width = util.env_opt('hidden_width', hidden_width)
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', 2)
embed_size = util.env_opt('embed_size', 4000)
tensors = fine_tune(Tok2Vec(token_vector_width, embed_size,
preprocess=doc2feats()))
if parser_maxout_pieces == 1:
lower = PrecomputableAffine(hidden_width if depth >= 1 else nr_class,
nF=cls.nr_feature,
@ -253,15 +257,10 @@ cdef class Parser:
nI=token_vector_width)
with Model.use_device('cpu'):
if depth == 0:
upper = chain()
upper.is_noop = True
else:
upper = chain(
clone(Maxout(hidden_width), (depth-1)),
zero_init(Affine(nr_class, drop_factor=0.0))
)
upper.is_noop = False
upper = chain(
clone(Residual(ReLu(hidden_width)), (depth-1)),
zero_init(Affine(nr_class, drop_factor=0.0))
)
# TODO: This is an unfortunate hack atm!
# Used to set input dimensions in network.
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
@ -273,7 +272,7 @@ cdef class Parser:
'hidden_width': hidden_width,
'maxout_pieces': parser_maxout_pieces
}
return (lower, upper), cfg
return (tensors, lower, upper), cfg
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
"""
@ -299,10 +298,6 @@ cdef class Parser:
self.moves = self.TransitionSystem(self.vocab.strings, {})
else:
self.moves = moves
if 'beam_width' not in cfg:
cfg['beam_width'] = util.env_opt('beam_width', 1)
if 'beam_density' not in cfg:
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
self.cfg = cfg
if 'actions' in self.cfg:
for action, labels in self.cfg.get('actions', {}).items():
@ -325,7 +320,7 @@ cdef class Parser:
if beam_width is None:
beam_width = self.cfg.get('beam_width', 1)
if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.0)
beam_density = self.cfg.get('beam_density', 0.001)
cdef Beam beam
if beam_width == 1:
states = self.parse_batch([doc], [doc.tensor])
@ -341,7 +336,7 @@ cdef class Parser:
return output
def pipe(self, docs, int batch_size=1000, int n_threads=2,
beam_width=None, beam_density=None):
beam_width=1, beam_density=0.001):
"""
Process a stream of documents.
@ -353,21 +348,21 @@ cdef class Parser:
The number of threads with which to work on the buffer in parallel.
Yields (Doc): Documents, in order.
"""
cdef StateClass parse_state
if beam_width is None:
beam_width = self.cfg.get('beam_width', 1)
if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.0)
if BEAM_PARSE:
beam_width = 8
cdef Doc doc
queue = []
cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs):
docs = list(docs)
tokvecs = [d.tensor for d in docs]
tokvecs = [doc.tensor for doc in docs]
if beam_width == 1:
parse_states = self.parse_batch(docs, tokvecs)
else:
parse_states = self.beam_parse(docs, tokvecs,
beam_width=beam_width, beam_density=beam_density)
beams = self.beam_parse(docs, tokvecs,
beam_width=beam_width, beam_density=beam_density)
parse_states = []
for beam in beams:
parse_states.append(<StateClass>beam.at(0))
self.set_annotations(docs, parse_states)
yield from docs
@ -382,8 +377,12 @@ cdef class Parser:
int nr_class, nr_feat, nr_piece, nr_dim, nr_state
if isinstance(docs, Doc):
docs = [docs]
if isinstance(tokvecses, np.ndarray):
tokvecses = [tokvecses]
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
nr_state = len(docs)
nr_class = self.moves.n_moves
@ -407,27 +406,20 @@ cdef class Parser:
cdef np.ndarray scores
c_token_ids = <int*>token_ids.data
c_is_valid = <int*>is_valid.data
cdef int has_hidden = not getattr(vec2scores, 'is_noop', False)
while not next_step.empty():
if not has_hidden:
for i in cython.parallel.prange(
next_step.size(), num_threads=6, nogil=True):
self._parse_step(next_step[i],
feat_weights, nr_class, nr_feat, nr_piece)
else:
for i in range(next_step.size()):
st = next_step[i]
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self.moves.set_valid(&c_is_valid[i*nr_class], st)
for i in range(next_step.size()):
st = next_step[i]
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
self.moves.set_valid(&c_is_valid[i*nr_class], st)
vectors = state2vec(token_ids[:next_step.size()])
scores = vec2scores(vectors)
c_scores = <float*>scores.data
for i in range(next_step.size()):
st = next_step[i]
guess = arg_max_if_valid(
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
action = self.moves.c[guess]
action.do(st, action.label)
scores = vec2scores(vectors)
c_scores = <float*>scores.data
for i in range(next_step.size()):
st = next_step[i]
guess = arg_max_if_valid(
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
action = self.moves.c[guess]
action.do(st, action.label)
this_step, next_step = next_step, this_step
next_step.clear()
for st in this_step:
@ -435,18 +427,22 @@ cdef class Parser:
next_step.push_back(st)
return states
def beam_parse(self, docs, tokvecses, int beam_width=3, float beam_density=0.001):
def beam_parse(self, docs, tokvecses, int beam_width=8, float beam_density=0.001):
cdef Beam beam
cdef np.ndarray scores
cdef Doc doc
cdef int nr_class = self.moves.n_moves
cdef StateClass stcls, output
tokvecs = self.model[0].ops.flatten(tokvecses)
if USE_FINE_TUNE:
tokvecs += self.model[0].ops.flatten(self.model[0]((docs, tokvecses)))
cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
cuda_stream, 0.0)
beams = []
cdef int offset = 0
cdef int j = 0
cdef int k
for doc in docs:
beam = Beam(nr_class, beam_width, min_density=beam_density)
beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
@ -459,58 +455,42 @@ cdef class Parser:
states = []
for i in range(beam.size):
stcls = <StateClass>beam.at(i)
states.append(stcls)
# This way we avoid having to score finalized states
# We do have to take care to keep indexes aligned, though
if not stcls.is_final():
states.append(stcls)
token_ids = self.get_token_ids(states)
vectors = state2vec(token_ids)
scores = vec2scores(vectors)
j = 0
c_scores = <float*>scores.data
for i in range(beam.size):
stcls = <StateClass>beam.at(i)
if not stcls.is_final():
self.moves.set_valid(beam.is_valid[i], stcls.c)
for j in range(nr_class):
beam.scores[i][j] = scores[i, j]
for k in range(nr_class):
beam.scores[i][k] = c_scores[j * scores.shape[1] + k]
j += 1
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
beams.append(beam)
return beams
cdef void _parse_step(self, StateC* state,
const float* feat_weights,
int nr_class, int nr_feat, int nr_piece) nogil:
'''This only works with no hidden layers -- fast but inaccurate'''
#for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True):
# self._parse_step(next_step[i], feat_weights, nr_class, nr_feat)
token_ids = <int*>calloc(nr_feat, sizeof(int))
scores = <float*>calloc(nr_class * nr_piece, sizeof(float))
is_valid = <int*>calloc(nr_class, sizeof(int))
state.set_context_tokens(token_ids, nr_feat)
sum_state_features(scores,
feat_weights, token_ids, 1, nr_feat, nr_class * nr_piece)
self.moves.set_valid(is_valid, state)
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
action = self.moves.c[guess]
action.do(state, action.label)
free(is_valid)
free(scores)
free(token_ids)
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
return self.update_beam(docs_tokvecs, golds,
self.cfg['beam_width'], self.cfg['beam_density'],
drop=drop, sgd=sgd, losses=losses)
if BEAM_PARSE:
return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd,
losses=losses)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
docs, tokvec_lists = docs_tokvecs
tokvecs = self.model[0].ops.flatten(tokvec_lists)
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
tokvecs += self.model[0].ops.flatten(my_tokvecs)
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs]
golds = [golds]
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
cuda_stream = get_cuda_stream()
@ -537,13 +517,13 @@ cdef class Parser:
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
d_scores = self.get_batch_loss(states, golds, scores)
d_vector = bp_scores(d_scores / d_scores.shape[0], sgd=sgd)
d_vector = bp_scores(d_scores, sgd=sgd)
if drop != 0:
d_vector *= mask
if isinstance(self.model[0].ops, CupyOps) \
and not isinstance(token_ids, state2vec.ops.xp.ndarray):
# Move token_ids and d_vector to CPU, asynchronously
# Move token_ids and d_vector to GPU, asynchronously
backprops.append((
get_async(cuda_stream, token_ids),
get_async(cuda_stream, d_vector),
@ -561,24 +541,21 @@ cdef class Parser:
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
bp_my_tokvecs(d_tokvecs, sgd=sgd)
if USE_FINE_TUNE:
bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs
def update_beam(self, docs_tokvecs, golds, width=None, density=None,
drop=0., sgd=None, losses=None):
if width is None:
width = self.cfg.get('beam_width', 2)
if density is None:
density = self.cfg.get('beam_density', 0.0)
def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
if losses is not None and self.name not in losses:
losses[self.name] = 0.
docs, tokvecs = docs_tokvecs
lengths = [len(d) for d in docs]
assert min(lengths) >= 1
tokvecs = self.model[0].ops.flatten(tokvecs)
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
my_tokvecs = self.model[0].ops.flatten(my_tokvecs)
tokvecs += my_tokvecs
states = self.moves.init_batch(docs)
for gold in golds:
@ -590,8 +567,8 @@ cdef class Parser:
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
states, tokvecs, golds,
state2vec, vec2scores,
width, density,
sgd=sgd, drop=drop, losses=losses)
drop, sgd, losses,
width=8)
backprop_lower = []
for i, d_scores in enumerate(states_d_scores):
if losses is not None:
@ -609,7 +586,8 @@ cdef class Parser:
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
d_tokvecs = self.model[0].ops.unflatten(d_tokvecs, lengths)
bp_my_tokvecs(d_tokvecs, sgd=sgd)
if USE_FINE_TUNE:
bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs
def _init_gold_batch(self, whole_docs, whole_golds):
@ -656,9 +634,9 @@ cdef class Parser:
for ids, d_vector, bp_vector in backprops:
d_state_features = bp_vector(d_vector, sgd=sgd)
mask = ids >= 0
d_state_features *= mask.reshape(ids.shape + (1,))
self.model[0].ops.scatter_add(d_tokvecs, ids * mask,
d_state_features)
indices = xp.nonzero(mask)
self.model[0].ops.scatter_add(d_tokvecs, ids[indices],
d_state_features[indices])
@property
def move_names(self):
@ -669,12 +647,12 @@ cdef class Parser:
return names
def get_batch_model(self, batch_size, tokvecs, stream, dropout):
lower, upper = self.model
_, lower, upper = self.model
state2vec = precompute_hiddens(batch_size, tokvecs,
lower, stream, drop=dropout)
return state2vec, upper
nr_feature = 8
nr_feature = 13
def get_token_ids(self, states):
cdef StateClass state
@ -759,10 +737,12 @@ cdef class Parser:
def to_disk(self, path, **exclude):
serializers = {
'lower_model': lambda p: p.open('wb').write(
'tok2vec_model': lambda p: p.open('wb').write(
self.model[0].to_bytes()),
'upper_model': lambda p: p.open('wb').write(
'lower_model': lambda p: p.open('wb').write(
self.model[1].to_bytes()),
'upper_model': lambda p: p.open('wb').write(
self.model[2].to_bytes()),
'vocab': lambda p: self.vocab.to_disk(p),
'moves': lambda p: self.moves.to_disk(p, strings=False),
'cfg': lambda p: p.open('w').write(json_dumps(self.cfg))
@ -783,24 +763,29 @@ cdef class Parser:
self.model, cfg = self.Model(**self.cfg)
else:
cfg = {}
with (path / 'lower_model').open('rb') as file_:
with (path / 'tok2vec_model').open('rb') as file_:
bytes_data = file_.read()
self.model[0].from_bytes(bytes_data)
with (path / 'upper_model').open('rb') as file_:
with (path / 'lower_model').open('rb') as file_:
bytes_data = file_.read()
self.model[1].from_bytes(bytes_data)
with (path / 'upper_model').open('rb') as file_:
bytes_data = file_.read()
self.model[2].from_bytes(bytes_data)
self.cfg.update(cfg)
return self
def to_bytes(self, **exclude):
serializers = OrderedDict((
('lower_model', lambda: self.model[0].to_bytes()),
('upper_model', lambda: self.model[1].to_bytes()),
('tok2vec_model', lambda: self.model[0].to_bytes()),
('lower_model', lambda: self.model[1].to_bytes()),
('upper_model', lambda: self.model[2].to_bytes()),
('vocab', lambda: self.vocab.to_bytes()),
('moves', lambda: self.moves.to_bytes(strings=False)),
('cfg', lambda: ujson.dumps(self.cfg))
))
if 'model' in exclude:
exclude['tok2vec_model'] = True
exclude['lower_model'] = True
exclude['upper_model'] = True
exclude.pop('model')
@ -811,6 +796,7 @@ cdef class Parser:
('vocab', lambda b: self.vocab.from_bytes(b)),
('moves', lambda b: self.moves.from_bytes(b, strings=False)),
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
('tok2vec_model', lambda b: None),
('lower_model', lambda b: None),
('upper_model', lambda b: None)
))
@ -820,10 +806,12 @@ cdef class Parser:
self.model, cfg = self.Model(self.moves.n_moves)
else:
cfg = {}
if 'tok2vec_model' in msg:
self.model[0].from_bytes(msg['tok2vec_model'])
if 'lower_model' in msg:
self.model[0].from_bytes(msg['lower_model'])
self.model[1].from_bytes(msg['lower_model'])
if 'upper_model' in msg:
self.model[1].from_bytes(msg['upper_model'])
self.model[2].from_bytes(msg['upper_model'])
self.cfg.update(cfg)
return self

View File

@ -99,9 +99,6 @@ cdef class TransitionSystem:
def preprocess_gold(self, GoldParse gold):
raise NotImplementedError
def is_gold_parse(self, StateClass state, GoldParse gold):
raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError
@ -110,6 +107,8 @@ cdef class TransitionSystem:
def is_valid(self, StateClass stcls, move_name):
action = self.lookup_transition(move_name)
if action.move == 0:
return False
return action.is_valid(stcls.c, action.label)
cdef int set_valid(self, int* is_valid, const StateC* st) nogil: