mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 01:02:23 +03:00
Add method to decode predicted characters
This commit is contained in:
parent
cff1f5e48a
commit
29f8413095
|
@ -917,6 +917,20 @@ class ClozeMultitask(Pipe):
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_utf8_predictions(char_array):
|
||||||
|
# The format alternates filling from start and end, and 255 is missing
|
||||||
|
words = []
|
||||||
|
char_array = char_array.reshape((char_array.shape[0], -1, 256))
|
||||||
|
nr_char = char_array.shape[1]
|
||||||
|
char_array = char_array.argmax(axis=-1)
|
||||||
|
for row in char_array:
|
||||||
|
starts = [chr(c) for c in row[::2] if c != 255]
|
||||||
|
ends = [chr(c) for c in row[1::2] if c != 255]
|
||||||
|
word = "".join(starts + list(reversed(ends)))
|
||||||
|
words.append(word)
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
@component("textcat", assigns=["doc.cats"])
|
@component("textcat", assigns=["doc.cats"])
|
||||||
class TextCategorizer(Pipe):
|
class TextCategorizer(Pipe):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user