diff --git a/setup.py b/setup.py index 397a9e8e5..f6d010c76 100755 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ MOD_NAMES = [ "spacy.pipeline.tagger", "spacy.pipeline.transition_parser", "spacy.pipeline._parser_internals.arc_eager", + "spacy.pipeline._parser_internals.batch", "spacy.pipeline._parser_internals.ner", "spacy.pipeline._parser_internals.nonproj", "spacy.pipeline._parser_internals._state", diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index dd2ff6c19..7aa1a9324 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -3,6 +3,8 @@ from thinc.api import Ops, Model, normal_init, chain, list2array, Linear from thinc.api import uniform_init, glorot_uniform_init, zero_init from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d import numpy +from ..pipeline._parser_internals import _beam_utils +from ..pipeline._parser_internals.batch import GreedyBatch from ..tokens.doc import Doc from ..util import registry @@ -15,6 +17,8 @@ State = Any # TODO def TransitionModel( *, tok2vec: Model[List[Doc], List[Floats2d]], + beam_width: int = 1, + beam_density: float = 0.0, state_tokens: int, hidden_width: int, maxout_pieces: int, @@ -49,6 +53,8 @@ def TransitionModel( "nF": state_tokens, }, attrs={ + "beam_width": beam_width, + "beam_density": beam_density, "unseen_classes": set(unseen_classes), "resize_output": resize_output, }, @@ -139,6 +145,9 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo nO = model.get_dim("nO") nI = model.get_dim("nI") + beam_width = model.attrs["beam_width"] + beam_density = model.attrs["beam_density"] + ops = model.ops docs, moves = docs_moves states = moves.init_batch(docs) @@ -149,20 +158,24 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo all_which = [] all_statevecs = [] all_scores = [] - next_states = [s for s in states if not s.is_final()] + if beam_width == 1: + batch = GreedyBatch(moves, states, None) + else: + batch = _beam_utils.BeamBatch( + moves, states, None, width=beam_width, density=beam_density + ) seen_mask = _get_seen_mask(model) - ids = numpy.zeros((len(states), nF), dtype="i") arange = model.ops.xp.arange(nF) - while next_states: - ids = ids[: len(next_states)] - for i, state in enumerate(next_states): + while not batch.is_done: + ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i") + for i, state in enumerate(batch.get_unfinished_states()): state.set_context_tokens(ids, i, nF) # Sum the state features, add the bias and apply the activation (maxout) # to create the state vectors. preacts2f = feats[ids, arange].sum(axis=1) # type: ignore preacts2f += lower_b preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) - assert preacts.shape[0] == len(next_states), preacts.shape + assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape statevecs, which = ops.maxout(preacts) # Multiply the state-vector by the scores weights and add the bias, # to get the logits. @@ -171,11 +184,11 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo scores[:, seen_mask] = model.ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. cpu_scores = model.ops.to_numpy(scores) - next_states = moves.transition_states(next_states, cpu_scores) + batch.advance(cpu_scores) all_scores.append(scores) if is_train: # Remember intermediate results for the backprop. - all_ids.append(ids.copy()) + all_ids.append(ids) all_statevecs.append(statevecs) all_which.append(which) @@ -211,7 +224,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo model.inc_grad("lower_pad", d_tokvecs[-1]) return (backprop_tok2vec(d_tokvecs[:-1]), None) - return (states, all_scores), backprop_parser + return (list(batch), all_scores), backprop_parser def _forward_reference( diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx index fa7df2056..6fb46e177 100644 --- a/spacy/pipeline/_parser_internals/_beam_utils.pyx +++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx @@ -10,6 +10,7 @@ from thinc.extra.search cimport MaxViolation from ...typedefs cimport hash_t, class_t from .transition_system cimport TransitionSystem, Transition from ...errors import Errors +from .batch cimport Batch from .stateclass cimport StateC, StateClass @@ -27,7 +28,7 @@ cdef int check_final_state(void* _state, void* extra_args) except -1: return state.is_final() -cdef class BeamBatch(object): +cdef class BeamBatch(Batch): cdef public TransitionSystem moves cdef public object states cdef public object docs diff --git a/spacy/pipeline/_parser_internals/batch.pxd b/spacy/pipeline/_parser_internals/batch.pxd new file mode 100644 index 000000000..60734e549 --- /dev/null +++ b/spacy/pipeline/_parser_internals/batch.pxd @@ -0,0 +1,2 @@ +cdef class Batch: + pass diff --git a/spacy/pipeline/_parser_internals/batch.pyx b/spacy/pipeline/_parser_internals/batch.pyx new file mode 100644 index 000000000..7928fb0b9 --- /dev/null +++ b/spacy/pipeline/_parser_internals/batch.pyx @@ -0,0 +1,49 @@ +from typing import Any + +TransitionSystem = Any # TODO + +cdef class Batch: + def advance(self, scores): + raise NotImplementedError + + def get_states(self): + raise NotImplementedError + + @property + def is_done(self): + raise NotImplementedError + + def get_unfinished_states(self): + raise NotImplementedError + + def __getitem__(self, i): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class GreedyBatch(Batch): + def __init__(self, moves: TransitionSystem, states, golds): + self._moves = moves + self._states = states + self._next_states = [s for s in states if not s.is_final()] + + def advance(self, scores): + self._next_states = self._moves.transition_states(self._next_states, scores) + + def get_states(self): + return self._states + + @property + def is_done(self): + return all(s.is_final() for s in self._states) + + def get_unfinished_states(self): + return [st for st in self._states if not st.is_final()] + + def __getitem__(self, i): + return self._states[i] + + def __len__(self): + return len(self._states) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index ba01a0c12..ce1d7e717 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -250,15 +250,14 @@ class Parser(TrainablePipe): return states_or_beams def greedy_parse(self, docs, drop=0.): - # TODO: Deprecated self._resize() + self._ensure_labels_are_added(docs) with _change_attrs(self.model, beam_width=1): states, _ = self.model.predict((docs, self.moves)) return states def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): - # TODO: Deprecated - self._resize() + self._ensure_labels_are_added(docs) with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]): beams, _ = self.model.predict((docs, self.moves)) return beams diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index c7eef189a..620c84465 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -181,7 +181,6 @@ def test_issue4267(): assert token.ent_iob == 2 -@pytest.mark.xfail(reason="no beam parser yet") @pytest.mark.issue(4313) def test_issue4313(): """This should not crash or exit with some strange error code""" @@ -597,7 +596,6 @@ def test_overfitting_IO(): assert ents[1].kb_id == 0 -@pytest.mark.xfail(reason="no beam parser yet") def test_beam_ner_scores(): # Test that we can get confidence values out of the beam_ner pipe beam_width = 16 @@ -633,7 +631,6 @@ def test_beam_ner_scores(): assert 0 - eps <= score <= 1 + eps -@pytest.mark.xfail(reason="no beam parser yet") def test_beam_overfitting_IO(neg_key): # Simple test to try and quickly overfit the Beam NER component nlp = English() diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 75b983eee..ca42a3f22 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -401,7 +401,6 @@ def test_overfitting_IO(pipe_name): assert_equal(batch_deps_1, no_batch_deps) -@pytest.mark.xfail(reason="no beam parser yet") def test_beam_parser_scores(): # Test that we can get confidence values out of the beam_parser pipe beam_width = 16 @@ -440,7 +439,6 @@ def test_beam_parser_scores(): assert 0 - eps <= head_score <= 1 + eps -@pytest.mark.xfail(reason="no beam parser yet") def test_beam_overfitting_IO(): # Simple test to try and quickly overfit the Beam dependency parser nlp = English()