mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add back support for beam parsing to the refactored parser (#10633)
* Add back support for beam parsing Beam parsing was already implemented as part of the `BeamBatch` class. This change makes its counterpart `GreedyBatch`. Both classes are hooked up in `TransitionModel`, selecting `GreedyBatch` when the beam size is one, or `BeamBatch` otherwise. * Use kwarg for beam width Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Avoid implicit default for beam_width and beam_density * Parser.{beam,greedy}_parse: ensure labels are added * Remove 'deprecated' comments Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
8cc29154ac
commit
e3a5350540
1
setup.py
1
setup.py
|
@ -42,6 +42,7 @@ MOD_NAMES = [
|
|||
"spacy.pipeline.tagger",
|
||||
"spacy.pipeline.transition_parser",
|
||||
"spacy.pipeline._parser_internals.arc_eager",
|
||||
"spacy.pipeline._parser_internals.batch",
|
||||
"spacy.pipeline._parser_internals.ner",
|
||||
"spacy.pipeline._parser_internals.nonproj",
|
||||
"spacy.pipeline._parser_internals._state",
|
||||
|
|
|
@ -3,6 +3,8 @@ from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
|||
from thinc.api import uniform_init, glorot_uniform_init, zero_init
|
||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||
import numpy
|
||||
from ..pipeline._parser_internals import _beam_utils
|
||||
from ..pipeline._parser_internals.batch import GreedyBatch
|
||||
from ..tokens.doc import Doc
|
||||
from ..util import registry
|
||||
|
||||
|
@ -15,6 +17,8 @@ State = Any # TODO
|
|||
def TransitionModel(
|
||||
*,
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
beam_width: int = 1,
|
||||
beam_density: float = 0.0,
|
||||
state_tokens: int,
|
||||
hidden_width: int,
|
||||
maxout_pieces: int,
|
||||
|
@ -49,6 +53,8 @@ def TransitionModel(
|
|||
"nF": state_tokens,
|
||||
},
|
||||
attrs={
|
||||
"beam_width": beam_width,
|
||||
"beam_density": beam_density,
|
||||
"unseen_classes": set(unseen_classes),
|
||||
"resize_output": resize_output,
|
||||
},
|
||||
|
@ -139,6 +145,9 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
nO = model.get_dim("nO")
|
||||
nI = model.get_dim("nI")
|
||||
|
||||
beam_width = model.attrs["beam_width"]
|
||||
beam_density = model.attrs["beam_density"]
|
||||
|
||||
ops = model.ops
|
||||
docs, moves = docs_moves
|
||||
states = moves.init_batch(docs)
|
||||
|
@ -149,20 +158,24 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
all_which = []
|
||||
all_statevecs = []
|
||||
all_scores = []
|
||||
next_states = [s for s in states if not s.is_final()]
|
||||
if beam_width == 1:
|
||||
batch = GreedyBatch(moves, states, None)
|
||||
else:
|
||||
batch = _beam_utils.BeamBatch(
|
||||
moves, states, None, width=beam_width, density=beam_density
|
||||
)
|
||||
seen_mask = _get_seen_mask(model)
|
||||
ids = numpy.zeros((len(states), nF), dtype="i")
|
||||
arange = model.ops.xp.arange(nF)
|
||||
while next_states:
|
||||
ids = ids[: len(next_states)]
|
||||
for i, state in enumerate(next_states):
|
||||
while not batch.is_done:
|
||||
ids = numpy.zeros((len(batch.get_unfinished_states()), nF), dtype="i")
|
||||
for i, state in enumerate(batch.get_unfinished_states()):
|
||||
state.set_context_tokens(ids, i, nF)
|
||||
# Sum the state features, add the bias and apply the activation (maxout)
|
||||
# to create the state vectors.
|
||||
preacts2f = feats[ids, arange].sum(axis=1) # type: ignore
|
||||
preacts2f += lower_b
|
||||
preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP)
|
||||
assert preacts.shape[0] == len(next_states), preacts.shape
|
||||
assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape
|
||||
statevecs, which = ops.maxout(preacts)
|
||||
# Multiply the state-vector by the scores weights and add the bias,
|
||||
# to get the logits.
|
||||
|
@ -171,11 +184,11 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
scores[:, seen_mask] = model.ops.xp.nanmin(scores)
|
||||
# Transition the states, filtering out any that are finished.
|
||||
cpu_scores = model.ops.to_numpy(scores)
|
||||
next_states = moves.transition_states(next_states, cpu_scores)
|
||||
batch.advance(cpu_scores)
|
||||
all_scores.append(scores)
|
||||
if is_train:
|
||||
# Remember intermediate results for the backprop.
|
||||
all_ids.append(ids.copy())
|
||||
all_ids.append(ids)
|
||||
all_statevecs.append(statevecs)
|
||||
all_which.append(which)
|
||||
|
||||
|
@ -211,7 +224,7 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo
|
|||
model.inc_grad("lower_pad", d_tokvecs[-1])
|
||||
return (backprop_tok2vec(d_tokvecs[:-1]), None)
|
||||
|
||||
return (states, all_scores), backprop_parser
|
||||
return (list(batch), all_scores), backprop_parser
|
||||
|
||||
|
||||
def _forward_reference(
|
||||
|
|
|
@ -10,6 +10,7 @@ from thinc.extra.search cimport MaxViolation
|
|||
from ...typedefs cimport hash_t, class_t
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ...errors import Errors
|
||||
from .batch cimport Batch
|
||||
from .stateclass cimport StateC, StateClass
|
||||
|
||||
|
||||
|
@ -27,7 +28,7 @@ cdef int check_final_state(void* _state, void* extra_args) except -1:
|
|||
return state.is_final()
|
||||
|
||||
|
||||
cdef class BeamBatch(object):
|
||||
cdef class BeamBatch(Batch):
|
||||
cdef public TransitionSystem moves
|
||||
cdef public object states
|
||||
cdef public object docs
|
||||
|
|
2
spacy/pipeline/_parser_internals/batch.pxd
Normal file
2
spacy/pipeline/_parser_internals/batch.pxd
Normal file
|
@ -0,0 +1,2 @@
|
|||
cdef class Batch:
|
||||
pass
|
49
spacy/pipeline/_parser_internals/batch.pyx
Normal file
49
spacy/pipeline/_parser_internals/batch.pyx
Normal file
|
@ -0,0 +1,49 @@
|
|||
from typing import Any
|
||||
|
||||
TransitionSystem = Any # TODO
|
||||
|
||||
cdef class Batch:
|
||||
def advance(self, scores):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_states(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_done(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_unfinished_states(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, i):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyBatch(Batch):
|
||||
def __init__(self, moves: TransitionSystem, states, golds):
|
||||
self._moves = moves
|
||||
self._states = states
|
||||
self._next_states = [s for s in states if not s.is_final()]
|
||||
|
||||
def advance(self, scores):
|
||||
self._next_states = self._moves.transition_states(self._next_states, scores)
|
||||
|
||||
def get_states(self):
|
||||
return self._states
|
||||
|
||||
@property
|
||||
def is_done(self):
|
||||
return all(s.is_final() for s in self._states)
|
||||
|
||||
def get_unfinished_states(self):
|
||||
return [st for st in self._states if not st.is_final()]
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self._states[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._states)
|
|
@ -250,15 +250,14 @@ class Parser(TrainablePipe):
|
|||
return states_or_beams
|
||||
|
||||
def greedy_parse(self, docs, drop=0.):
|
||||
# TODO: Deprecated
|
||||
self._resize()
|
||||
self._ensure_labels_are_added(docs)
|
||||
with _change_attrs(self.model, beam_width=1):
|
||||
states, _ = self.model.predict((docs, self.moves))
|
||||
return states
|
||||
|
||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||
# TODO: Deprecated
|
||||
self._resize()
|
||||
self._ensure_labels_are_added(docs)
|
||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||
beams, _ = self.model.predict((docs, self.moves))
|
||||
return beams
|
||||
|
|
|
@ -181,7 +181,6 @@ def test_issue4267():
|
|||
assert token.ent_iob == 2
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="no beam parser yet")
|
||||
@pytest.mark.issue(4313)
|
||||
def test_issue4313():
|
||||
"""This should not crash or exit with some strange error code"""
|
||||
|
@ -597,7 +596,6 @@ def test_overfitting_IO():
|
|||
assert ents[1].kb_id == 0
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="no beam parser yet")
|
||||
def test_beam_ner_scores():
|
||||
# Test that we can get confidence values out of the beam_ner pipe
|
||||
beam_width = 16
|
||||
|
@ -633,7 +631,6 @@ def test_beam_ner_scores():
|
|||
assert 0 - eps <= score <= 1 + eps
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="no beam parser yet")
|
||||
def test_beam_overfitting_IO(neg_key):
|
||||
# Simple test to try and quickly overfit the Beam NER component
|
||||
nlp = English()
|
||||
|
|
|
@ -401,7 +401,6 @@ def test_overfitting_IO(pipe_name):
|
|||
assert_equal(batch_deps_1, no_batch_deps)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="no beam parser yet")
|
||||
def test_beam_parser_scores():
|
||||
# Test that we can get confidence values out of the beam_parser pipe
|
||||
beam_width = 16
|
||||
|
@ -440,7 +439,6 @@ def test_beam_parser_scores():
|
|||
assert 0 - eps <= head_score <= 1 + eps
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="no beam parser yet")
|
||||
def test_beam_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the Beam dependency parser
|
||||
nlp = English()
|
||||
|
|
Loading…
Reference in New Issue
Block a user