mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-15 06:09:01 +03:00
Add method to decode predicted characters
This commit is contained in:
parent
f2808f78a7
commit
5a272d9029
|
@ -908,6 +908,20 @@ class ClozeMultitask(Pipe):
|
|||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
@staticmethod
|
||||
def decode_utf8_predictions(char_array):
|
||||
# The format alternates filling from start and end, and 255 is missing
|
||||
words = []
|
||||
char_array = char_array.reshape((char_array.shape[0], -1, 256))
|
||||
nr_char = char_array.shape[1]
|
||||
char_array = char_array.argmax(axis=-1)
|
||||
for row in char_array:
|
||||
starts = [chr(c) for c in row[::2] if c != 255]
|
||||
ends = [chr(c) for c in row[1::2] if c != 255]
|
||||
word = "".join(starts + list(reversed(ends)))
|
||||
words.append(word)
|
||||
return words
|
||||
|
||||
|
||||
class TextCategorizer(Pipe):
|
||||
"""Pipeline component for text classification.
|
||||
|
|
Loading…
Reference in New Issue
Block a user