mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-05 04:40:20 +03:00
Start rigging beam back up
This commit is contained in:
parent
07603a26ae
commit
394862b0f4
|
@ -7,6 +7,7 @@ from libcpp.vector cimport vector
|
|||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free
|
||||
import random
|
||||
import contextlib
|
||||
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate, CupyOps, get_array_module
|
||||
|
@ -210,14 +211,21 @@ class Parser(TrainablePipe):
|
|||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Apply the parser or entity recognizer, setting the annotations onto
|
||||
the `Doc` object.
|
||||
|
||||
doc (Doc): The document to be processed.
|
||||
"""
|
||||
states = self.predict([doc])
|
||||
self.set_annotations([doc], states)
|
||||
return doc
|
||||
|
||||
def pipe(self, docs, *, int batch_size=256):
|
||||
"""Process a stream of documents.
|
||||
|
||||
stream: The sequence of documents to process.
|
||||
batch_size (int): Number of documents to accumulate into a working set.
|
||||
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
|
||||
deals with a failing batch of documents. The default function just reraises
|
||||
the exception.
|
||||
|
||||
YIELDS (Doc): Documents, in order.
|
||||
"""
|
||||
|
@ -242,27 +250,23 @@ class Parser(TrainablePipe):
|
|||
if not any(len(doc) for doc in docs):
|
||||
result = self.moves.init_batch(docs)
|
||||
return result
|
||||
if self.cfg["beam_width"] == 1:
|
||||
return self.greedy_parse(docs, drop=0.0)
|
||||
else:
|
||||
return self.beam_parse(
|
||||
docs,
|
||||
drop=0.0,
|
||||
beam_width=self.cfg["beam_width"],
|
||||
beam_density=self.cfg["beam_density"]
|
||||
)
|
||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||
states_or_beams, _ = self.model.predict((docs, self.moves))
|
||||
return states_or_beams
|
||||
|
||||
def greedy_parse(self, docs, drop=0.):
|
||||
set_dropout_rate(self.model, drop)
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
# expand our model output.
|
||||
# Deprecated
|
||||
self._resize()
|
||||
states, scores = self.model.predict((docs, self.moves))
|
||||
with _change_attrs(self.model, beam_width=1):
|
||||
states, _ = self.model.predict((docs, self.moves))
|
||||
return states
|
||||
|
||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||
raise NotImplementedError
|
||||
# Deprecated
|
||||
self._resize()
|
||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||
beams, _ = self.model.predict((docs, self.moves))
|
||||
return beams
|
||||
|
||||
def set_annotations(self, docs, states_or_beams):
|
||||
cdef StateClass state
|
||||
|
@ -461,3 +465,19 @@ class Parser(TrainablePipe):
|
|||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
return self
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _change_attrs(model, **kwargs):
|
||||
"""Temporarily modify a thinc model's attributes."""
|
||||
unset = object()
|
||||
old_attrs = {}
|
||||
for key, value in kwargs.items():
|
||||
old_attrs[key] = model.attrs.get(key, unset)
|
||||
model.attrs[key] = value
|
||||
yield model
|
||||
for key, value in old_attrs.items():
|
||||
if value is unset:
|
||||
model.attrs.pop(key)
|
||||
else:
|
||||
model.attrs[key] = value
|
||||
|
|
Loading…
Reference in New Issue
Block a user