mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Randomize the rebatch size in parser
This commit is contained in:
parent
0872cf611d
commit
661873ee4c
|
@ -555,7 +555,10 @@ cdef class Parser:
|
|||
for multitask in self._multitasks:
|
||||
multitask.update(docs, golds, drop=drop, sgd=sgd)
|
||||
cuda_stream = util.get_cuda_stream()
|
||||
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
||||
# Chop sequences into lengths of this many transitions, to make the
|
||||
# batch uniform length.
|
||||
cut_gold = numpy.random.choice(range(20, 100))
|
||||
states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||
drop)
|
||||
todo = [(s, g) for (s, g) in zip(states, golds)
|
||||
|
@ -659,7 +662,7 @@ cdef class Parser:
|
|||
_cleanup(beam)
|
||||
|
||||
|
||||
def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=2000):
|
||||
def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500):
|
||||
"""Make a square batch, of length equal to the shortest doc. A long
|
||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||
where N is the shortest doc. We'll make two states, one representing
|
||||
|
|
Loading…
Reference in New Issue
Block a user