Start rigging beam back up

This commit is contained in:
Matthew Honnibal 2021-11-01 12:39:16 +01:00
parent 13b0a24870
commit 3258efedfe

View File

@ -8,6 +8,7 @@ from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free
import random
from typing import Optional
import contextlib
import srsly
from thinc.api import set_dropout_rate, CupyOps, get_array_module
@ -191,27 +192,23 @@ class Parser(TrainablePipe):
result = self.moves.init_batch(docs)
self._resize()
return result
if self.cfg["beam_width"] == 1:
return self.greedy_parse(docs, drop=0.0)
else:
return self.beam_parse(
docs,
drop=0.0,
beam_width=self.cfg["beam_width"],
beam_density=self.cfg["beam_density"]
)
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
states_or_beams, _ = self.model.predict((docs, self.moves))
return states_or_beams
def greedy_parse(self, docs, drop=0.):
set_dropout_rate(self.model, drop)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
# Deprecated
self._resize()
states, scores = self.model.predict((docs, self.moves))
with _change_attrs(self.model, beam_width=1):
states, _ = self.model.predict((docs, self.moves))
return states
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
raise NotImplementedError
# Deprecated
self._resize()
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
beams, _ = self.model.predict((docs, self.moves))
return beams
def set_annotations(self, docs, states_or_beams):
cdef StateClass state
@ -410,3 +407,19 @@ class Parser(TrainablePipe):
except AttributeError:
raise ValueError(Errors.E149) from None
return self
@contextlib.contextmanager
def _change_attrs(model, **kwargs):
"""Temporarily modify a thinc model's attributes."""
unset = object()
old_attrs = {}
for key, value in kwargs.items():
old_attrs[key] = model.attrs.get(key, unset)
model.attrs[key] = value
yield model
for key, value in old_attrs.items():
if value is unset:
model.attrs.pop(key)
else:
model.attrs[key] = value