Start rigging beam back up

This commit is contained in:
Matthew Honnibal 2021-11-01 12:39:16 +01:00
parent 13b0a24870
commit 3258efedfe

View File

@ -8,6 +8,7 @@ from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
import random import random
from typing import Optional from typing import Optional
import contextlib
import srsly import srsly
from thinc.api import set_dropout_rate, CupyOps, get_array_module from thinc.api import set_dropout_rate, CupyOps, get_array_module
@ -191,27 +192,23 @@ class Parser(TrainablePipe):
result = self.moves.init_batch(docs) result = self.moves.init_batch(docs)
self._resize() self._resize()
return result return result
if self.cfg["beam_width"] == 1: with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
return self.greedy_parse(docs, drop=0.0) states_or_beams, _ = self.model.predict((docs, self.moves))
else: return states_or_beams
return self.beam_parse(
docs,
drop=0.0,
beam_width=self.cfg["beam_width"],
beam_density=self.cfg["beam_density"]
)
def greedy_parse(self, docs, drop=0.): def greedy_parse(self, docs, drop=0.):
set_dropout_rate(self.model, drop) # Deprecated
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize() self._resize()
states, scores = self.model.predict((docs, self.moves)) with _change_attrs(self.model, beam_width=1):
states, _ = self.model.predict((docs, self.moves))
return states return states
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
raise NotImplementedError # Deprecated
self._resize()
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
beams, _ = self.model.predict((docs, self.moves))
return beams
def set_annotations(self, docs, states_or_beams): def set_annotations(self, docs, states_or_beams):
cdef StateClass state cdef StateClass state
@ -410,3 +407,19 @@ class Parser(TrainablePipe):
except AttributeError: except AttributeError:
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
return self return self
@contextlib.contextmanager
def _change_attrs(model, **kwargs):
"""Temporarily modify a thinc model's attributes."""
unset = object()
old_attrs = {}
for key, value in kwargs.items():
old_attrs[key] = model.attrs.get(key, unset)
model.attrs[key] = value
yield model
for key, value in old_attrs.items():
if value is unset:
model.attrs.pop(key)
else:
model.attrs[key] = value