dont use get_array_module (#11056)

This commit is contained in:
kadarakos 2022-07-04 17:15:33 +02:00 committed by GitHub
parent e9eb59699f
commit 5240baccfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -192,7 +192,7 @@ class TextCategorizer(TrainablePipe):
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
tensors = [doc.tensor for doc in docs]
xp = get_array_module(tensors)
xp = self.model.ops.xp
scores = xp.zeros((len(list(docs)), len(self.labels)))
return scores
scores = self.model.predict(docs)