diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py index ad0dfa7a1..45c85d969 100644 --- a/spacy/tests/parser/test_nn_beam.py +++ b/spacy/tests/parser/test_nn_beam.py @@ -8,6 +8,7 @@ from ...syntax.arc_eager import ArcEager from ...tokens import Doc from ...gold import GoldParse from ...syntax._beam_utils import ParserBeam, update_beam +from ...syntax.stateclass import StateClass @pytest.fixture @@ -27,6 +28,10 @@ def moves(vocab): def docs(vocab): return [Doc(vocab, words=['Rats', 'bite', 'things'])] +@pytest.fixture +def states(docs): + return [StateClass(doc) for doc in docs] + @pytest.fixture def tokvecs(docs, vector_size): output = [] @@ -57,8 +62,8 @@ def vector_size(): @pytest.fixture -def beam(moves, docs, golds, beam_width): - return ParserBeam(moves, docs, golds, width=beam_width) +def beam(moves, states, golds, beam_width): + return ParserBeam(moves, states, golds, width=beam_width) @pytest.fixture def scores(moves, batch_size, beam_width): @@ -80,19 +85,3 @@ def test_beam_advance(beam, scores): def test_beam_advance_too_few_scores(beam, scores): with pytest.raises(IndexError): beam.advance(scores[:-1]) - - -def test_update_beam(moves, docs, tokvecs, golds, vector_size): - @layerize - def state2vec(X, drop=0.): - vec = numpy.ones((X.shape[0], vector_size), dtype='f') - return vec, None - @layerize - def vec2scores(X, drop=0.): - scores = numpy.ones((X.shape[0], moves.n_moves), dtype='f') - return scores, None - d_loss, backprops = update_beam(moves, 13, docs, tokvecs, golds, - state2vec, vec2scores, drop=0.0, sgd=None, - losses={}, width=4, density=0.001) - -