Tmp: Commiting work that was sitting around

This commit is contained in:
Matthew Honnibal 2018-10-21 16:15:53 +02:00
parent 3b17eb7c49
commit f46b22879b
3 changed files with 699 additions and 219 deletions

View File

@ -0,0 +1,259 @@
from thinc.typedefs cimport atom_t
from .stateclass cimport StateClass
from ._state cimport StateC
cdef int fill_context(atom_t* context, const StateC* state) nogil
# Context elements
# Ensure each token's attributes are listed: w, p, c, c6, c4. The order
# is referenced by incrementing the enum...
# Tokens are listed in left-to-right order.
#cdef size_t* SLOTS = [
# S2w, S1w,
# S0l0w, S0l2w, S0lw,
# S0w,
# S0r0w, S0r2w, S0rw,
# N0l0w, N0l2w, N0lw,
# P2w, P1w,
# N0w, N1w, N2w, N3w, 0
#]
# NB: The order of the enum is _NOT_ arbitrary!!
cpdef enum:
S2w
S2W
S2p
S2c
S2c4
S2c6
S2L
S2_prefix
S2_suffix
S2_shape
S2_ne_iob
S2_ne_type
S1w
S1W
S1p
S1c
S1c4
S1c6
S1L
S1_prefix
S1_suffix
S1_shape
S1_ne_iob
S1_ne_type
S1rw
S1rW
S1rp
S1rc
S1rc4
S1rc6
S1rL
S1r_prefix
S1r_suffix
S1r_shape
S1r_ne_iob
S1r_ne_type
S0lw
S0lW
S0lp
S0lc
S0lc4
S0lc6
S0lL
S0l_prefix
S0l_suffix
S0l_shape
S0l_ne_iob
S0l_ne_type
S0l2w
S0l2W
S0l2p
S0l2c
S0l2c4
S0l2c6
S0l2L
S0l2_prefix
S0l2_suffix
S0l2_shape
S0l2_ne_iob
S0l2_ne_type
S0w
S0W
S0p
S0c
S0c4
S0c6
S0L
S0_prefix
S0_suffix
S0_shape
S0_ne_iob
S0_ne_type
S0r2w
S0r2W
S0r2p
S0r2c
S0r2c4
S0r2c6
S0r2L
S0r2_prefix
S0r2_suffix
S0r2_shape
S0r2_ne_iob
S0r2_ne_type
S0rw
S0rW
S0rp
S0rc
S0rc4
S0rc6
S0rL
S0r_prefix
S0r_suffix
S0r_shape
S0r_ne_iob
S0r_ne_type
N0l2w
N0l2W
N0l2p
N0l2c
N0l2c4
N0l2c6
N0l2L
N0l2_prefix
N0l2_suffix
N0l2_shape
N0l2_ne_iob
N0l2_ne_type
N0lw
N0lW
N0lp
N0lc
N0lc4
N0lc6
N0lL
N0l_prefix
N0l_suffix
N0l_shape
N0l_ne_iob
N0l_ne_type
N0w
N0W
N0p
N0c
N0c4
N0c6
N0L
N0_prefix
N0_suffix
N0_shape
N0_ne_iob
N0_ne_type
N1w
N1W
N1p
N1c
N1c4
N1c6
N1L
N1_prefix
N1_suffix
N1_shape
N1_ne_iob
N1_ne_type
N2w
N2W
N2p
N2c
N2c4
N2c6
N2L
N2_prefix
N2_suffix
N2_shape
N2_ne_iob
N2_ne_type
P1w
P1W
P1p
P1c
P1c4
P1c6
P1L
P1_prefix
P1_suffix
P1_shape
P1_ne_iob
P1_ne_type
P2w
P2W
P2p
P2c
P2c4
P2c6
P2L
P2_prefix
P2_suffix
P2_shape
P2_ne_iob
P2_ne_type
E0w
E0W
E0p
E0c
E0c4
E0c6
E0L
E0_prefix
E0_suffix
E0_shape
E0_ne_iob
E0_ne_type
E1w
E1W
E1p
E1c
E1c4
E1c6
E1L
E1_prefix
E1_suffix
E1_shape
E1_ne_iob
E1_ne_type
# Misc features at the end
dist
N0lv
S0lv
S0rv
S1lv
S1rv
S0_has_head
S1_has_head
S2_has_head
CONTEXT_SIZE

