mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-25 11:23:40 +03:00
Add assert_packed_msg_equal util function
This commit is contained in:
parent
9acf8686f7
commit
1ebd0d3f27
|
@ -10,6 +10,7 @@ import numpy
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import msgpack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,3 +106,13 @@ def assert_docs_equal(doc1, doc2):
|
||||||
assert [ t.ent_type for t in doc1 ] == [ t.ent_type for t in doc2 ]
|
assert [ t.ent_type for t in doc1 ] == [ t.ent_type for t in doc2 ]
|
||||||
assert [ t.ent_iob for t in doc1 ] == [ t.ent_iob for t in doc2 ]
|
assert [ t.ent_iob for t in doc1 ] == [ t.ent_iob for t in doc2 ]
|
||||||
assert [ ent for ent in doc1.ents ] == [ ent for ent in doc2.ents ]
|
assert [ ent for ent in doc1.ents ] == [ ent for ent in doc2.ents ]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_packed_msg_equal(b1, b2):
|
||||||
|
"""Assert that two packed msgpack messages are equal."""
|
||||||
|
msg1 = msgpack.loads(b1, encoding='utf8')
|
||||||
|
msg2 = msgpack.loads(b2, encoding='utf8')
|
||||||
|
assert sorted(msg1.keys()) == sorted(msg2.keys())
|
||||||
|
for (k1, v1), (k2, v2) in zip(sorted(msg1.items()), sorted(msg2.items())):
|
||||||
|
assert k1 == k2
|
||||||
|
assert v1 == v2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user