mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-15 14:17:58 +03:00
Add method to decode predicted characters
This commit is contained in:
parent
f2808f78a7
commit
5a272d9029
|
@ -908,6 +908,20 @@ class ClozeMultitask(Pipe):
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += loss
|
losses[self.name] += loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode_utf8_predictions(char_array):
|
||||||
|
# The format alternates filling from start and end, and 255 is missing
|
||||||
|
words = []
|
||||||
|
char_array = char_array.reshape((char_array.shape[0], -1, 256))
|
||||||
|
nr_char = char_array.shape[1]
|
||||||
|
char_array = char_array.argmax(axis=-1)
|
||||||
|
for row in char_array:
|
||||||
|
starts = [chr(c) for c in row[::2] if c != 255]
|
||||||
|
ends = [chr(c) for c in row[1::2] if c != 255]
|
||||||
|
word = "".join(starts + list(reversed(ends)))
|
||||||
|
words.append(word)
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
class TextCategorizer(Pipe):
|
class TextCategorizer(Pipe):
|
||||||
"""Pipeline component for text classification.
|
"""Pipeline component for text classification.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user