View File

@ -0,0 +1,419 @@
"""
Fill an array, context, with every _atomic_ value our features reference.
We then write the _actual features_ as tuples of the atoms. The machinery
that translates from the tuples to feature-extractors (which pick the values
out of "context") is in features/extractor.pyx
The atomic feature names are listed in a big enum, so that the feature tuples
can refer to them.
"""
# coding: utf-8
from __future__ import unicode_literals
from libc.string cimport memset
from itertools import combinations
from cymem.cymem cimport Pool
from ..structs cimport TokenC
from .stateclass cimport StateClass
from ._state cimport StateC
cdef inline void fill_token(atom_t* context, const TokenC* token) nogil:
if token is NULL:
context[0] = 0
context[1] = 0
context[2] = 0
context[3] = 0
context[4] = 0
context[5] = 0
context[6] = 0
context[7] = 0
context[8] = 0
context[9] = 0
context[10] = 0
context[11] = 0
else:
context[0] = token.lex.orth
context[1] = token.lemma
context[2] = token.tag
context[3] = token.lex.cluster
# We've read in the string little-endian, so now we can take & (2**n)-1
# to get the first n bits of the cluster.
# e.g. s = "1110010101"
# s = ''.join(reversed(s))
# first_4_bits = int(s, 2)
# print first_4_bits
# 5
# print "{0:b}".format(prefix).ljust(4, '0')
# 1110
# What we're doing here is picking a number where all bits are 1, e.g.
# 15 is 1111, 63 is 111111 and doing bitwise AND, so getting all bits in
# the source that are set to 1.
context[4] = token.lex.cluster & 15
context[5] = token.lex.cluster & 63
context[6] = token.dep if token.head != 0 else 0
context[7] = token.lex.prefix
context[8] = token.lex.suffix
context[9] = token.lex.shape
context[10] = token.ent_iob
context[11] = token.ent_type
cdef int fill_context(atom_t* ctxt, const StateC* st) nogil:
# Take care to fill every element of context!
# We could memset, but this makes it very easy to have broken features that
# make almost no impact on accuracy. If instead they're unset, the impact
# tends to be dramatic, so we get an obvious regression to fix...
fill_token(&ctxt[S2w], st.S_(2))
fill_token(&ctxt[S1w], st.S_(1))
fill_token(&ctxt[S1rw], st.R_(st.S(1), 1))
fill_token(&ctxt[S0lw], st.L_(st.S(0), 1))
fill_token(&ctxt[S0l2w], st.L_(st.S(0), 2))
fill_token(&ctxt[S0w], st.S_(0))
fill_token(&ctxt[S0r2w], st.R_(st.S(0), 2))
fill_token(&ctxt[S0rw], st.R_(st.S(0), 1))
fill_token(&ctxt[N0lw], st.L_(st.B(0), 1))
fill_token(&ctxt[N0l2w], st.L_(st.B(0), 2))
fill_token(&ctxt[N0w], st.B_(0))
fill_token(&ctxt[N1w], st.B_(1))
fill_token(&ctxt[N2w], st.B_(2))
fill_token(&ctxt[P1w], st.safe_get(st.B(0)-1))
fill_token(&ctxt[P2w], st.safe_get(st.B(0)-2))
fill_token(&ctxt[E0w], st.E_(0))
fill_token(&ctxt[E1w], st.E_(1))
if st.stack_depth() >= 1 and not st.eol():
ctxt[dist] = min_(st.B(0) - st.E(0), 5)
else:
ctxt[dist] = 0
ctxt[N0lv] = min_(st.n_L(st.B(0)), 5)
ctxt[S0lv] = min_(st.n_L(st.S(0)), 5)
ctxt[S0rv] = min_(st.n_R(st.S(0)), 5)
ctxt[S1lv] = min_(st.n_L(st.S(1)), 5)
ctxt[S1rv] = min_(st.n_R(st.S(1)), 5)
ctxt[S0_has_head] = 0
ctxt[S1_has_head] = 0
ctxt[S2_has_head] = 0
if st.stack_depth() >= 1:
ctxt[S0_has_head] = st.has_head(st.S(0)) + 1
if st.stack_depth() >= 2:
ctxt[S1_has_head] = st.has_head(st.S(1)) + 1
if st.stack_depth() >= 3:
ctxt[S2_has_head] = st.has_head(st.S(2)) + 1
cdef inline int min_(int a, int b) nogil:
return a if a > b else b
ner = (
(N0W,),
(P1W,),
(N1W,),
(P2W,),
(N2W,),
(P1W, N0W,),
(N0W, N1W),
(N0_prefix,),
(N0_suffix,),
(P1_shape,),
(N0_shape,),
(N1_shape,),
(P1_shape, N0_shape,),
(N0_shape, P1_shape,),
(P1_shape, N0_shape, N1_shape),
(N2_shape,),
(P2_shape,),
#(P2_norm, P1_norm, W_norm),
#(P1_norm, W_norm, N1_norm),
#(W_norm, N1_norm, N2_norm)
(P2p,),
(P1p,),
(N0p,),
(N1p,),
(N2p,),
(P1p, N0p),
(N0p, N1p),
(P2p, P1p, N0p),
(P1p, N0p, N1p),
(N0p, N1p, N2p),
(P2c,),
(P1c,),
(N0c,),
(N1c,),
(N2c,),
(P1c, N0c),
(N0c, N1c),
(E0W,),
(E0c,),
(E0p,),
(E0W, N0W),
(E0c, N0W),
(E0p, N0W),
(E0p, P1p, N0p),
(E0c, P1c, N0c),
(E0w, P1c),
(E0p, P1p),
(E0c, P1c),
(E0p, E1p),
(E0c, P1p),
(E1W,),
(E1c,),
(E1p,),
(E0W, E1W),
(E0W, E1p,),
(E0p, E1W,),
(E0p, E1W),
(P1_ne_iob,),
(P1_ne_iob, P1_ne_type),
(N0w, P1_ne_iob, P1_ne_type),
(N0_shape,),
(N1_shape,),
(N2_shape,),
(P1_shape,),
(P2_shape,),
(N0_prefix,),
(N0_suffix,),
(P1_ne_iob,),
(P2_ne_iob,),
(P1_ne_iob, P2_ne_iob),
(P1_ne_iob, P1_ne_type),
(P2_ne_iob, P2_ne_type),
(N0w, P1_ne_iob, P1_ne_type),
(N0w, N1w),
)
unigrams = (
(S2W, S2p),
(S2c6, S2p),
(S1W, S1p),
(S1c6, S1p),
(S0W, S0p),
(S0c6, S0p),
(N0W, N0p),
(N0p,),
(N0c,),
(N0c6, N0p),
(N0L,),
(N1W, N1p),
(N1c6, N1p),
(N2W, N2p),
(N2c6, N2p),
(S0r2W, S0r2p),
(S0r2c6, S0r2p),
(S0r2L,),
(S0rW, S0rp),
(S0rc6, S0rp),
(S0rL,),
(S0l2W, S0l2p),
(S0l2c6, S0l2p),
(S0l2L,),
(S0lW, S0lp),
(S0lc6, S0lp),
(S0lL,),
(N0l2W, N0l2p),
(N0l2c6, N0l2p),
(N0l2L,),
(N0lW, N0lp),
(N0lc6, N0lp),
(N0lL,),
)
s0_n0 = (
(S0W, S0p, N0W, N0p),
(S0c, S0p, N0c, N0p),
(S0c6, S0p, N0c6, N0p),
(S0c4, S0p, N0c4, N0p),
(S0p, N0p),
(S0W, N0p),
(S0p, N0W),
(S0W, N0c),
(S0c, N0W),
(S0p, N0c),
(S0c, N0p),
(S0W, S0rp, N0p),
(S0p, S0rp, N0p),
(S0p, N0lp, N0W),
(S0p, N0lp, N0p),
(S0L, N0p),
(S0p, S0rL, N0p),
(S0p, N0lL, N0p),
(S0p, S0rv, N0p),
(S0p, N0lv, N0p),
(S0c6, S0rL, S0r2L, N0p),
(S0p, N0lL, N0l2L, N0p),
)
s1_s0 = (
(S1p, S0p),
(S1p, S0p, S0_has_head),
(S1W, S0p),
(S1W, S0p, S0_has_head),
(S1c, S0p),
(S1c, S0p, S0_has_head),
(S1p, S1rL, S0p),
(S1p, S1rL, S0p, S0_has_head),
(S1p, S0lL, S0p),
(S1p, S0lL, S0p, S0_has_head),
(S1p, S0lL, S0l2L, S0p),
(S1p, S0lL, S0l2L, S0p, S0_has_head),
(S1L, S0L, S0W),
(S1L, S0L, S0p),
(S1p, S1L, S0L, S0p),
(S1p, S0p),
)
s1_n0 = (
(S1p, N0p),
(S1c, N0c),
(S1c, N0p),
(S1p, N0c),
(S1W, S1p, N0p),
(S1p, N0W, N0p),
(S1c6, S1p, N0c6, N0p),
(S1L, N0p),
(S1p, S1rL, N0p),
(S1p, S1rp, N0p),
)
s0_n1 = (
(S0p, N1p),
(S0c, N1c),
(S0c, N1p),
(S0p, N1c),
(S0W, S0p, N1p),
(S0p, N1W, N1p),
(S0c6, S0p, N1c6, N1p),
(S0L, N1p),
(S0p, S0rL, N1p),
)
n0_n1 = (
(N0W, N0p, N1W, N1p),
(N0W, N0p, N1p),
(N0p, N1W, N1p),
(N0c, N0p, N1c, N1p),
(N0c6, N0p, N1c6, N1p),
(N0c, N1c),
(N0p, N1c),
)
tree_shape = (
(dist,),
(S0p, S0_has_head, S1_has_head, S2_has_head),
(S0p, S0lv, S0rv),
(N0p, N0lv),
)
trigrams = (
(N0p, N1p, N2p),
(S0p, S0lp, S0l2p),
(S0p, S0rp, S0r2p),
(S0p, S1p, S2p),
(S1p, S0p, N0p),
(S0p, S0lp, N0p),
(S0p, N0p, N0lp),
(N0p, N0lp, N0l2p),
(S0W, S0p, S0rL, S0r2L),
(S0p, S0rL, S0r2L),
(S0W, S0p, S0lL, S0l2L),
(S0p, S0lL, S0l2L),
(N0W, N0p, N0lL, N0l2L),
(N0p, N0lL, N0l2L),
)
words = (
S2w,
S1w,
S1rw,
S0lw,
S0l2w,
S0w,
S0r2w,
S0rw,
N0lw,
N0l2w,
N0w,
N1w,
N2w,
P1w,
P2w
)
tags = (
S2p,
S1p,
S1rp,
S0lp,
S0l2p,
S0p,
S0r2p,
S0rp,
N0lp,
N0l2p,
N0p,
N1p,
N2p,
P1p,
P2p
)
labels = (
S2L,
S1L,
S1rL,
S0lL,
S0l2L,
S0L,
S0r2L,
S0rL,
N0lL,
N0l2L,
N0L,
N1L,
N2L,
P1L,
P2L
)

