diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 085fd8ea8..e6119ee79 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -306,7 +306,7 @@ class Parser(TrainablePipe): if max_moves >= 1: # Chop sequences into lengths of this many words, to make the # batch uniform length. - max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) + max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) init_states, gold_states, _ = self._init_gold_batch( examples, max_length=max_moves diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index af33dcf5f..df463b700 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -1,3 +1,4 @@ +import itertools import pytest import numpy from numpy.testing import assert_equal @@ -401,12 +402,15 @@ def test_incomplete_data(pipe_name): assert doc[2].head.i == 1 -@pytest.mark.parametrize("pipe_name", PARSERS) -def test_overfitting_IO(pipe_name): +@pytest.mark.parametrize( + "pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100]) +) +def test_overfitting_IO(pipe_name, max_moves): fix_random_seed(0) # Simple test to try and quickly overfit the dependency parser (normal or beam) nlp = English() parser = nlp.add_pipe(pipe_name) + parser.cfg["update_with_oracle_cut_size"] = max_moves train_examples = [] for text, annotations in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))