mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Add beam decoding to parser, to allow NER uncertainties
This commit is contained in:
parent
0ca5832427
commit
3da1063b36
|
@ -10,6 +10,8 @@ from libc.stdint cimport uint32_t
|
|||
from libc.string cimport memcpy
|
||||
from cymem.cymem cimport Pool
|
||||
from collections import OrderedDict
|
||||
from thinc.extra.search cimport Beam
|
||||
import numpy
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC, is_space_token
|
||||
|
@ -510,3 +512,23 @@ cdef class ArcEager(TransitionSystem):
|
|||
"State at failure:\n"
|
||||
"%s" % (self.n_moves, stcls.print_state(gold.words)))
|
||||
assert n_gold >= 1
|
||||
|
||||
def get_beam_annot(self, Beam beam):
|
||||
length = (<StateClass>beam.at(0)).c.length
|
||||
heads = [{} for _ in range(length)]
|
||||
deps = [{} for _ in range(length)]
|
||||
probs = beam.probs
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
self.finalize_state(stcls.c)
|
||||
if stcls.is_final():
|
||||
prob = probs[i]
|
||||
for j in range(stcls.c.length):
|
||||
head = j + stcls.c._sent[j].head
|
||||
dep = stcls.c._sent[j].dep
|
||||
heads[j].setdefault(head, 0.0)
|
||||
heads[j][head] += prob
|
||||
deps[j].setdefault(dep, 0.0)
|
||||
deps[j][dep] += prob
|
||||
return heads, deps
|
||||
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from thinc.typedefs cimport weight_t
|
||||
from thinc.extra.search cimport Beam
|
||||
from collections import OrderedDict
|
||||
import numpy
|
||||
from thinc.neural.ops import NumpyOps
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
@ -122,6 +125,22 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
|
||||
return gold
|
||||
|
||||
def get_beam_annot(self, Beam beam):
|
||||
entities = {}
|
||||
probs = beam.probs
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if stcls.is_final():
|
||||
self.finalize_state(stcls.c)
|
||||
prob = probs[i]
|
||||
for j in range(stcls.c._e_i):
|
||||
start = stcls.c._ents[j].start
|
||||
end = stcls.c._ents[j].end
|
||||
label = stcls.c._ents[j].label
|
||||
entities.setdefault((start, end, label), 0.0)
|
||||
entities[(start, end, label)] += prob
|
||||
return entities
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
cdef attr_t label
|
||||
if name == '-' or name == None:
|
||||
|
|
|
@ -29,6 +29,7 @@ from thinc.linear.avgtron cimport AveragedPerceptron
|
|||
from thinc.linalg cimport VecVec
|
||||
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
|
||||
from thinc.extra.eg cimport Example
|
||||
from thinc.extra.search cimport Beam
|
||||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
@ -110,7 +111,6 @@ cdef class precompute_hiddens:
|
|||
self.nO = cached.shape[2]
|
||||
self.nP = getattr(lower_model, 'nP', 1)
|
||||
self.ops = lower_model.ops
|
||||
self._features = numpy.zeros((batch_size, self.nO*self.nP), dtype='f')
|
||||
self._is_synchronized = False
|
||||
self._cuda_stream = cuda_stream
|
||||
self._cached = cached
|
||||
|
@ -127,13 +127,12 @@ cdef class precompute_hiddens:
|
|||
return self.begin_update(X)[0]
|
||||
|
||||
def begin_update(self, token_ids, drop=0.):
|
||||
self._features.fill(0)
|
||||
cdef np.ndarray state_vector = numpy.zeros((token_ids.shape[0], self.nO*self.nP), dtype='f')
|
||||
# This is tricky, but (assuming GPU available);
|
||||
# - Input to forward on CPU
|
||||
# - Output from forward on CPU
|
||||
# - Input to backward on GPU!
|
||||
# - Output from backward on GPU
|
||||
cdef np.ndarray state_vector = self._features[:len(token_ids)]
|
||||
bp_hiddens = self._bp_hiddens
|
||||
|
||||
feat_weights = self.get_feat_weights()
|
||||
|
@ -305,7 +304,7 @@ cdef class Parser:
|
|||
def __reduce__(self):
|
||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
def __call__(self, Doc doc, beam_width=None, beam_density=None):
|
||||
"""
|
||||
Apply the parser or entity recognizer, setting the annotations onto the Doc object.
|
||||
|
||||
|
@ -314,11 +313,26 @@ cdef class Parser:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
if beam_width is None:
|
||||
beam_width = self.cfg.get('beam_width', 1)
|
||||
if beam_density is None:
|
||||
beam_density = self.cfg.get('beam_density', 0.001)
|
||||
cdef Beam beam
|
||||
if beam_width == 1:
|
||||
states = self.parse_batch([doc], [doc.tensor])
|
||||
self.set_annotations([doc], states)
|
||||
return doc
|
||||
else:
|
||||
beam = self.beam_parse([doc], [doc.tensor],
|
||||
beam_width=beam_width, beam_density=beam_density)[0]
|
||||
output = self.moves.get_beam_annot(beam)
|
||||
state = <StateClass>beam.at(0)
|
||||
self.set_annotations([doc], [state])
|
||||
_cleanup(beam)
|
||||
return output
|
||||
|
||||
def pipe(self, docs, int batch_size=1000, int n_threads=2):
|
||||
def pipe(self, docs, int batch_size=1000, int n_threads=2,
|
||||
beam_width=1, beam_density=0.001):
|
||||
"""
|
||||
Process a stream of documents.
|
||||
|
||||
|
@ -336,7 +350,11 @@ cdef class Parser:
|
|||
for docs in cytoolz.partition_all(batch_size, docs):
|
||||
docs = list(docs)
|
||||
tokvecs = [d.tensor for d in docs]
|
||||
if beam_width == 1:
|
||||
parse_states = self.parse_batch(docs, tokvecs)
|
||||
else:
|
||||
parse_states = self.beam_parse(docs, tokvecs,
|
||||
beam_width=beam_width, beam_density=beam_density)
|
||||
self.set_annotations(docs, parse_states)
|
||||
yield from docs
|
||||
|
||||
|
@ -404,6 +422,45 @@ cdef class Parser:
|
|||
next_step.push_back(st)
|
||||
return states
|
||||
|
||||
def beam_parse(self, docs, tokvecses, int beam_width=8, float beam_density=0.001):
|
||||
cdef Beam beam
|
||||
cdef np.ndarray scores
|
||||
cdef Doc doc
|
||||
cdef int nr_class = self.moves.n_moves
|
||||
cdef StateClass stcls, output
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
cuda_stream = get_cuda_stream()
|
||||
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs,
|
||||
cuda_stream, 0.0)
|
||||
beams = []
|
||||
cdef int offset = 0
|
||||
for doc in docs:
|
||||
beam = Beam(nr_class, beam_width, min_density=beam_density)
|
||||
beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
|
||||
for i in range(beam.width):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
stcls.c.offset = offset
|
||||
offset += len(doc)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
states = []
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
states.append(stcls)
|
||||
token_ids = self.get_token_ids(states)
|
||||
vectors = state2vec(token_ids)
|
||||
scores = vec2scores(vectors)
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.is_final():
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
for j in range(nr_class):
|
||||
beam.scores[i][j] = scores[i, j]
|
||||
beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
beams.append(beam)
|
||||
return beams
|
||||
|
||||
cdef void _parse_step(self, StateC* state,
|
||||
const float* feat_weights,
|
||||
int nr_class, int nr_feat, int nr_piece) nogil:
|
||||
|
@ -560,6 +617,7 @@ cdef class Parser:
|
|||
dtype='i', order='C')
|
||||
c_ids = <int*>ids.data
|
||||
for i, state in enumerate(states):
|
||||
if not state.is_final():
|
||||
state.c.set_context_tokens(c_ids, n_tokens)
|
||||
c_ids += ids.shape[1]
|
||||
return ids
|
||||
|
@ -762,3 +820,30 @@ cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actio
|
|||
mode = i
|
||||
score = scores[i]
|
||||
return mode
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
dest = <StateClass>_dest
|
||||
src = <StateClass>_src
|
||||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
return (<StateClass>_state).is_final()
|
||||
|
||||
|
||||
def _cleanup(Beam beam):
|
||||
for i in range(beam.width):
|
||||
Py_XDECREF(<PyObject*>beam._states[i].content)
|
||||
Py_XDECREF(<PyObject*>beam._parents[i].content)
|
||||
|
||||
|
||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||
state = <StateClass>_state
|
||||
if state.c.is_final():
|
||||
return 1
|
||||
else:
|
||||
return state.c.hash()
|
||||
|
|
|
@ -137,6 +137,10 @@ cdef class TransitionSystem:
|
|||
"the entity recognizer\n"
|
||||
"The transition system has %d actions." % (self.n_moves))
|
||||
|
||||
def get_class_name(self, int clas):
|
||||
act = self.c[clas]
|
||||
return self.move_name(act.move, act.label)
|
||||
|
||||
def add_action(self, int action, label_name):
|
||||
cdef attr_t label_id
|
||||
if not isinstance(label_name, int):
|
||||
|
|
Loading…
Reference in New Issue
Block a user