diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index e2a3eb201..f0d5b7fff 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -917,6 +917,20 @@ class ClozeMultitask(Pipe): if losses is not None: losses[self.name] += loss + @staticmethod + def decode_utf8_predictions(char_array): + # The format alternates filling from start and end, and 255 is missing + words = [] + char_array = char_array.reshape((char_array.shape[0], -1, 256)) + nr_char = char_array.shape[1] + char_array = char_array.argmax(axis=-1) + for row in char_array: + starts = [chr(c) for c in row[::2] if c != 255] + ends = [chr(c) for c in row[1::2] if c != 255] + word = "".join(starts + list(reversed(ends))) + words.append(word) + return words + @component("textcat", assigns=["doc.cats"]) class TextCategorizer(Pipe):