Remove state argument in pipeline. Other changes

This commit is contained in:
Matthew Honnibal 2017-05-19 13:26:36 -05:00
parent 66ea9aebe7
commit c12ab47a56
2 changed files with 41 additions and 70 deletions

View File

@ -33,7 +33,7 @@ from .morphology cimport Morphology
from .vocab cimport Vocab from .vocab cimport Vocab
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
from ._ml import Tok2Vec, flatten, get_col, doc2feats from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
from .parts_of_speech import X from .parts_of_speech import X
@ -57,18 +57,12 @@ class TokenVectorEncoder(object):
docs = [docs] docs = [docs]
tokvecs = self.predict(docs) tokvecs = self.predict(docs)
self.set_annotations(docs, tokvecs) self.set_annotations(docs, tokvecs)
state = {} if state is None else state
state['tokvecs'] = tokvecs
return state
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for batch in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, stream):
docs, states = zip(*batch)
tokvecs = self.predict(docs) tokvecs = self.predict(docs)
self.set_annotations(docs, tokvecs) self.set_annotations(docs, tokvecs)
for state in states: yield from docs
state['tokvecs'] = tokvecs
yield from zip(docs, states)
def predict(self, docs): def predict(self, docs):
feats = self.doc2feats(docs) feats = self.doc2feats(docs)
@ -81,18 +75,12 @@ class TokenVectorEncoder(object):
doc.tensor = tokvecs[start : start + len(doc)] doc.tensor = tokvecs[start : start + len(doc)]
start += len(doc) start += len(doc)
def update(self, docs, golds, state=None, def begin_update(self, docs, drop=0.):
drop=0., sgd=None):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
golds = [golds]
state = {} if state is None else state
feats = self.doc2feats(docs) feats = self.doc2feats(docs)
tokvecs, bp_tokvecs = self.model.begin_update(feats, drop=drop) tokvecs, bp_tokvecs = self.model.begin_update(feats, drop=drop)
state['feats'] = feats return tokvecs, bp_tokvecs
state['tokvecs'] = tokvecs
state['bp_tokvecs'] = bp_tokvecs
return state
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
raise NotImplementedError raise NotImplementedError
@ -113,22 +101,16 @@ class NeuralTagger(object):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
def __call__(self, doc, state=None): def __call__(self, doc):
assert state is not None tags = self.predict(doc.tensor)
assert 'tokvecs' in state
tokvecs = state['tokvecs']
tags = self.predict(tokvecs)
self.set_annotations([doc], tags) self.set_annotations([doc], tags)
return state
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for batch in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, stream):
docs, states = zip(*batch) tokvecs = self.model.ops.flatten([d.tensor for d in docs])
tag_ids = self.predict(states[0]['tokvecs']) tag_ids = self.predict(tokvecs)
self.set_annotations(docs, tag_ids) self.set_annotations(docs, tag_ids)
for state in states: yield from docs
state['tag_ids'] = tag_ids
yield from zip(docs, states)
def predict(self, tokvecs): def predict(self, tokvecs):
scores = self.model(tokvecs) scores = self.model(tokvecs)
@ -150,11 +132,9 @@ class NeuralTagger(object):
vocab.morphology.assign_tag_id(&doc.c[j], tag_id) vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
idx += 1 idx += 1
def update(self, docs, golds, state=None, drop=0., sgd=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None):
state = {} if state is None else state docs, tokvecs = docs_tokvecs
tokvecs = state['tokvecs']
bp_tokvecs = state['bp_tokvecs']
if self.model.nI is None: if self.model.nI is None:
self.model.nI = tokvecs.shape[1] self.model.nI = tokvecs.shape[1]
@ -163,20 +143,20 @@ class NeuralTagger(object):
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
bp_tokvecs(d_tokvecs, sgd=sgd) return d_tokvecs
state['tag_scores'] = tag_scores
state['tag_loss'] = loss
return state
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
tag_index = {tag: i for i, tag in enumerate(self.vocab.morphology.tag_names)} tag_index = {tag: i for i, tag in enumerate(self.vocab.morphology.tag_names)}
cdef int idx = 0 cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype='i') correct = numpy.zeros((scores.shape[0],), dtype='i')
guesses = scores.argmax(axis=1)
for gold in golds: for gold in golds:
for tag in gold.tags: for tag in gold.tags:
correct[idx] = tag_index[tag] if tag is None:
correct[idx] = guesses[idx]
else:
correct[idx] = tag_index[tag]
idx += 1 idx += 1
correct = self.model.ops.xp.array(correct, dtype='i') correct = self.model.ops.xp.array(correct, dtype='i')
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1]) d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
@ -198,15 +178,16 @@ class NeuralTagger(object):
cdef Vocab vocab = self.vocab cdef Vocab vocab = self.vocab
vocab.morphology = Morphology(vocab.strings, new_tag_map, vocab.morphology = Morphology(vocab.strings, new_tag_map,
vocab.morphology.lemmatizer) vocab.morphology.lemmatizer)
self.model = Softmax(self.vocab.morphology.n_tags) token_vector_width = pipeline[0].model.nO
print("Tagging", self.model.nO, "tags") self.model = rebatch(1024, Softmax(self.vocab.morphology.n_tags,
token_vector_width))
#self.model = Softmax(self.vocab.morphology.n_tags)
def use_params(self, params): def use_params(self, params):
with self.model.use_params(params): with self.model.use_params(params):
yield yield
cdef class EntityRecognizer(LinearParser): cdef class EntityRecognizer(LinearParser):
""" """
Annotate named entities on Doc objects. Annotate named entities on Doc objects.
@ -275,8 +256,6 @@ cdef class NeuralEntityRecognizer(NeuralParser):
return ids return ids
cdef class BeamDependencyParser(BeamParser): cdef class BeamDependencyParser(BeamParser):
TransitionSystem = ArcEager TransitionSystem = ArcEager

