mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Test with update_with_oracle_cut_size={0, 1, 5, 100}
And fix a git that occurs with a cut size of 1.
This commit is contained in:
parent
850ce0583d
commit
c20572a82a
|
@ -306,7 +306,7 @@ class Parser(TrainablePipe):
|
||||||
if max_moves >= 1:
|
if max_moves >= 1:
|
||||||
# Chop sequences into lengths of this many words, to make the
|
# Chop sequences into lengths of this many words, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
|
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
|
||||||
init_states, gold_states, _ = self._init_gold_batch(
|
init_states, gold_states, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=max_moves
|
max_length=max_moves
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import itertools
|
||||||
import pytest
|
import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
|
@ -401,12 +402,15 @@ def test_incomplete_data(pipe_name):
|
||||||
assert doc[2].head.i == 1
|
assert doc[2].head.i == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pipe_name", PARSERS)
|
@pytest.mark.parametrize(
|
||||||
def test_overfitting_IO(pipe_name):
|
"pipe_name,max_moves", itertools.product(PARSERS, [0, 1, 5, 100])
|
||||||
|
)
|
||||||
|
def test_overfitting_IO(pipe_name, max_moves):
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
# Simple test to try and quickly overfit the dependency parser (normal or beam)
|
# Simple test to try and quickly overfit the dependency parser (normal or beam)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
parser = nlp.add_pipe(pipe_name)
|
parser = nlp.add_pipe(pipe_name)
|
||||||
|
parser.cfg["update_with_oracle_cut_size"] = max_moves
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user