mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Rebatch parser inputs, with mid-sentence states
This commit is contained in:
parent
679efe79c8
commit
c245ff6b27
|
@ -426,9 +426,11 @@ cdef class Parser:
|
||||||
golds = [golds]
|
golds = [golds]
|
||||||
|
|
||||||
cuda_stream = get_cuda_stream()
|
cuda_stream = get_cuda_stream()
|
||||||
golds = [self.moves.preprocess_gold(g) for g in golds]
|
|
||||||
|
|
||||||
states = self.moves.init_batch(docs)
|
states, golds = self._init_gold_batch(docs, golds)
|
||||||
|
max_length = min([len(doc) for doc in docs])
|
||||||
|
#golds = [self.moves.preprocess_gold(g) for g in golds]
|
||||||
|
#states = self.moves.init_batch(docs)
|
||||||
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
||||||
0.0)
|
0.0)
|
||||||
|
|
||||||
|
@ -438,6 +440,7 @@ cdef class Parser:
|
||||||
backprops = []
|
backprops = []
|
||||||
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
|
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
|
||||||
cdef float loss = 0.
|
cdef float loss = 0.
|
||||||
|
#while len(todo and len(todo) >= len(states):
|
||||||
while todo:
|
while todo:
|
||||||
states, golds = zip(*todo)
|
states, golds = zip(*todo)
|
||||||
|
|
||||||
|
@ -467,10 +470,54 @@ cdef class Parser:
|
||||||
todo = [st for st in todo if not st[0].is_final()]
|
todo = [st for st in todo if not st[0].is_final()]
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
|
if len(backprops) >= (max_length * 2):
|
||||||
|
break
|
||||||
self._make_updates(d_tokvecs,
|
self._make_updates(d_tokvecs,
|
||||||
backprops, sgd, cuda_stream)
|
backprops, sgd, cuda_stream)
|
||||||
return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
||||||
|
|
||||||
|
def _init_gold_batch(self, docs, golds):
|
||||||
|
"""Make a square batch, of length equal to the shortest doc. A long
|
||||||
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
where N is the shortest doc. We'll make two states, one representing
|
||||||
|
long_doc[:N], and another representing long_doc[N:]."""
|
||||||
|
cdef StateClass state
|
||||||
|
lengths = [len(doc) for doc in docs]
|
||||||
|
# Cap to min length
|
||||||
|
min_length = min(lengths)
|
||||||
|
offset = 0
|
||||||
|
states = []
|
||||||
|
extra_golds = []
|
||||||
|
cdef np.ndarray py_costs = numpy.zeros((self.moves.n_moves,), dtype='f')
|
||||||
|
cdef np.ndarray py_is_valid = numpy.zeros((self.moves.n_moves,), dtype='i')
|
||||||
|
costs = <float*>py_costs.data
|
||||||
|
is_valid = <int*>py_is_valid.data
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
gold = self.moves.preprocess_gold(gold)
|
||||||
|
state = StateClass(doc, offset=offset)
|
||||||
|
self.moves.initialize_state(state.c)
|
||||||
|
states.append(state)
|
||||||
|
extra_golds.append(gold)
|
||||||
|
start = min(min_length, len(doc))
|
||||||
|
while start < len(doc):
|
||||||
|
length = min(min_length, len(doc)-start)
|
||||||
|
state = StateClass(doc, offset=offset)
|
||||||
|
self.moves.initialize_state(state.c)
|
||||||
|
while state.B(0) < start and not state.is_final():
|
||||||
|
py_is_valid.fill(0)
|
||||||
|
py_costs.fill(0)
|
||||||
|
self.moves.set_costs(is_valid, costs, state, gold)
|
||||||
|
for i in range(self.moves.n_moves):
|
||||||
|
if is_valid[i] and costs[i] <= 0:
|
||||||
|
self.moves.c[i].do(state.c, self.moves.c[i].label)
|
||||||
|
break
|
||||||
|
start += length
|
||||||
|
if not state.is_final():
|
||||||
|
states.append(state)
|
||||||
|
extra_golds.append(gold)
|
||||||
|
offset += len(doc)
|
||||||
|
return states, extra_golds
|
||||||
|
|
||||||
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
||||||
# Tells CUDA to block, so our async copies complete.
|
# Tells CUDA to block, so our async copies complete.
|
||||||
if cuda_stream is not None:
|
if cuda_stream is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user