mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Start rigging beam back up
This commit is contained in:
parent
13b0a24870
commit
3258efedfe
|
@ -8,6 +8,7 @@ from libc.string cimport memset, memcpy
|
|||
from libc.stdlib cimport calloc, free
|
||||
import random
|
||||
from typing import Optional
|
||||
import contextlib
|
||||
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate, CupyOps, get_array_module
|
||||
|
@ -191,27 +192,23 @@ class Parser(TrainablePipe):
|
|||
result = self.moves.init_batch(docs)
|
||||
self._resize()
|
||||
return result
|
||||
if self.cfg["beam_width"] == 1:
|
||||
return self.greedy_parse(docs, drop=0.0)
|
||||
else:
|
||||
return self.beam_parse(
|
||||
docs,
|
||||
drop=0.0,
|
||||
beam_width=self.cfg["beam_width"],
|
||||
beam_density=self.cfg["beam_density"]
|
||||
)
|
||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||
states_or_beams, _ = self.model.predict((docs, self.moves))
|
||||
return states_or_beams
|
||||
|
||||
def greedy_parse(self, docs, drop=0.):
|
||||
set_dropout_rate(self.model, drop)
|
||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
||||
# if labels are missing. We therefore have to check whether we need to
|
||||
# expand our model output.
|
||||
# Deprecated
|
||||
self._resize()
|
||||
states, scores = self.model.predict((docs, self.moves))
|
||||
with _change_attrs(self.model, beam_width=1):
|
||||
states, _ = self.model.predict((docs, self.moves))
|
||||
return states
|
||||
|
||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||
raise NotImplementedError
|
||||
# Deprecated
|
||||
self._resize()
|
||||
with _change_attrs(self.model, beam_width=self.cfg["beam_width"], beam_density=self.cfg["beam_density"]):
|
||||
beams, _ = self.model.predict((docs, self.moves))
|
||||
return beams
|
||||
|
||||
def set_annotations(self, docs, states_or_beams):
|
||||
cdef StateClass state
|
||||
|
@ -410,3 +407,19 @@ class Parser(TrainablePipe):
|
|||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
return self
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _change_attrs(model, **kwargs):
|
||||
"""Temporarily modify a thinc model's attributes."""
|
||||
unset = object()
|
||||
old_attrs = {}
|
||||
for key, value in kwargs.items():
|
||||
old_attrs[key] = model.attrs.get(key, unset)
|
||||
model.attrs[key] = value
|
||||
yield model
|
||||
for key, value in old_attrs.items():
|
||||
if value is unset:
|
||||
model.attrs.pop(key)
|
||||
else:
|
||||
model.attrs[key] = value
|
||||
|
|
Loading…
Reference in New Issue
Block a user