Start rigging beam back up

This commit is contained in:
Matthew Honnibal 2021-11-01 12:39:16 +01:00 committed by svlandeg
parent 07603a26ae
commit 394862b0f4

View File

@ -7,6 +7,7 @@ from libcpp.vector cimport vector
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
import random import random
import contextlib
import srsly import srsly
from thinc.api import set_dropout_rate, CupyOps, get_array_module from thinc.api import set_dropout_rate, CupyOps, get_array_module
@ -210,14 +211,21 @@ class Parser(TrainablePipe):
with self.model.use_params(params): with self.model.use_params(params):
yield yield
def __call__(self, Doc doc):
"""Apply the parser or entity recognizer, setting the annotations onto
the `Doc` object.
doc (Doc): The document to be processed.
"""
states = self.predict([doc])
self.set_annotations([doc], states)
return doc
def pipe(self, docs, *, int batch_size=256): def pipe(self, docs, *, int batch_size=256):
"""Process a stream of documents. """Process a stream of documents.
stream: The sequence of documents to process. stream: The sequence of documents to process.
batch_size (int): Number of documents to accumulate into a working set. batch_size (int): Number of documents to accumulate into a working set.
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
deals with a failing batch of documents. The default function just reraises
the exception.
YIELDS (Doc): Documents, in order. YIELDS (Doc): Documents, in order.
""" """
@ -242,27 +250,23 @@ class Parser(TrainablePipe):
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
result = self.moves.init_batch(docs) result = self.moves.init_batch(docs)
return result return result
if self.cfg["beam_width"] == 1: with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
return self.greedy_parse(docs, drop=0.0) states_or_beams, _ = self.model.predict((docs, self.moves))
else: return states_or_beams
return self.beam_parse(
docs,
drop=0.0,
beam_width=self.cfg["beam_width"],
beam_density=self.cfg["beam_density"]
)
def greedy_parse(self, docs, drop=0.): def greedy_parse(self, docs, drop=0.):
set_dropout_rate(self.model, drop) # Deprecated
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize() self._resize()
states, scores = self.model.predict((docs, self.moves)) with _change_attrs(self.model, beam_width=1):
states, _ = self.model.predict((docs, self.moves))
return states return states
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
raise NotImplementedError # Deprecated
self._resize()
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
beams, _ = self.model.predict((docs, self.moves))
return beams
def set_annotations(self, docs, states_or_beams): def set_annotations(self, docs, states_or_beams):
cdef StateClass state cdef StateClass state
@ -461,3 +465,19 @@ class Parser(TrainablePipe):
except AttributeError: except AttributeError:
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
return self return self
@contextlib.contextmanager
def _change_attrs(model, **kwargs):
"""Temporarily modify a thinc model's attributes."""
unset = object()
old_attrs = {}
for key, value in kwargs.items():
old_attrs[key] = model.attrs.get(key, unset)
model.attrs[key] = value
yield model
for key, value in old_attrs.items():
if value is unset:
model.attrs.pop(key)
else:
model.attrs[key] = value