Add method to decode predicted characters

This commit is contained in:
Matthw Honnibal 2019-10-21 03:56:15 +02:00
parent f2808f78a7
commit 5a272d9029

View File

@ -908,6 +908,20 @@ class ClozeMultitask(Pipe):
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
@staticmethod
def decode_utf8_predictions(char_array):
# The format alternates filling from start and end, and 255 is missing
words = []
char_array = char_array.reshape((char_array.shape[0], -1, 256))
nr_char = char_array.shape[1]
char_array = char_array.argmax(axis=-1)
for row in char_array:
starts = [chr(c) for c in row[::2] if c != 255]
ends = [chr(c) for c in row[1::2] if c != 255]
word = "".join(starts + list(reversed(ends)))
words.append(word)
return words
class TextCategorizer(Pipe): class TextCategorizer(Pipe):
"""Pipeline component for text classification. """Pipeline component for text classification.