View File

@ -41,188 +41,6 @@ from .transition_system cimport Transition
from . import _beam_utils, nonproj
def get_templates(*args, **kwargs):
return []
DEBUG = False
def set_debug(val):
global DEBUG
DEBUG = val
cdef class precompute_hiddens:
"""Allow a model to be "primed" by pre-computing input features in bulk.
This is used for the parser, where we want to take a batch of documents,
and compute vectors for each (token, position) pair. These vectors can then
be reused, especially for beam-search.
Let's say we're using 12 features for each state, e.g. word at start of
buffer, three words on stack, their children, etc. In the normal arc-eager
system, a document of length N is processed in 2*N states. This means we'll
create 2*N*12 feature vectors --- but if we pre-compute, we only need
N*12 vector computations. The saving for beam-search is much better:
if we have a beam of k, we'll normally make 2*N*12*K computations --
so we can save the factor k. This also gives a nice CPU/GPU division:
we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU.
"""
cdef int nF, nO, nP
cdef bint _is_synchronized
cdef public object ops
cdef np.ndarray _features
cdef np.ndarray _cached
cdef np.ndarray bias
cdef object _cuda_stream
cdef object _bp_hiddens
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
drop=0.):
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop)
cdef np.ndarray cached
if not isinstance(gpu_cached, numpy.ndarray):
# Note the passing of cuda_stream here: it lets
# cupy make the copy asynchronously.
# We then have to block before first use.
cached = gpu_cached.get(stream=cuda_stream)
else:
cached = gpu_cached
if not isinstance(lower_model.b, numpy.ndarray):
self.bias = lower_model.b.get()
else:
self.bias = lower_model.b
self.nF = cached.shape[1]
self.nP = getattr(lower_model, 'nP', 1)
self.nO = cached.shape[2]
self.ops = lower_model.ops
self._is_synchronized = False
self._cuda_stream = cuda_stream
self._cached = cached
self._bp_hiddens = bp_features
cdef const float* get_feat_weights(self) except NULL:
if not self._is_synchronized and self._cuda_stream is not None:
self._cuda_stream.synchronize()
self._is_synchronized = True
return <float*>self._cached.data
def __call__(self, X):
return self.begin_update(X)[0]
def begin_update(self, token_ids, drop=0.):
cdef np.ndarray state_vector = numpy.zeros(
(token_ids.shape[0], self.nO, self.nP), dtype='f')
# This is tricky, but (assuming GPU available);
# - Input to forward on CPU
# - Output from forward on CPU
# - Input to backward on GPU!
# - Output from backward on GPU
bp_hiddens = self._bp_hiddens
feat_weights = self.get_feat_weights()
cdef int[:, ::1] ids = token_ids
sum_state_features(<float*>state_vector.data,
feat_weights, &ids[0,0],
token_ids.shape[0], self.nF, self.nO*self.nP)
state_vector += self.bias
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
def backward(d_state_vector_ids, sgd=None):
d_state_vector, token_ids = d_state_vector_ids
d_state_vector = bp_nonlinearity(d_state_vector, sgd)
# This will usually be on GPU
if not isinstance(d_state_vector, self.ops.xp.ndarray):
d_state_vector = self.ops.xp.array(d_state_vector)
d_tokens = bp_hiddens((d_state_vector, token_ids), sgd)
return d_tokens
return state_vector, backward
def _nonlinearity(self, state_vector):
if self.nP == 1:
state_vector = state_vector.reshape(state_vector.shape[:-1])
mask = state_vector >= 0.
state_vector *= mask
else:
state_vector, mask = self.ops.maxout(state_vector)
def backprop_nonlinearity(d_best, sgd=None):
if self.nP == 1:
d_best *= mask
d_best = d_best.reshape((d_best.shape + (1,)))
return d_best
else:
return self.ops.backprop_maxout(d_best, mask, self.nP)
return state_vector, backprop_nonlinearity
cdef void sum_state_features(float* output,
const float* cached, const int* token_ids, int B, int F, int O) nogil:
cdef int idx, b, f, i
cdef const float* feature
padding = cached
cached += F * O
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * F * O + f*O
feature = &cached[idx]
for i in range(O):
output[i] += feature[i]
output += O
token_ids += F
cdef void cpu_log_loss(float* d_scores,
const float* costs, const int* is_valid, const float* scores,
int O) nogil:
"""Do multi-label log loss"""
cdef double max_, gmax, Z, gZ
best = arg_max_if_gold(scores, costs, is_valid, O)
guess = arg_max_if_valid(scores, is_valid, O)
Z = 1e-10
gZ = 1e-10
max_ = scores[guess]
gmax = scores[best]
for i in range(O):
if is_valid[i]:
Z += exp(scores[i] - max_)
if costs[i] <= costs[best]:
gZ += exp(scores[i] - gmax)
for i in range(O):
if not is_valid[i]:
d_scores[i] = 0.
elif costs[i] <= costs[best]:
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
else:
d_scores[i] = exp(scores[i]-max_) / Z
cdef void cpu_regression_loss(float* d_scores,
const float* costs, const int* is_valid, const float* scores,
int O) nogil:
cdef float eps = 2.
best = arg_max_if_gold(scores, costs, is_valid, O)
for i in range(O):
if not is_valid[i]:
d_scores[i] = 0.
elif scores[i] < scores[best]:
d_scores[i] = 0.
else:
# I doubt this is correct?
# Looking for something like Huber loss
diff = scores[i] - -costs[i]
if diff > eps:
d_scores[i] = eps
elif diff < -eps:
d_scores[i] = -eps
else:
d_scores[i] = diff
def _collect_states(beams):
cdef StateClass state
@ -545,25 +363,26 @@ cdef class Parser:
def update(self, docs, golds, drop=0., sgd=None, losses=None):
if not any(self.moves.has_gold(gold) for gold in golds):
return None
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
return self.update_beam(docs, golds,
self.cfg['beam_width'], self.cfg['beam_density'],
drop=drop, sgd=sgd, losses=losses)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs]
golds = [golds]
cuda_stream = util.get_cuda_stream()
states, golds, max_steps = self._init_gold_batch(docs, golds)
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
drop)
todo = [(s, g) for (s, g) in zip(states, golds)
if not s.is_final() and g is not None]
if not todo:
return None
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
return self.update_beam(docs, golds,
self.cfg['beam_width'], self.cfg['beam_density'],
drop=drop, sgd=sgd, losses=losses)
else:
return self.update_greedy(docs, golds, drop=drop, sgd=sgd, losses=losses)
def update_greedy(self, docs, golds, drop=0., sgd=None, losses=None):
tokvecs, bp_tokvecs = self.model.tok2vec(docs)
states = self.init_states(docs, tokvecs)
histories, get_costs = self.model.predict_histories(states)
costs = get_costs(golds)
d_tokens = self.model.update(states, histories, costs)
return bp_tokvecs(tokvecs)
backprops = []
# Add a padding vector to the d_tokvecs gradient, so that missing
# values don't affect the real gradient.
d_tokvecs = state2vec.ops.allocate((tokvecs.shape[0]+1, tokvecs.shape[1]))
@ -571,32 +390,11 @@ cdef class Parser:
n_steps = 0
while todo:
states, golds = zip(*todo)
token_ids = self.get_token_ids(states)
vector, bp_vector = state2vec.begin_update(token_ids, drop=0.0)
if drop != 0:
mask = vec2scores.ops.get_dropout_mask(vector.shape, drop)
vector *= mask
hists = numpy.asarray([st.history for st in states], dtype='i')
if self.cfg.get('hist_size', 0):
scores, bp_scores = vec2scores.begin_update((vector, hists), drop=drop)
else:
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
vector, bp_vector = state2vec.begin_update(states, drop=0.0)
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
d_scores = self.get_batch_loss(states, golds, scores)
d_scores /= len(docs)
d_vector = bp_scores(d_scores, sgd=sgd)
if drop != 0:
d_vector *= mask
if isinstance(self.model[0].ops, CupyOps) \
and not isinstance(token_ids, state2vec.ops.xp.ndarray):
# Move token_ids and d_vector to GPU, asynchronously
backprops.append((
util.get_async(cuda_stream, token_ids),
util.get_async(cuda_stream, d_vector),
bp_vector
))
else:
backprops.append((token_ids, d_vector, bp_vector))
self.transition_batch(states, scores)
todo = [(st, gold) for (st, gold) in todo
@ -658,7 +456,6 @@ cdef class Parser:
for beam in beams:
_cleanup(beam)
def _init_gold_batch(self, whole_docs, whole_golds):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
@ -719,6 +516,11 @@ cdef class Parser:
names.append(name)
return names
@property
def labels(self):
return [label.split('-')[1] for label in self.move_names
if '-' in label]
def get_batch_model(self, docs, stream, dropout):
tok2vec, lower, upper = self.model
tokvecs, bp_tokvecs = tok2vec.begin_update(docs, drop=dropout)