mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Fix use_params and pipe methods
This commit is contained in:
parent
ca70b08661
commit
c2c825127a
|
@ -220,13 +220,19 @@ class Language(object):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params, **cfg):
|
def use_params(self, params, **cfg):
|
||||||
contexts = [pipe.model.use_params(params) for pipe
|
contexts = [pipe.use_params(params) for pipe
|
||||||
in self.pipeline if hasattr(pipe, 'model')
|
in self.pipeline if hasattr(pipe, 'use_params')]
|
||||||
and hasattr(pipe.model, 'use_params')]
|
# TODO: Having trouble with contextlib
|
||||||
|
# Workaround: these aren't actually context managers atm.
|
||||||
|
for context in contexts:
|
||||||
|
try:
|
||||||
|
next(context)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
yield
|
yield
|
||||||
for context in contexts:
|
for context in contexts:
|
||||||
try:
|
try:
|
||||||
next(context.gen)
|
next(context)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -242,7 +248,8 @@ class Language(object):
|
||||||
parse (bool)
|
parse (bool)
|
||||||
entity (bool)
|
entity (bool)
|
||||||
"""
|
"""
|
||||||
stream = ((self.make_doc(text), None) for text in texts)
|
#stream = ((self.make_doc(text), None) for text in texts)
|
||||||
|
stream = ((doc, {}) for doc in texts)
|
||||||
for proc in self.pipeline:
|
for proc in self.pipeline:
|
||||||
name = getattr(proc, 'name', None)
|
name = getattr(proc, 'name', None)
|
||||||
if name in disabled and not disabled[name]:
|
if name in disabled and not disabled[name]:
|
||||||
|
|
|
@ -61,8 +61,14 @@ class TokenVectorEncoder(object):
|
||||||
state['tokvecs'] = tokvecs
|
state['tokvecs'] = tokvecs
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def pipe(self, docs, **kwargs):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
raise NotImplementedError
|
for batch in cytoolz.partition_all(batch_size, stream):
|
||||||
|
docs, states = zip(*batch)
|
||||||
|
tokvecs = self.predict(docs)
|
||||||
|
self.set_annotations(docs, tokvecs)
|
||||||
|
for state in states:
|
||||||
|
state['tokvecs'] = tokvecs
|
||||||
|
yield from zip(docs, states)
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
feats = self.doc2feats(docs)
|
feats = self.doc2feats(docs)
|
||||||
|
@ -96,6 +102,10 @@ class TokenVectorEncoder(object):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model()
|
self.model = self.Model()
|
||||||
|
|
||||||
|
def use_params(self, params):
|
||||||
|
with self.model.use_params(params):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class NeuralTagger(object):
|
class NeuralTagger(object):
|
||||||
name = 'nn_tagger'
|
name = 'nn_tagger'
|
||||||
|
@ -112,11 +122,13 @@ class NeuralTagger(object):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
for batch in cytoolz.partition_all(batch_size, batch):
|
for batch in cytoolz.partition_all(batch_size, stream):
|
||||||
docs, tokvecs = zip(*batch)
|
docs, states = zip(*batch)
|
||||||
tag_ids = self.predict(docs, tokvecs)
|
tag_ids = self.predict(states[0]['tokvecs'])
|
||||||
self.set_annotations(docs, tag_ids)
|
self.set_annotations(docs, tag_ids)
|
||||||
yield from docs
|
for state in states:
|
||||||
|
state['tag_ids'] = tag_ids
|
||||||
|
yield from zip(docs, states)
|
||||||
|
|
||||||
def predict(self, tokvecs):
|
def predict(self, tokvecs):
|
||||||
scores = self.model(tokvecs)
|
scores = self.model(tokvecs)
|
||||||
|
@ -130,7 +142,7 @@ class NeuralTagger(object):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
cdef int i, j
|
cdef int i, j, tag_id
|
||||||
cdef Vocab vocab = self.vocab
|
cdef Vocab vocab = self.vocab
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[idx:idx+len(doc)]
|
doc_tag_ids = batch_tag_ids[idx:idx+len(doc)]
|
||||||
|
@ -147,7 +159,6 @@ class NeuralTagger(object):
|
||||||
self.model.nI = tokvecs.shape[1]
|
self.model.nI = tokvecs.shape[1]
|
||||||
|
|
||||||
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
|
tag_scores, bp_tag_scores = self.model.begin_update(tokvecs, drop=drop)
|
||||||
|
|
||||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||||
|
|
||||||
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
|
d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||||
|
@ -167,24 +178,33 @@ class NeuralTagger(object):
|
||||||
for tag in gold.tags:
|
for tag in gold.tags:
|
||||||
correct[idx] = tag_index[tag]
|
correct[idx] = tag_index[tag]
|
||||||
idx += 1
|
idx += 1
|
||||||
correct = self.model.ops.xp.array(correct)
|
correct = self.model.ops.xp.array(correct, dtype='i')
|
||||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||||
loss = (d_scores**2).sum()
|
loss = (d_scores**2).sum()
|
||||||
d_scores = self.model.ops.asarray(d_scores)
|
d_scores = self.model.ops.asarray(d_scores, dtype='f')
|
||||||
return loss, d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def begin_training(self, gold_tuples, pipeline=None):
|
def begin_training(self, gold_tuples, pipeline=None):
|
||||||
tag_map = dict(self.vocab.morphology.tag_map)
|
orig_tag_map = dict(self.vocab.morphology.tag_map)
|
||||||
|
new_tag_map = {}
|
||||||
for raw_text, annots_brackets in gold_tuples:
|
for raw_text, annots_brackets in gold_tuples:
|
||||||
for annots, brackets in annots_brackets:
|
for annots, brackets in annots_brackets:
|
||||||
ids, words, tags, heads, deps, ents = annots
|
ids, words, tags, heads, deps, ents = annots
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
if tag not in tag_map:
|
if tag in orig_tag_map:
|
||||||
tag_map[tag] = {POS: X}
|
new_tag_map[tag] = orig_tag_map[tag]
|
||||||
|
else:
|
||||||
|
new_tag_map[tag] = {POS: X}
|
||||||
cdef Vocab vocab = self.vocab
|
cdef Vocab vocab = self.vocab
|
||||||
vocab.morphology = Morphology(vocab.strings, tag_map,
|
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
||||||
vocab.morphology.lemmatizer)
|
vocab.morphology.lemmatizer)
|
||||||
self.model = Softmax(self.vocab.morphology.n_tags)
|
self.model = Softmax(self.vocab.morphology.n_tags)
|
||||||
|
print("Tagging", self.model.nO, "tags")
|
||||||
|
|
||||||
|
def use_params(self, params):
|
||||||
|
with self.model.use_params(params):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef class EntityRecognizer(LinearParser):
|
cdef class EntityRecognizer(LinearParser):
|
||||||
|
|
|
@ -7,6 +7,7 @@ from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import ujson
|
import ujson
|
||||||
|
import contextlib
|
||||||
|
|
||||||
from libc.math cimport exp
|
from libc.math cimport exp
|
||||||
cimport cython
|
cimport cython
|
||||||
|
@ -297,18 +298,15 @@ cdef class Parser:
|
||||||
The number of threads with which to work on the buffer in parallel.
|
The number of threads with which to work on the buffer in parallel.
|
||||||
Yields (Doc): Documents, in order.
|
Yields (Doc): Documents, in order.
|
||||||
"""
|
"""
|
||||||
cdef StateClass state
|
cdef StateClass parse_state
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
queue = []
|
queue = []
|
||||||
for batch in cytoolz.partition_all(batch_size, stream):
|
for batch in cytoolz.partition_all(batch_size, stream):
|
||||||
docs, tokvecs = zip(*batch)
|
batch = list(batch)
|
||||||
states = self.parse_batch(docs, tokvecs)
|
docs, states = zip(*batch)
|
||||||
for doc, state in zip(docs, states):
|
parse_states = self.parse_batch(docs, states[0]['tokvecs'])
|
||||||
self.moves.finalize_state(state.c)
|
self.set_annotations(docs, parse_states)
|
||||||
for i in range(doc.length):
|
yield from zip(docs, states)
|
||||||
doc.c[i] = state.c._sent[i]
|
|
||||||
self.moves.finalize_doc(doc)
|
|
||||||
yield doc
|
|
||||||
|
|
||||||
def parse_batch(self, docs, tokvecs):
|
def parse_batch(self, docs, tokvecs):
|
||||||
cuda_stream = get_cuda_stream()
|
cuda_stream = get_cuda_stream()
|
||||||
|
@ -324,7 +322,7 @@ cdef class Parser:
|
||||||
scores = vec2scores(vectors)
|
scores = vec2scores(vectors)
|
||||||
self.transition_batch(states, scores)
|
self.transition_batch(states, scores)
|
||||||
todo = [st for st in states if not st.is_final()]
|
todo = [st for st in states if not st.is_final()]
|
||||||
self.finish_batch(states, docs)
|
return states
|
||||||
|
|
||||||
def update(self, docs, golds, state=None, drop=0., sgd=None):
|
def update(self, docs, golds, state=None, drop=0., sgd=None):
|
||||||
assert state is not None
|
assert state is not None
|
||||||
|
@ -437,7 +435,7 @@ cdef class Parser:
|
||||||
c_d_scores += d_scores.shape[1]
|
c_d_scores += d_scores.shape[1]
|
||||||
return d_scores
|
return d_scores
|
||||||
|
|
||||||
def finish_batch(self, states, docs):
|
def set_annotations(self, docs, states):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
for state, doc in zip(states, docs):
|
for state, doc in zip(states, docs):
|
||||||
|
@ -465,6 +463,12 @@ cdef class Parser:
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(self.moves.n_moves, **cfg)
|
self.model = self.Model(self.moves.n_moves, **cfg)
|
||||||
|
|
||||||
|
def use_params(self, params):
|
||||||
|
# Can't decorate cdef class :(. Workaround.
|
||||||
|
with self.model[0].use_params(params):
|
||||||
|
with self.model[1].use_params(params):
|
||||||
|
yield
|
||||||
|
|
||||||
def to_disk(self, path):
|
def to_disk(self, path):
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
with (path / 'model.bin').open('wb') as file_:
|
with (path / 'model.bin').open('wb') as file_:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user