View File

@ -35,12 +35,12 @@ from preshed.maps cimport map_get
from thinc.api import layerize, chain from thinc.api import layerize, chain
from thinc.neural import Model, Affine, ELU, ReLu, Maxout from thinc.neural import Model, Affine, ELU, ReLu, Maxout
from thinc.neural.ops import NumpyOps from thinc.neural.ops import NumpyOps, CupyOps
from .. import util from .. import util
from ..util import get_async, get_cuda_stream from ..util import get_async, get_cuda_stream
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
from .._ml import Tok2Vec, doc2feats from .._ml import Tok2Vec, doc2feats, rebatch
from . import _parse_features from . import _parse_features
from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport CONTEXT_SIZE
@ -229,6 +229,8 @@ cdef class Parser:
nI=token_vector_width, nI=token_vector_width,
pieces=maxout_pieces) pieces=maxout_pieces)
lower = rebatch(1024, lower)
with Model.use_device('cpu'): with Model.use_device('cpu'):
upper = chain( upper = chain(
Maxout(hidden_width), Maxout(hidden_width),
@ -274,7 +276,7 @@ cdef class Parser:
def __reduce__(self): def __reduce__(self):
return (Parser, (self.vocab, self.moves, self.model), None, None) return (Parser, (self.vocab, self.moves, self.model), None, None)
def __call__(self, Doc tokens, state=None): def __call__(self, Doc doc):
""" """
Apply the parser or entity recognizer, setting the annotations onto the Doc object. Apply the parser or entity recognizer, setting the annotations onto the Doc object.
@ -283,10 +285,9 @@ cdef class Parser:
Returns: Returns:
None None
""" """
self.parse_batch([tokens], state['tokvecs']) self.parse_batch([doc], doc.tensor)
return state
def pipe(self, stream, int batch_size=1000, int n_threads=2): def pipe(self, docs, int batch_size=1000, int n_threads=2):
""" """
Process a stream of documents. Process a stream of documents.
@ -301,12 +302,11 @@ cdef class Parser:
cdef StateClass parse_state cdef StateClass parse_state
cdef Doc doc cdef Doc doc
queue = [] queue = []
for batch in cytoolz.partition_all(batch_size, stream): for docs in cytoolz.partition_all(batch_size, docs):
batch = list(batch) tokvecs = self.model[0].ops.flatten([d.tensor for d in docs])
docs, states = zip(*batch) parse_states = self.parse_batch(docs, tokvecs)
parse_states = self.parse_batch(docs, states[0]['tokvecs'])
self.set_annotations(docs, parse_states) self.set_annotations(docs, parse_states)
yield from zip(docs, states) yield from docs
def parse_batch(self, docs, tokvecs): def parse_batch(self, docs, tokvecs):
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
@ -324,10 +324,8 @@ cdef class Parser:
todo = [st for st in states if not st.is_final()] todo = [st for st in states if not st.is_final()]
return states return states
def update(self, docs, golds, state=None, drop=0., sgd=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None):
assert state is not None docs, tokvecs = docs_tokvecs
assert 'tokvecs' in state
assert 'bp_tokvecs' in state
if isinstance(docs, Doc) and isinstance(golds, GoldParse): if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs] docs = [docs]
golds = [golds] golds = [golds]
@ -336,9 +334,6 @@ cdef class Parser:
for gold in golds: for gold in golds:
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
tokvecs = state['tokvecs']
bp_tokvecs = state['bp_tokvecs']
states = self.moves.init_batch(docs) states = self.moves.init_batch(docs)
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
drop) drop)
@ -357,17 +352,17 @@ cdef class Parser:
d_scores = self.get_batch_loss(states, golds, scores) d_scores = self.get_batch_loss(states, golds, scores)
d_vector = bp_scores(d_scores, sgd=sgd) d_vector = bp_scores(d_scores, sgd=sgd)
loss += (d_scores**2).sum()
if not isinstance(tokvecs, state2vec.ops.xp.ndarray): if isinstance(self.model[0].ops, CupyOps) \
backprops.append((token_ids, d_vector, bp_vector)) and not isinstance(token_ids, state2vec.ops.xp.ndarray):
else:
# Move token_ids and d_vector to CPU, asynchronously # Move token_ids and d_vector to CPU, asynchronously
backprops.append(( backprops.append((
get_async(cuda_stream, token_ids), get_async(cuda_stream, token_ids),
get_async(cuda_stream, d_vector), get_async(cuda_stream, d_vector),
bp_vector bp_vector
)) ))
else:
backprops.append((token_ids, d_vector, bp_vector))
self.transition_batch(states, scores) self.transition_batch(states, scores)
todo = [st for st in todo if not st[0].is_final()] todo = [st for st in todo if not st[0].is_final()]
# Tells CUDA to block, so our async copies complete. # Tells CUDA to block, so our async copies complete.
@ -385,9 +380,7 @@ cdef class Parser:
else: else:
xp.add.at(d_tokvecs, xp.add.at(d_tokvecs,
token_ids, d_state_features * active_feats) token_ids, d_state_features * active_feats)
bp_tokvecs(d_tokvecs, sgd) return d_tokvecs
state['parser_loss'] = loss
return state
def get_batch_model(self, batch_size, tokvecs, stream, dropout): def get_batch_model(self, batch_size, tokvecs, stream, dropout):
lower, upper = self.model lower, upper = self.model
@ -445,7 +438,6 @@ cdef class Parser:
self.moves.finalize_doc(doc) self.moves.finalize_doc(doc)
def add_label(self, label): def add_label(self, label):
# Doesn't set label into serializer -- subclasses override it to do that.
for action in self.moves.action_types: for action in self.moves.action_types:
added = self.moves.add_action(action, label) added = self.moves.add_action(action, label)
if added: if added: