mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix GPU usage in Language
This commit is contained in:
parent
711ad5edc4
commit
2713041571
|
@ -3,6 +3,10 @@ from __future__ import absolute_import, unicode_literals
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import dill
|
import dill
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from thinc.neural import Model
|
||||||
|
from thinc.neural.ops import NumpyOps, CupyOps
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
|
@ -179,11 +183,16 @@ class Language(object):
|
||||||
state = process.update(docs, golds,
|
state = process.update(docs, golds,
|
||||||
state=state,
|
state=state,
|
||||||
drop=drop,
|
drop=drop,
|
||||||
sgd=sgd)
|
sgd=get_grads)
|
||||||
else:
|
else:
|
||||||
process(docs, state=state)
|
process(docs, state=state)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
for key, (W, dW) in grads.items():
|
for key, (W, dW) in grads.items():
|
||||||
|
# TODO: Unhack this when thinc improves
|
||||||
|
if isinstance(W, numpy.ndarray):
|
||||||
|
sgd.ops = NumpyOps()
|
||||||
|
else:
|
||||||
|
sgd.ops = CupyOps()
|
||||||
sgd(W, dW, key=key)
|
sgd(W, dW, key=key)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
@ -197,6 +206,10 @@ class Language(object):
|
||||||
# Handle crossing dependencies
|
# Handle crossing dependencies
|
||||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||||
contexts = []
|
contexts = []
|
||||||
|
if cfg.get('use_gpu'):
|
||||||
|
Model.ops = CupyOps()
|
||||||
|
Model.Ops = CupyOps
|
||||||
|
print("Use GPU")
|
||||||
for proc in self.pipeline:
|
for proc in self.pipeline:
|
||||||
if hasattr(proc, 'begin_training'):
|
if hasattr(proc, 'begin_training'):
|
||||||
context = proc.begin_training(gold_tuples,
|
context = proc.begin_training(gold_tuples,
|
||||||
|
@ -205,6 +218,18 @@ class Language(object):
|
||||||
trainer = Trainer(self, gold_tuples, **cfg)
|
trainer = Trainer(self, gold_tuples, **cfg)
|
||||||
yield trainer, trainer.optimizer
|
yield trainer, trainer.optimizer
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_params(self, params, **cfg):
|
||||||
|
contexts = [pipe.model.use_params(params) for pipe
|
||||||
|
in self.pipeline if hasattr(pipe, 'model')
|
||||||
|
and hasattr(pipe.model, 'use_params')]
|
||||||
|
yield
|
||||||
|
for context in contexts:
|
||||||
|
try:
|
||||||
|
next(context.gen)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
def pipe(self, texts, n_threads=2, batch_size=1000, **disabled):
|
def pipe(self, texts, n_threads=2, batch_size=1000, **disabled):
|
||||||
"""
|
"""
|
||||||
Process texts as a stream, and yield Doc objects in order.
|
Process texts as a stream, and yield Doc objects in order.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user