mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix serialization of extension attr values in DocBin (#4540)
This commit is contained in:
parent
df293f3894
commit
a90025b277
|
@ -1,11 +1,9 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_issue4528(en_vocab):
|
def test_issue4528(en_vocab):
|
||||||
"""Test that user_data is correctly serialized in DocBin."""
|
"""Test that user_data is correctly serialized in DocBin."""
|
||||||
doc = Doc(en_vocab, words=["hello", "world"])
|
doc = Doc(en_vocab, words=["hello", "world"])
|
||||||
|
|
|
@ -103,7 +103,8 @@ class DocBin(object):
|
||||||
doc = Doc(vocab, words=words, spaces=spaces)
|
doc = Doc(vocab, words=words, spaces=spaces)
|
||||||
doc = doc.from_array(self.attrs, tokens)
|
doc = doc.from_array(self.attrs, tokens)
|
||||||
if self.store_user_data:
|
if self.store_user_data:
|
||||||
doc.user_data.update(srsly.msgpack_loads(self.user_data[i]))
|
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
||||||
|
doc.user_data.update(user_data)
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user