mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Fixes for thinc 6.7
This commit is contained in:
parent
53d00a0371
commit
4c97371051
|
@ -635,9 +635,9 @@ cdef class Parser:
|
|||
def to_disk(self, path, **exclude):
|
||||
serializers = {
|
||||
'lower_model': lambda p: p.open('wb').write(
|
||||
util.model_to_bytes(self.model[0])),
|
||||
self.model[0].to_bytes()),
|
||||
'upper_model': lambda p: p.open('wb').write(
|
||||
util.model_to_bytes(self.model[1])),
|
||||
self.model[1].to_bytes()),
|
||||
'vocab': lambda p: self.vocab.to_disk(p),
|
||||
'moves': lambda p: self.moves.to_disk(p, strings=False),
|
||||
'cfg': lambda p: p.open('w').write(json_dumps(self.cfg))
|
||||
|
@ -669,8 +669,8 @@ cdef class Parser:
|
|||
|
||||
def to_bytes(self, **exclude):
|
||||
serializers = {
|
||||
'lower_model': lambda: util.model_to_bytes(self.model[0]),
|
||||
'upper_model': lambda: util.model_to_bytes(self.model[1]),
|
||||
'lower_model': lambda: self.model[0].to_bytes(),
|
||||
'upper_model': lambda: self.model[1].to_bytes(),
|
||||
'vocab': lambda: self.vocab.to_bytes(),
|
||||
'moves': lambda: self.moves.to_bytes(strings=False),
|
||||
'cfg': lambda: ujson.dumps(self.cfg)
|
||||
|
@ -692,7 +692,7 @@ cdef class Parser:
|
|||
else:
|
||||
cfg = {}
|
||||
self.model[0].from_bytes(msg['lower_model'])
|
||||
util.model[1].from_bytes(msg['upper_model'])
|
||||
self.model[1].from_bytes(msg['upper_model'])
|
||||
self.cfg.update(cfg)
|
||||
return self
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from ..util import ensure_path
|
||||
from ..util import model_to_bytes, model_from_bytes
|
||||
from .. import util
|
||||
from ..displacy import parse_deps, parse_ents
|
||||
from ..tokens import Span
|
||||
|
@ -20,41 +19,6 @@ def test_util_ensure_path_succeeds(text):
|
|||
assert isinstance(path, Path)
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_simple_model_roundtrip_bytes():
|
||||
model = Maxout(5, 10, pieces=2)
|
||||
model.b += 1
|
||||
data = model_to_bytes(model)
|
||||
model.b -= 1
|
||||
model_from_bytes(model, data)
|
||||
assert model.b[0, 0] == 1
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_multi_model_roundtrip_bytes():
|
||||
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
||||
model._layers[0].b += 1
|
||||
model._layers[1].b += 2
|
||||
data = model_to_bytes(model)
|
||||
model._layers[0].b -= 1
|
||||
model._layers[1].b -= 2
|
||||
model_from_bytes(model, data)
|
||||
assert model._layers[0].b[0, 0] == 1
|
||||
assert model._layers[1].b[0, 0] == 2
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_multi_model_load_missing_dims():
|
||||
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
||||
model._layers[0].b += 1
|
||||
model._layers[1].b += 2
|
||||
data = model_to_bytes(model)
|
||||
|
||||
model2 = chain(Maxout(5), Maxout())
|
||||
model_from_bytes(model2, data)
|
||||
assert model2._layers[0].b[0, 0] == 1
|
||||
assert model2._layers[1].b[0, 0] == 2
|
||||
|
||||
@pytest.mark.parametrize('package', ['numpy'])
|
||||
def test_util_is_package(package):
|
||||
"""Test that an installed package via pip is recognised by util.is_package."""
|
||||
|
|
Loading…
Reference in New Issue
Block a user