mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 04:10:20 +03:00
Fix GPU inference speed problem
This commit is contained in:
parent
065aa062f0
commit
a6b1a068d5
|
@ -5,7 +5,7 @@ from itertools import islice
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Config, Model
|
from thinc.api import Config, Model, NumpyOps
|
||||||
from thinc.api import SequenceCategoricalCrossentropy, L2Distance
|
from thinc.api import SequenceCategoricalCrossentropy, L2Distance
|
||||||
from thinc.types import Floats2d, Ints2d
|
from thinc.types import Floats2d, Ints2d
|
||||||
|
|
||||||
|
@ -118,6 +118,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
self.cfg: Dict[str, Any] = {"labels": []}
|
self.cfg: Dict[str, Any] = {"labels": []}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.numpy_ops = NumpyOps()
|
||||||
self.lowercasing = model.has_ref("lowercasing_output")
|
self.lowercasing = model.has_ref("lowercasing_output")
|
||||||
|
|
||||||
def get_loss(
|
def get_loss(
|
||||||
|
@ -206,6 +207,8 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
):
|
):
|
||||||
guesses = []
|
guesses = []
|
||||||
for (i, doc, doc_scores) in zip(range(len(docs)), docs, scores):
|
for (i, doc, doc_scores) in zip(range(len(docs)), docs, scores):
|
||||||
|
if lowercasing_flags is not None:
|
||||||
|
doc_lowercasing_flags = self.numpy_ops.asarray(lowercasing_flags[i])
|
||||||
if self.top_k == 1:
|
if self.top_k == 1:
|
||||||
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
|
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
|
||||||
else:
|
else:
|
||||||
|
@ -216,7 +219,10 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
doc_compat_guesses = []
|
doc_compat_guesses = []
|
||||||
for (j, token, candidates) in zip(range(len(doc)), doc, doc_guesses):
|
for (j, token, candidates) in zip(range(len(doc)), doc, doc_guesses):
|
||||||
if lowercasing_flags is not None and lowercasing_flags[i][j] > 0.5:
|
if (
|
||||||
|
lowercasing_flags is not None
|
||||||
|
and doc_lowercasing_flags[j] > 0.5 # type:ignore
|
||||||
|
):
|
||||||
to_lowercase = 1
|
to_lowercase = 1
|
||||||
text = token.lower_
|
text = token.lower_
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user