mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-05 04:40:20 +03:00
Start rigging beam back up
This commit is contained in:
parent
07603a26ae
commit
394862b0f4
|
@ -7,6 +7,7 @@ from libcpp.vector cimport vector
|
||||||
from libc.string cimport memset, memcpy
|
from libc.string cimport memset, memcpy
|
||||||
from libc.stdlib cimport calloc, free
|
from libc.stdlib cimport calloc, free
|
||||||
import random
|
import random
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import set_dropout_rate, CupyOps, get_array_module
|
from thinc.api import set_dropout_rate, CupyOps, get_array_module
|
||||||
|
@ -210,14 +211,21 @@ class Parser(TrainablePipe):
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
def __call__(self, Doc doc):
|
||||||
|
"""Apply the parser or entity recognizer, setting the annotations onto
|
||||||
|
the `Doc` object.
|
||||||
|
|
||||||
|
doc (Doc): The document to be processed.
|
||||||
|
"""
|
||||||
|
states = self.predict([doc])
|
||||||
|
self.set_annotations([doc], states)
|
||||||
|
return doc
|
||||||
|
|
||||||
def pipe(self, docs, *, int batch_size=256):
|
def pipe(self, docs, *, int batch_size=256):
|
||||||
"""Process a stream of documents.
|
"""Process a stream of documents.
|
||||||
|
|
||||||
stream: The sequence of documents to process.
|
stream: The sequence of documents to process.
|
||||||
batch_size (int): Number of documents to accumulate into a working set.
|
batch_size (int): Number of documents to accumulate into a working set.
|
||||||
error_handler (Callable[[str, List[Doc], Exception], Any]): Function that
|
|
||||||
deals with a failing batch of documents. The default function just reraises
|
|
||||||
the exception.
|
|
||||||
|
|
||||||
YIELDS (Doc): Documents, in order.
|
YIELDS (Doc): Documents, in order.
|
||||||
"""
|
"""
|
||||||
|
@ -242,27 +250,23 @@ class Parser(TrainablePipe):
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
result = self.moves.init_batch(docs)
|
result = self.moves.init_batch(docs)
|
||||||
return result
|
return result
|
||||||
if self.cfg["beam_width"] == 1:
|
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||||
return self.greedy_parse(docs, drop=0.0)
|
states_or_beams, _ = self.model.predict((docs, self.moves))
|
||||||
else:
|
return states_or_beams
|
||||||
return self.beam_parse(
|
|
||||||
docs,
|
|
||||||
drop=0.0,
|
|
||||||
beam_width=self.cfg["beam_width"],
|
|
||||||
beam_density=self.cfg["beam_density"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
set_dropout_rate(self.model, drop)
|
# Deprecated
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
|
||||||
# expand our model output.
|
|
||||||
self._resize()
|
self._resize()
|
||||||
states, scores = self.model.predict((docs, self.moves))
|
with _change_attrs(self.model, beam_width=1):
|
||||||
|
states, _ = self.model.predict((docs, self.moves))
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
raise NotImplementedError
|
# Deprecated
|
||||||
|
self._resize()
|
||||||
|
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||||
|
beams, _ = self.model.predict((docs, self.moves))
|
||||||
|
return beams
|
||||||
|
|
||||||
def set_annotations(self, docs, states_or_beams):
|
def set_annotations(self, docs, states_or_beams):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -461,3 +465,19 @@ class Parser(TrainablePipe):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _change_attrs(model, **kwargs):
|
||||||
|
"""Temporarily modify a thinc model's attributes."""
|
||||||
|
unset = object()
|
||||||
|
old_attrs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
old_attrs[key] = model.attrs.get(key, unset)
|
||||||
|
model.attrs[key] = value
|
||||||
|
yield model
|
||||||
|
for key, value in old_attrs.items():
|
||||||
|
if value is unset:
|
||||||
|
model.attrs.pop(key)
|
||||||
|
else:
|
||||||
|
model.attrs[key] = value
|
||||||
|
|
Loading…
Reference in New Issue
Block a user