Test with update_with_oracle_cut_size={0, 1, 5, 100}

And fix a git that occurs with a cut size of 1.
This commit is contained in:
Daniël de Kok 2023-01-13 16:04:07 +01:00
parent 850ce0583d
commit c20572a82a
2 changed files with 7 additions and 3 deletions

View File

@ -306,7 +306,7 @@ class Parser(TrainablePipe):
if max_moves >= 1: if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the # Chop sequences into lengths of this many words, to make the
# batch uniform length. # batch uniform length.
max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
init_states, gold_states, _ = self._init_gold_batch( init_states, gold_states, _ = self._init_gold_batch(
examples, examples,
max_length=max_moves max_length=max_moves

View File

@ -1,3 +1,4 @@
import itertools
import pytest import pytest
import numpy import numpy
from numpy.testing import assert_equal from numpy.testing import assert_equal
@ -401,12 +402,15 @@ def test_incomplete_data(pipe_name):
assert doc[2].head.i == 1 assert doc[2].head.i == 1
@pytest.mark.parametrize("pipe_name", PARSERS) @pytest.mark.parametrize(
def test_overfitting_IO(pipe_name): "pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100])
)
def test_overfitting_IO(pipe_name, max_moves):
fix_random_seed(0) fix_random_seed(0)
# Simple test to try and quickly overfit the dependency parser (normal or beam) # Simple test to try and quickly overfit the dependency parser (normal or beam)
nlp = English() nlp = English()
parser = nlp.add_pipe(pipe_name) parser = nlp.add_pipe(pipe_name)
parser.cfg["update_with_oracle_cut_size"] = max_moves
train_examples = [] train_examples = []
for text, annotations in TRAIN_DATA: for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))