mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-12 07:15:48 +03:00
Remove beam for now (maybe)
Remove beam_utils Update setup.py Remove beam
This commit is contained in:
parent
98ca14f577
commit
3c0fc10dc4
1
setup.py
1
setup.py
|
@ -39,7 +39,6 @@ MOD_NAMES = [
|
||||||
"spacy.tokenizer",
|
"spacy.tokenizer",
|
||||||
"spacy.syntax.nn_parser",
|
"spacy.syntax.nn_parser",
|
||||||
"spacy.syntax._parser_model",
|
"spacy.syntax._parser_model",
|
||||||
"spacy.syntax._beam_utils",
|
|
||||||
"spacy.syntax.nonproj",
|
"spacy.syntax.nonproj",
|
||||||
"spacy.syntax.transition_system",
|
"spacy.syntax.transition_system",
|
||||||
"spacy.syntax.arc_eager",
|
"spacy.syntax.arc_eager",
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
from ..typedefs cimport hash_t, class_t
|
|
||||||
|
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
|
||||||
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
|
|
||||||
|
|
||||||
cdef int check_final_state(void* _state, void* extra_args) except -1
|
|
||||||
|
|
||||||
|
|
||||||
cdef hash_t hash_state(void* _state, void* _) except 0
|
|
|
@ -1,329 +0,0 @@
|
||||||
# cython: infer_types=True, profile=True
|
|
||||||
cimport numpy as np
|
|
||||||
from cpython.ref cimport PyObject, Py_XDECREF
|
|
||||||
from thinc.extra.search cimport Beam
|
|
||||||
from thinc.extra.search cimport MaxViolation
|
|
||||||
|
|
||||||
from thinc.extra.search import MaxViolation
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
from ..typedefs cimport hash_t, class_t
|
|
||||||
from .transition_system cimport TransitionSystem, Transition
|
|
||||||
from .stateclass cimport StateC, StateClass
|
|
||||||
from ..gold.example cimport Example
|
|
||||||
|
|
||||||
from ..errors import Errors
|
|
||||||
|
|
||||||
|
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
|
||||||
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
|
||||||
dest = <StateC*>_dest
|
|
||||||
src = <StateC*>_src
|
|
||||||
moves = <const Transition*>_moves
|
|
||||||
dest.clone(src)
|
|
||||||
moves[clas].do(dest, moves[clas].label)
|
|
||||||
dest.push_hist(clas)
|
|
||||||
|
|
||||||
|
|
||||||
cdef int check_final_state(void* _state, void* extra_args) except -1:
|
|
||||||
state = <StateC*>_state
|
|
||||||
return state.is_final()
|
|
||||||
|
|
||||||
|
|
||||||
cdef hash_t hash_state(void* _state, void* _) except 0:
|
|
||||||
state = <StateC*>_state
|
|
||||||
if state.is_final():
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
return state.hash()
|
|
||||||
|
|
||||||
|
|
||||||
def collect_states(beams):
|
|
||||||
cdef StateClass state
|
|
||||||
cdef Beam beam
|
|
||||||
states = []
|
|
||||||
for state_or_beam in beams:
|
|
||||||
if isinstance(state_or_beam, StateClass):
|
|
||||||
states.append(state_or_beam)
|
|
||||||
else:
|
|
||||||
beam = state_or_beam
|
|
||||||
state = StateClass.borrow(<StateC*>beam.at(0))
|
|
||||||
states.append(state)
|
|
||||||
return states
|
|
||||||
|
|
||||||
|
|
||||||
cdef class ParserBeam(object):
|
|
||||||
cdef public TransitionSystem moves
|
|
||||||
cdef public object states
|
|
||||||
cdef public object golds
|
|
||||||
cdef public object beams
|
|
||||||
cdef public object dones
|
|
||||||
|
|
||||||
def __init__(self, TransitionSystem moves, states, golds,
|
|
||||||
int width, float density=0.):
|
|
||||||
self.moves = moves
|
|
||||||
self.states = states
|
|
||||||
self.golds = golds
|
|
||||||
self.beams = []
|
|
||||||
cdef Beam beam
|
|
||||||
cdef StateClass state
|
|
||||||
cdef StateC* st
|
|
||||||
for state in states:
|
|
||||||
beam = Beam(self.moves.n_moves, width, min_density=density)
|
|
||||||
beam.initialize(self.moves.init_beam_state,
|
|
||||||
self.moves.del_beam_state, state.c.length,
|
|
||||||
state.c._sent)
|
|
||||||
for i in range(beam.width):
|
|
||||||
st = <StateC*>beam.at(i)
|
|
||||||
st.offset = state.c.offset
|
|
||||||
self.beams.append(beam)
|
|
||||||
self.dones = [False] * len(self.beams)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_done(self):
|
|
||||||
return all(b.is_done or self.dones[i]
|
|
||||||
for i, b in enumerate(self.beams))
|
|
||||||
|
|
||||||
def __getitem__(self, i):
|
|
||||||
return self.beams[i]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.beams)
|
|
||||||
|
|
||||||
def advance(self, scores, follow_gold=False):
|
|
||||||
cdef Beam beam
|
|
||||||
for i, beam in enumerate(self.beams):
|
|
||||||
if beam.is_done or not scores[i].size or self.dones[i]:
|
|
||||||
continue
|
|
||||||
self._set_scores(beam, scores[i])
|
|
||||||
if self.golds is not None:
|
|
||||||
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
|
||||||
beam.advance(transition_state, hash_state, <void*>self.moves.c)
|
|
||||||
beam.check_done(check_final_state, NULL)
|
|
||||||
# This handles the non-monotonic stuff for the parser.
|
|
||||||
if beam.is_done and self.golds is not None:
|
|
||||||
for j in range(beam.size):
|
|
||||||
state = StateClass.borrow(<StateC*>beam.at(j))
|
|
||||||
if state.is_final():
|
|
||||||
try:
|
|
||||||
if self.moves.is_gold_parse(state, self.golds[i]):
|
|
||||||
beam._states[j].loss = 0.0
|
|
||||||
except NotImplementedError:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _set_scores(self, Beam beam, float[:, ::1] scores):
|
|
||||||
cdef float* c_scores = &scores[0, 0]
|
|
||||||
cdef int nr_state = min(scores.shape[0], beam.size)
|
|
||||||
cdef int nr_class = scores.shape[1]
|
|
||||||
for i in range(nr_state):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
if not state.is_final():
|
|
||||||
for j in range(nr_class):
|
|
||||||
beam.scores[i][j] = c_scores[i * nr_class + j]
|
|
||||||
self.moves.set_valid(beam.is_valid[i], state)
|
|
||||||
else:
|
|
||||||
for j in range(beam.nr_class):
|
|
||||||
beam.scores[i][j] = 0
|
|
||||||
beam.costs[i][j] = 0
|
|
||||||
|
|
||||||
def _set_costs(self, Beam beam, Example example, int follow_gold=False):
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = StateClass.borrow(<StateC*>beam.at(i))
|
|
||||||
if not state.is_final():
|
|
||||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
|
||||||
state, example)
|
|
||||||
if follow_gold:
|
|
||||||
min_cost = 0
|
|
||||||
for j in range(beam.nr_class):
|
|
||||||
if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
|
|
||||||
min_cost = beam.costs[i][j]
|
|
||||||
for j in range(beam.nr_class):
|
|
||||||
if beam.costs[i][j] > min_cost:
|
|
||||||
beam.is_valid[i][j] = 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_ids(states, int n_tokens):
|
|
||||||
cdef StateClass state
|
|
||||||
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
|
|
||||||
dtype='int32', order='C')
|
|
||||||
c_ids = <int*>ids.data
|
|
||||||
for i, state in enumerate(states):
|
|
||||||
if not state.is_final():
|
|
||||||
state.c.set_context_tokens(c_ids, n_tokens)
|
|
||||||
else:
|
|
||||||
ids[i] = -1
|
|
||||||
c_ids += ids.shape[1]
|
|
||||||
return ids
|
|
||||||
|
|
||||||
|
|
||||||
nr_update = 0
|
|
||||||
|
|
||||||
|
|
||||||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|
||||||
states, golds,
|
|
||||||
state2vec, vec2scores,
|
|
||||||
int width, losses=None, drop=0.,
|
|
||||||
early_update=True, beam_density=0.0):
|
|
||||||
global nr_update
|
|
||||||
cdef MaxViolation violn
|
|
||||||
nr_update += 1
|
|
||||||
pbeam = ParserBeam(moves, states, golds, width=width, density=beam_density)
|
|
||||||
gbeam = ParserBeam(moves, states, golds, width=width, density=beam_density)
|
|
||||||
cdef StateClass state
|
|
||||||
beam_maps = []
|
|
||||||
backprops = []
|
|
||||||
violns = [MaxViolation() for _ in range(len(states))]
|
|
||||||
for t in range(max_steps):
|
|
||||||
if pbeam.is_done and gbeam.is_done:
|
|
||||||
break
|
|
||||||
# The beam maps let us find the right row in the flattened scores
|
|
||||||
# arrays for each state. States are identified by (example id,
|
|
||||||
# history). We keep a different beam map for each step (since we'll
|
|
||||||
# have a flat scores array for each step). The beam map will let us
|
|
||||||
# take the per-state losses, and compute the gradient for each (step,
|
|
||||||
# state, class).
|
|
||||||
beam_maps.append({})
|
|
||||||
# Gather all states from the two beams in a list. Some stats may occur
|
|
||||||
# in both beams. To figure out which beam each state belonged to,
|
|
||||||
# we keep two lists of indices, p_indices and g_indices
|
|
||||||
states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1],
|
|
||||||
nr_update)
|
|
||||||
if not states:
|
|
||||||
break
|
|
||||||
# Now that we have our flat list of states, feed them through the model
|
|
||||||
token_ids = get_token_ids(states, nr_feature)
|
|
||||||
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
|
||||||
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
|
||||||
|
|
||||||
# Store the callbacks for the backward pass
|
|
||||||
backprops.append((token_ids, bp_vectors, bp_scores))
|
|
||||||
|
|
||||||
# Unpack the flat scores into lists for the two beams. The indices arrays
|
|
||||||
# tell us which example and state the scores-row refers to.
|
|
||||||
p_scores = [numpy.ascontiguousarray(scores[indices], dtype='f')
|
|
||||||
for indices in p_indices]
|
|
||||||
g_scores = [numpy.ascontiguousarray(scores[indices], dtype='f')
|
|
||||||
for indices in g_indices]
|
|
||||||
# Now advance the states in the beams. The gold beam is constrained to
|
|
||||||
# to follow only gold analyses.
|
|
||||||
pbeam.advance(p_scores)
|
|
||||||
gbeam.advance(g_scores, follow_gold=True)
|
|
||||||
# Track the "maximum violation", to use in the update.
|
|
||||||
for i, violn in enumerate(violns):
|
|
||||||
violn.check_crf(pbeam[i], gbeam[i])
|
|
||||||
histories = []
|
|
||||||
losses = []
|
|
||||||
for violn in violns:
|
|
||||||
if violn.p_hist:
|
|
||||||
histories.append(violn.p_hist + violn.g_hist)
|
|
||||||
losses.append(violn.p_probs + violn.g_probs)
|
|
||||||
else:
|
|
||||||
histories.append([])
|
|
||||||
losses.append([])
|
|
||||||
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, losses)
|
|
||||||
beams = list(pbeam.beams) + list(gbeam.beams)
|
|
||||||
return states_d_scores, backprops[:len(states_d_scores)], beams
|
|
||||||
|
|
||||||
|
|
||||||
def get_states(pbeams, gbeams, beam_map, nr_update):
|
|
||||||
seen = {}
|
|
||||||
states = []
|
|
||||||
p_indices = []
|
|
||||||
g_indices = []
|
|
||||||
cdef Beam pbeam, gbeam
|
|
||||||
if len(pbeams) != len(gbeams):
|
|
||||||
raise ValueError(Errors.E079.format(pbeams=len(pbeams), gbeams=len(gbeams)))
|
|
||||||
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
|
|
||||||
p_indices.append([])
|
|
||||||
g_indices.append([])
|
|
||||||
for i in range(pbeam.size):
|
|
||||||
state = StateClass.borrow(<StateC*>pbeam.at(i))
|
|
||||||
if not state.is_final():
|
|
||||||
key = tuple([eg_id] + pbeam.histories[i])
|
|
||||||
if key in seen:
|
|
||||||
raise ValueError(Errors.E080.format(key=key))
|
|
||||||
seen[key] = len(states)
|
|
||||||
p_indices[-1].append(len(states))
|
|
||||||
states.append(state)
|
|
||||||
beam_map.update(seen)
|
|
||||||
for i in range(gbeam.size):
|
|
||||||
state = StateClass.borrow(<StateC*>gbeam.at(i))
|
|
||||||
if not state.is_final():
|
|
||||||
key = tuple([eg_id] + gbeam.histories[i])
|
|
||||||
if key in seen:
|
|
||||||
g_indices[-1].append(seen[key])
|
|
||||||
else:
|
|
||||||
g_indices[-1].append(len(states))
|
|
||||||
beam_map[key] = len(states)
|
|
||||||
states.append(state)
|
|
||||||
p_idx = [numpy.asarray(idx, dtype='i') for idx in p_indices]
|
|
||||||
g_idx = [numpy.asarray(idx, dtype='i') for idx in g_indices]
|
|
||||||
return states, p_idx, g_idx
|
|
||||||
|
|
||||||
|
|
||||||
def get_gradient(nr_class, beam_maps, histories, losses):
|
|
||||||
"""The global model assigns a loss to each parse. The beam scores
|
|
||||||
are additive, so the same gradient is applied to each action
|
|
||||||
in the history. This gives the gradient of a single *action*
|
|
||||||
for a beam state -- so we have "the gradient of loss for taking
|
|
||||||
action i given history H."
|
|
||||||
|
|
||||||
Histories: Each hitory is a list of actions
|
|
||||||
Each candidate has a history
|
|
||||||
Each beam has multiple candidates
|
|
||||||
Each batch has multiple beams
|
|
||||||
So history is list of lists of lists of ints
|
|
||||||
"""
|
|
||||||
grads = []
|
|
||||||
nr_steps = []
|
|
||||||
for eg_id, hists in enumerate(histories):
|
|
||||||
nr_step = 0
|
|
||||||
for loss, hist in zip(losses[eg_id], hists):
|
|
||||||
if loss != 0.0 and not numpy.isnan(loss):
|
|
||||||
nr_step = max(nr_step, len(hist))
|
|
||||||
nr_steps.append(nr_step)
|
|
||||||
for i in range(max(nr_steps)):
|
|
||||||
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
|
|
||||||
dtype='f'))
|
|
||||||
if len(histories) != len(losses):
|
|
||||||
raise ValueError(Errors.E081.format(n_hist=len(histories), losses=len(losses)))
|
|
||||||
for eg_id, hists in enumerate(histories):
|
|
||||||
for loss, hist in zip(losses[eg_id], hists):
|
|
||||||
if loss == 0.0 or numpy.isnan(loss):
|
|
||||||
continue
|
|
||||||
key = tuple([eg_id])
|
|
||||||
# Adjust loss for length
|
|
||||||
# We need to do this because each state in a short path is scored
|
|
||||||
# multiple times, as we add in the average cost when we run out
|
|
||||||
# of actions.
|
|
||||||
avg_loss = loss / len(hist)
|
|
||||||
loss += avg_loss * (nr_steps[eg_id] - len(hist))
|
|
||||||
for j, clas in enumerate(hist):
|
|
||||||
i = beam_maps[j][key]
|
|
||||||
# In step j, at state i action clas
|
|
||||||
# resulted in loss
|
|
||||||
grads[j][i, clas] += loss
|
|
||||||
key = key + tuple([clas])
|
|
||||||
return grads
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_beam(Beam beam):
|
|
||||||
cdef StateC* state
|
|
||||||
# Once parsing has finished, states in beam may not be unique. Is this
|
|
||||||
# correct?
|
|
||||||
seen = set()
|
|
||||||
for i in range(beam.width):
|
|
||||||
addr = <size_t>beam._parents[i].content
|
|
||||||
if addr not in seen:
|
|
||||||
state = <StateC*>addr
|
|
||||||
del state
|
|
||||||
seen.add(addr)
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E023.format(addr=addr, i=i))
|
|
||||||
addr = <size_t>beam._states[i].content
|
|
||||||
if addr not in seen:
|
|
||||||
state = <StateC*>addr
|
|
||||||
del state
|
|
||||||
seen.add(addr)
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E023.format(addr=addr, i=i))
|
|
|
@ -1,7 +1,6 @@
|
||||||
# cython: profile=True, cdivision=True, infer_types=True
|
# cython: profile=True, cdivision=True, infer_types=True
|
||||||
from cpython.ref cimport Py_INCREF
|
from cpython.ref cimport Py_INCREF
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from thinc.extra.search cimport Beam
|
|
||||||
from libc.stdint cimport int32_t
|
from libc.stdint cimport int32_t
|
||||||
|
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
|
@ -379,8 +378,6 @@ cdef int _del_state(Pool mem, void* state, void* x) except -1:
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
TransitionSystem.__init__(self, *args, **kwargs)
|
TransitionSystem.__init__(self, *args, **kwargs)
|
||||||
self.init_beam_state = _init_state
|
|
||||||
self.del_beam_state = _del_state
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_actions(cls, **kwargs):
|
def get_actions(cls, **kwargs):
|
||||||
|
@ -444,22 +441,6 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def preprocess_gold(self, gold):
|
def preprocess_gold(self, gold):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_beam_parses(self, Beam beam):
|
|
||||||
parses = []
|
|
||||||
probs = beam.probs
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
if state.is_final():
|
|
||||||
self.finalize_state(state)
|
|
||||||
prob = probs[i]
|
|
||||||
parse = []
|
|
||||||
for j in range(state.length):
|
|
||||||
head = state.H(j)
|
|
||||||
label = self.strings[state._sent[j].dep]
|
|
||||||
parse.append((head, j, label))
|
|
||||||
parses.append((prob, parse))
|
|
||||||
return parses
|
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||||
if isinstance(name_or_id, int):
|
if isinstance(name_or_id, int):
|
||||||
return self.c[name_or_id]
|
return self.c[name_or_id]
|
||||||
|
@ -600,22 +581,3 @@ cdef class ArcEager(TransitionSystem):
|
||||||
failure_state = stcls.print_state([t.text for t in example])
|
failure_state = stcls.print_state([t.text for t in example])
|
||||||
raise ValueError(Errors.E021.format(n_actions=self.n_moves,
|
raise ValueError(Errors.E021.format(n_actions=self.n_moves,
|
||||||
state=failure_state))
|
state=failure_state))
|
||||||
|
|
||||||
def get_beam_annot(self, Beam beam):
|
|
||||||
length = (<StateC*>beam.at(0)).length
|
|
||||||
heads = [{} for _ in range(length)]
|
|
||||||
deps = [{} for _ in range(length)]
|
|
||||||
probs = beam.probs
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
self.finalize_state(state)
|
|
||||||
if state.is_final():
|
|
||||||
prob = probs[i]
|
|
||||||
for j in range(state.length):
|
|
||||||
head = j + state._sent[j].head
|
|
||||||
dep = state._sent[j].dep
|
|
||||||
heads[j].setdefault(head, 0.0)
|
|
||||||
heads[j][head] += prob
|
|
||||||
deps[j].setdefault(dep, 0.0)
|
|
||||||
deps[j][dep] += prob
|
|
||||||
return heads, deps
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from thinc.extra.search cimport Beam
|
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
from ..typedefs cimport weight_t
|
from ..typedefs cimport weight_t
|
||||||
|
@ -99,39 +97,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
def preprocess_gold(self, gold):
|
def preprocess_gold(self, gold):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_beam_annot(self, Beam beam):
|
|
||||||
entities = {}
|
|
||||||
probs = beam.probs
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
if state.is_final():
|
|
||||||
self.finalize_state(state)
|
|
||||||
prob = probs[i]
|
|
||||||
for j in range(state._e_i):
|
|
||||||
start = state._ents[j].start
|
|
||||||
end = state._ents[j].end
|
|
||||||
label = state._ents[j].label
|
|
||||||
entities.setdefault((start, end, label), 0.0)
|
|
||||||
entities[(start, end, label)] += prob
|
|
||||||
return entities
|
|
||||||
|
|
||||||
def get_beam_parses(self, Beam beam):
|
|
||||||
parses = []
|
|
||||||
probs = beam.probs
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
if state.is_final():
|
|
||||||
self.finalize_state(state)
|
|
||||||
prob = probs[i]
|
|
||||||
parse = []
|
|
||||||
for j in range(state._e_i):
|
|
||||||
start = state._ents[j].start
|
|
||||||
end = state._ents[j].end
|
|
||||||
label = state._ents[j].label
|
|
||||||
parse.append((start, end, self.strings[label]))
|
|
||||||
parses.append((prob, parse))
|
|
||||||
return parses
|
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
cdef attr_t label
|
cdef attr_t label
|
||||||
if name == '-' or name == '' or name is None:
|
if name == '-' or name == '' or name is None:
|
||||||
|
|
|
@ -8,7 +8,6 @@ from libcpp.vector cimport vector
|
||||||
from libc.string cimport memset, memcpy
|
from libc.string cimport memset, memcpy
|
||||||
from libc.stdlib cimport calloc, free
|
from libc.stdlib cimport calloc, free
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from thinc.extra.search cimport Beam
|
|
||||||
from thinc.backends.linalg cimport Vec, VecVec
|
from thinc.backends.linalg cimport Vec, VecVec
|
||||||
|
|
||||||
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
|
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
|
||||||
|
@ -28,21 +27,15 @@ from ._parser_model cimport get_c_weights, get_c_sizes
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
from .transition_system cimport Transition
|
from .transition_system cimport Transition
|
||||||
from . cimport _beam_utils
|
|
||||||
from ..gold.example cimport Example
|
from ..gold.example cimport Example
|
||||||
|
|
||||||
from ..util import link_vectors_to_models, create_default_optimizer, registry
|
from ..util import link_vectors_to_models, create_default_optimizer, registry
|
||||||
from ..compat import copy_array
|
from ..compat import copy_array
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from .. import util
|
from .. import util
|
||||||
from . import _beam_utils
|
|
||||||
from . import nonproj
|
from . import nonproj
|
||||||
|
|
||||||
|
|
||||||
def get_parses_from_example(example, merge=False, vocab=None):
|
|
||||||
# TODO: This is just a temporary shim to make the refactor easier.
|
|
||||||
return [(example.predicted, example)]
|
|
||||||
|
|
||||||
cdef class Parser:
|
cdef class Parser:
|
||||||
"""
|
"""
|
||||||
Base class of the DependencyParser and EntityRecognizer.
|
Base class of the DependencyParser and EntityRecognizer.
|
||||||
|
@ -146,71 +139,47 @@ cdef class Parser:
|
||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess_gold(self, examples):
|
|
||||||
for ex in examples:
|
|
||||||
yield ex
|
|
||||||
|
|
||||||
def use_params(self, params):
|
def use_params(self, params):
|
||||||
# Can't decorate cdef class :(. Workaround.
|
# Can't decorate cdef class :(. Workaround.
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def __call__(self, Doc doc, beam_width=None):
|
def __call__(self, Doc doc):
|
||||||
"""Apply the parser or entity recognizer, setting the annotations onto
|
"""Apply the parser or entity recognizer, setting the annotations onto
|
||||||
the `Doc` object.
|
the `Doc` object.
|
||||||
|
|
||||||
doc (Doc): The document to be processed.
|
doc (Doc): The document to be processed.
|
||||||
"""
|
"""
|
||||||
if beam_width is None:
|
states = self.predict([doc])
|
||||||
beam_width = self.cfg['beam_width']
|
|
||||||
beam_density = self.cfg.get('beam_density', 0.)
|
|
||||||
states = self.predict([doc], beam_width=beam_width,
|
|
||||||
beam_density=beam_density)
|
|
||||||
self.set_annotations([doc], states, tensors=None)
|
self.set_annotations([doc], states, tensors=None)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, docs, int batch_size=256, int n_threads=-1, beam_width=None,
|
def pipe(self, docs, int batch_size=256, int n_threads=-1):
|
||||||
as_example=False):
|
|
||||||
"""Process a stream of documents.
|
"""Process a stream of documents.
|
||||||
|
|
||||||
stream: The sequence of documents to process.
|
stream: The sequence of documents to process.
|
||||||
batch_size (int): Number of documents to accumulate into a working set.
|
batch_size (int): Number of documents to accumulate into a working set.
|
||||||
YIELDS (Doc): Documents, in order.
|
YIELDS (Doc): Documents, in order.
|
||||||
"""
|
"""
|
||||||
if beam_width is None:
|
|
||||||
beam_width = self.cfg['beam_width']
|
|
||||||
beam_density = self.cfg.get('beam_density', 0.)
|
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
for batch in util.minibatch(docs, size=batch_size):
|
for batch in util.minibatch(docs, size=batch_size):
|
||||||
batch_in_order = list(batch)
|
batch_in_order = list(batch)
|
||||||
docs = [self._get_doc(ex) for ex in batch_in_order]
|
docs = [ex.predicted for ex in batch_in_order]
|
||||||
by_length = sorted(docs, key=lambda doc: len(doc))
|
by_length = sorted(docs, key=lambda doc: len(doc))
|
||||||
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
|
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
|
||||||
subbatch = list(subbatch)
|
subbatch = list(subbatch)
|
||||||
parse_states = self.predict(subbatch, beam_width=beam_width,
|
parse_states = self.predict(subbatch)
|
||||||
beam_density=beam_density)
|
|
||||||
self.set_annotations(subbatch, parse_states, tensors=None)
|
self.set_annotations(subbatch, parse_states, tensors=None)
|
||||||
if as_example:
|
yield from batch_in_order
|
||||||
annotated_examples = []
|
|
||||||
for ex, doc in zip(batch_in_order, docs):
|
|
||||||
ex.doc = doc
|
|
||||||
annotated_examples.append(ex)
|
|
||||||
yield from annotated_examples
|
|
||||||
else:
|
|
||||||
yield from batch_in_order
|
|
||||||
|
|
||||||
def predict(self, docs, beam_width=1, beam_density=0.0, drop=0.):
|
def predict(self, docs):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
result = self.moves.init_batch(docs)
|
result = self.moves.init_batch(docs)
|
||||||
self._resize()
|
self._resize()
|
||||||
return result
|
return result
|
||||||
if beam_width < 2:
|
return self.greedy_parse(docs, drop=0.0)
|
||||||
return self.greedy_parse(docs, drop=drop)
|
|
||||||
else:
|
|
||||||
return self.beam_parse(docs, beam_width=beam_width,
|
|
||||||
beam_density=beam_density, drop=drop)
|
|
||||||
|
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
cdef vector[StateC*] states
|
cdef vector[StateC*] states
|
||||||
|
@ -232,44 +201,6 @@ cdef class Parser:
|
||||||
weights, sizes)
|
weights, sizes)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
|
||||||
cdef Beam beam
|
|
||||||
cdef Doc doc
|
|
||||||
cdef np.ndarray token_ids
|
|
||||||
set_dropout_rate(self.model, drop)
|
|
||||||
beams = self.moves.init_beams(docs, beam_width, beam_density=beam_density)
|
|
||||||
# This is pretty dirty, but the NER can resize itself in init_batch,
|
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
|
||||||
# expand our model output.
|
|
||||||
self._resize()
|
|
||||||
cdef int nr_feature = self.model.get_ref("lower").get_dim("nF")
|
|
||||||
model = self.model.predict(docs)
|
|
||||||
token_ids = numpy.zeros((len(docs) * beam_width, nr_feature),
|
|
||||||
dtype='i', order='C')
|
|
||||||
cdef int* c_ids
|
|
||||||
cdef int n_states
|
|
||||||
model = self.model.predict(docs)
|
|
||||||
todo = [beam for beam in beams if not beam.is_done]
|
|
||||||
while todo:
|
|
||||||
token_ids.fill(-1)
|
|
||||||
c_ids = <int*>token_ids.data
|
|
||||||
n_states = 0
|
|
||||||
for beam in todo:
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
# This way we avoid having to score finalized states
|
|
||||||
# We do have to take care to keep indexes aligned, though
|
|
||||||
if not state.is_final():
|
|
||||||
state.set_context_tokens(c_ids, nr_feature)
|
|
||||||
c_ids += nr_feature
|
|
||||||
n_states += 1
|
|
||||||
if n_states == 0:
|
|
||||||
break
|
|
||||||
vectors = model.state2vec.predict(token_ids[:n_states])
|
|
||||||
scores = model.vec2scores.predict(vectors)
|
|
||||||
todo = self.transition_beams(todo, scores)
|
|
||||||
return beams
|
|
||||||
|
|
||||||
cdef void _parseC(self, StateC** states,
|
cdef void _parseC(self, StateC** states,
|
||||||
WeightsC weights, SizesC sizes) nogil:
|
WeightsC weights, SizesC sizes) nogil:
|
||||||
cdef int i, j
|
cdef int i, j
|
||||||
|
@ -290,20 +221,9 @@ cdef class Parser:
|
||||||
unfinished.clear()
|
unfinished.clear()
|
||||||
free_activations(&activations)
|
free_activations(&activations)
|
||||||
|
|
||||||
def set_annotations(self, docs, states_or_beams, tensors=None):
|
def set_annotations(self, docs, states, tensors=None):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef Beam beam
|
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
states = []
|
|
||||||
beams = []
|
|
||||||
for state_or_beam in states_or_beams:
|
|
||||||
if isinstance(state_or_beam, StateClass):
|
|
||||||
states.append(state_or_beam)
|
|
||||||
else:
|
|
||||||
beam = state_or_beam
|
|
||||||
state = StateClass.borrow(<StateC*>beam.at(0))
|
|
||||||
states.append(state)
|
|
||||||
beams.append(beam)
|
|
||||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||||
self.moves.finalize_state(state.c)
|
self.moves.finalize_state(state.c)
|
||||||
for j in range(doc.length):
|
for j in range(doc.length):
|
||||||
|
@ -311,8 +231,6 @@ cdef class Parser:
|
||||||
self.moves.finalize_doc(doc)
|
self.moves.finalize_doc(doc)
|
||||||
for hook in self.postprocesses:
|
for hook in self.postprocesses:
|
||||||
hook(doc)
|
hook(doc)
|
||||||
for beam in beams:
|
|
||||||
_beam_utils.cleanup_beam(beam)
|
|
||||||
|
|
||||||
def transition_states(self, states, float[:, ::1] scores):
|
def transition_states(self, states, float[:, ::1] scores):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -344,20 +262,6 @@ cdef class Parser:
|
||||||
states[i].push_hist(guess)
|
states[i].push_hist(guess)
|
||||||
free(is_valid)
|
free(is_valid)
|
||||||
|
|
||||||
def transition_beams(self, beams, float[:, ::1] scores):
|
|
||||||
cdef Beam beam
|
|
||||||
cdef float* c_scores = &scores[0, 0]
|
|
||||||
for beam in beams:
|
|
||||||
for i in range(beam.size):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
if not state.is_final():
|
|
||||||
self.moves.set_valid(beam.is_valid[i], state)
|
|
||||||
memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float))
|
|
||||||
c_scores += scores.shape[1]
|
|
||||||
beam.advance(_beam_utils.transition_state, _beam_utils.hash_state, <void*>self.moves.c)
|
|
||||||
beam.check_done(_beam_utils.check_final_state, NULL)
|
|
||||||
return [b for b in beams if not b.is_done]
|
|
||||||
|
|
||||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||||
examples = Example.to_example_objects(examples)
|
examples = Example.to_example_objects(examples)
|
||||||
|
|
||||||
|
@ -366,23 +270,8 @@ cdef class Parser:
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
for multitask in self._multitasks:
|
for multitask in self._multitasks:
|
||||||
multitask.update(examples, drop=drop, sgd=sgd)
|
multitask.update(examples, drop=drop, sgd=sgd)
|
||||||
# The probability we use beam update, instead of falling back to
|
|
||||||
# a greedy update
|
|
||||||
beam_update_prob = self.cfg['beam_update_prob']
|
|
||||||
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
|
|
||||||
return self.update_beam(examples, self.cfg['beam_width'],
|
|
||||||
drop=drop, sgd=sgd, losses=losses, set_annotations=set_annotations,
|
|
||||||
beam_density=self.cfg.get('beam_density', 0.001))
|
|
||||||
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
cut_gold = True
|
states, golds, max_steps = self._init_gold_batch_no_cut(examples)
|
||||||
if cut_gold:
|
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
|
||||||
# batch uniform length.
|
|
||||||
cut_gold = numpy.random.choice(range(20, 100))
|
|
||||||
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
|
|
||||||
else:
|
|
||||||
states, golds, max_steps = self._init_gold_batch_no_cut(examples)
|
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||||
if not s.is_final() and g is not None]
|
if not s.is_final() and g is not None]
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
|
@ -450,52 +339,6 @@ cdef class Parser:
|
||||||
losses[self.name] += loss / n_scores
|
losses[self.name] += loss / n_scores
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def update_beam(self, examples, width, drop=0., sgd=None, losses=None,
|
|
||||||
set_annotations=False, beam_density=0.0):
|
|
||||||
examples = Example.to_example_objects(examples)
|
|
||||||
docs = [ex.doc for ex in examples]
|
|
||||||
golds = [ex.gold for ex in examples]
|
|
||||||
new_golds = []
|
|
||||||
lengths = [len(d) for d in docs]
|
|
||||||
states = self.moves.init_batch(docs)
|
|
||||||
for gold in golds:
|
|
||||||
self.moves.preprocess_gold(gold)
|
|
||||||
new_golds.append(gold)
|
|
||||||
set_dropout_rate(self.model, drop)
|
|
||||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
|
||||||
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
|
||||||
self.moves,
|
|
||||||
self.model.get_ref("lower").get_dim("nF"),
|
|
||||||
10000,
|
|
||||||
states,
|
|
||||||
golds,
|
|
||||||
model.state2vec,
|
|
||||||
model.vec2scores,
|
|
||||||
width,
|
|
||||||
losses=losses,
|
|
||||||
beam_density=beam_density
|
|
||||||
)
|
|
||||||
for i, d_scores in enumerate(states_d_scores):
|
|
||||||
losses[self.name] += (d_scores**2).mean()
|
|
||||||
ids, bp_vectors, bp_scores = backprops[i]
|
|
||||||
d_vector = bp_scores(d_scores)
|
|
||||||
if isinstance(model.ops, CupyOps) \
|
|
||||||
and not isinstance(ids, model.state2vec.ops.xp.ndarray):
|
|
||||||
model.backprops.append((
|
|
||||||
util.get_async(model.cuda_stream, ids),
|
|
||||||
util.get_async(model.cuda_stream, d_vector),
|
|
||||||
bp_vectors))
|
|
||||||
else:
|
|
||||||
model.backprops.append((ids, d_vector, bp_vectors))
|
|
||||||
backprop_tok2vec(golds)
|
|
||||||
if sgd is not None:
|
|
||||||
self.model.finish_update(sgd)
|
|
||||||
if set_annotations:
|
|
||||||
self.set_annotations(docs, beams)
|
|
||||||
cdef Beam beam
|
|
||||||
for beam in beams:
|
|
||||||
_beam_utils.cleanup_beam(beam)
|
|
||||||
|
|
||||||
def get_gradients(self):
|
def get_gradients(self):
|
||||||
"""Get non-zero gradients of the model's parameters, as a dictionary
|
"""Get non-zero gradients of the model's parameters, as a dictionary
|
||||||
keyed by the parameter ID. The values are (weights, gradients) tuples.
|
keyed by the parameter ID. The values are (weights, gradients) tuples.
|
||||||
|
@ -513,66 +356,9 @@ cdef class Parser:
|
||||||
queue.extend(node._layers)
|
queue.extend(node._layers)
|
||||||
return gradients
|
return gradients
|
||||||
|
|
||||||
def _init_gold_batch_no_cut(self, whole_examples):
|
def _init_gold_batch_no_cut(self, examples):
|
||||||
states = self.moves.init_batch([eg.doc for eg in whole_examples])
|
states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
good_docs = []
|
return states, examples
|
||||||
good_golds = []
|
|
||||||
good_states = []
|
|
||||||
for i, eg in enumerate(whole_examples):
|
|
||||||
parses = get_parses_from_example(eg)
|
|
||||||
doc, gold = parses[0]
|
|
||||||
if gold is not None and self.moves.has_gold(gold):
|
|
||||||
good_docs.append(doc)
|
|
||||||
good_golds.append(gold)
|
|
||||||
good_states.append(states[i])
|
|
||||||
n_moves = []
|
|
||||||
for doc, gold in zip(good_docs, good_golds):
|
|
||||||
oracle_actions = self.moves.get_oracle_sequence(doc, gold)
|
|
||||||
n_moves.append(len(oracle_actions))
|
|
||||||
return good_states, good_golds, max(n_moves, default=0) * 2
|
|
||||||
|
|
||||||
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
|
|
||||||
"""Make a square batch, of length equal to the shortest doc. A long
|
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
|
||||||
where N is the shortest doc. We'll make two states, one representing
|
|
||||||
long_doc[:N], and another representing long_doc[N:]."""
|
|
||||||
cdef:
|
|
||||||
StateClass state
|
|
||||||
Transition action
|
|
||||||
whole_docs = []
|
|
||||||
whole_golds = []
|
|
||||||
for eg in whole_examples:
|
|
||||||
for doc, gold in get_parses_from_example(eg):
|
|
||||||
whole_docs.append(doc)
|
|
||||||
whole_golds.append(gold)
|
|
||||||
whole_states = self.moves.init_batch(whole_docs)
|
|
||||||
max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs])))
|
|
||||||
max_moves = 0
|
|
||||||
states = []
|
|
||||||
golds = []
|
|
||||||
for doc, state, gold in zip(whole_docs, whole_states, whole_golds):
|
|
||||||
gold = self.moves.preprocess_gold(gold)
|
|
||||||
if gold is None:
|
|
||||||
continue
|
|
||||||
oracle_actions = self.moves.get_oracle_sequence(doc, gold)
|
|
||||||
start = 0
|
|
||||||
while start < len(doc):
|
|
||||||
state = state.copy()
|
|
||||||
n_moves = 0
|
|
||||||
while state.B(0) < start and not state.is_final():
|
|
||||||
action = self.moves.c[oracle_actions.pop(0)]
|
|
||||||
action.do(state.c, action.label)
|
|
||||||
state.c.push_hist(action.clas)
|
|
||||||
n_moves += 1
|
|
||||||
has_gold = self.moves.has_gold(gold, start=start,
|
|
||||||
end=start+max_length)
|
|
||||||
if not state.is_final() and has_gold:
|
|
||||||
states.append(state)
|
|
||||||
golds.append(gold)
|
|
||||||
max_moves = max(max_moves, n_moves)
|
|
||||||
start += min(max_length, len(doc)-start)
|
|
||||||
max_moves = max(max_moves, len(oracle_actions))
|
|
||||||
return states, golds, max_moves
|
|
||||||
|
|
||||||
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
|
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -631,13 +417,8 @@ cdef class Parser:
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
doc_sample = []
|
doc_sample = []
|
||||||
gold_sample = []
|
|
||||||
for example in islice(get_examples(), 10):
|
for example in islice(get_examples(), 10):
|
||||||
parses = get_parses_from_example(example, merge=False, vocab=self.vocab)
|
doc_sample.append(example.predicted)
|
||||||
for doc, gold in parses:
|
|
||||||
if len(doc):
|
|
||||||
doc_sample.append(doc)
|
|
||||||
gold_sample.append(gold)
|
|
||||||
|
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
for name, component in pipeline:
|
for name, component in pipeline:
|
||||||
|
@ -656,12 +437,6 @@ cdef class Parser:
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def _get_doc(self, example):
|
|
||||||
""" Use this method if the `example` can be both a Doc or an Example """
|
|
||||||
if isinstance(example, Doc):
|
|
||||||
return example
|
|
||||||
return example.doc
|
|
||||||
|
|
||||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
serializers = {
|
serializers = {
|
||||||
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
|
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
|
||||||
|
|
|
@ -40,8 +40,6 @@ cdef class TransitionSystem:
|
||||||
cdef int _size
|
cdef int _size
|
||||||
cdef public attr_t root_label
|
cdef public attr_t root_label
|
||||||
cdef public freqs
|
cdef public freqs
|
||||||
cdef init_state_t init_beam_state
|
|
||||||
cdef del_state_t del_beam_state
|
|
||||||
cdef public object labels
|
cdef public object labels
|
||||||
|
|
||||||
cdef int initialize_state(self, StateC* state) nogil
|
cdef int initialize_state(self, StateC* state) nogil
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
from cpython.ref cimport Py_INCREF
|
from cpython.ref cimport Py_INCREF
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from thinc.extra.search cimport Beam
|
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from ..typedefs cimport weight_t
|
from ..typedefs cimport weight_t
|
||||||
from . cimport _beam_utils
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
|
@ -47,8 +45,6 @@ cdef class TransitionSystem:
|
||||||
if labels_by_action:
|
if labels_by_action:
|
||||||
self.initialize_actions(labels_by_action, min_freq=min_freq)
|
self.initialize_actions(labels_by_action, min_freq=min_freq)
|
||||||
self.root_label = self.strings.add('ROOT')
|
self.root_label = self.strings.add('ROOT')
|
||||||
self.init_beam_state = _init_state
|
|
||||||
self.del_beam_state = _del_state
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (self.__class__, (self.strings, self.labels), None, None)
|
return (self.__class__, (self.strings, self.labels), None, None)
|
||||||
|
@ -64,29 +60,6 @@ cdef class TransitionSystem:
|
||||||
offset += len(doc)
|
offset += len(doc)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def init_beams(self, docs, beam_width, beam_density=0.):
|
|
||||||
cdef Doc doc
|
|
||||||
beams = []
|
|
||||||
cdef int offset = 0
|
|
||||||
|
|
||||||
# Doc objects might contain labels that we need to register actions for. We need to check for that
|
|
||||||
# *before* we create any Beam objects, because the Beam object needs the correct number of
|
|
||||||
# actions. It's sort of dumb, but the best way is to just call init_batch() -- that triggers the additions,
|
|
||||||
# and it doesn't matter that we create and discard the state objects.
|
|
||||||
self.init_batch(docs)
|
|
||||||
|
|
||||||
for doc in docs:
|
|
||||||
beam = Beam(self.n_moves, beam_width, min_density=beam_density)
|
|
||||||
beam.initialize(self.init_beam_state, self.del_beam_state,
|
|
||||||
doc.length, doc.c)
|
|
||||||
for i in range(beam.width):
|
|
||||||
state = <StateC*>beam.at(i)
|
|
||||||
state.offset = offset
|
|
||||||
offset += len(doc)
|
|
||||||
beam.check_done(_beam_utils.check_final_state, NULL)
|
|
||||||
beams.append(beam)
|
|
||||||
return beams
|
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example):
|
def get_oracle_sequence(self, Example example):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
|
|
Loading…
Reference in New Issue
Block a user