mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-03 16:54:12 +03:00
Add method to decode predicted characters
This commit is contained in:
parent
cff1f5e48a
commit
29f8413095
|
@ -917,6 +917,20 @@ class ClozeMultitask(Pipe):
|
|||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
@staticmethod
|
||||
def decode_utf8_predictions(char_array):
|
||||
# The format alternates filling from start and end, and 255 is missing
|
||||
words = []
|
||||
char_array = char_array.reshape((char_array.shape[0], -1, 256))
|
||||
nr_char = char_array.shape[1]
|
||||
char_array = char_array.argmax(axis=-1)
|
||||
for row in char_array:
|
||||
starts = [chr(c) for c in row[::2] if c != 255]
|
||||
ends = [chr(c) for c in row[1::2] if c != 255]
|
||||
word = "".join(starts + list(reversed(ends)))
|
||||
words.append(word)
|
||||
return words
|
||||
|
||||
|
||||
@component("textcat", assigns=["doc.cats"])
|
||||
class TextCategorizer(Pipe):
|
||||
|
|
Loading…
Reference in New Issue
Block a user