mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-04 13:40:34 +03:00
Fixes for thinc 6.7
This commit is contained in:
parent
53d00a0371
commit
4c97371051
|
@ -635,9 +635,9 @@ cdef class Parser:
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serializers = {
|
serializers = {
|
||||||
'lower_model': lambda p: p.open('wb').write(
|
'lower_model': lambda p: p.open('wb').write(
|
||||||
util.model_to_bytes(self.model[0])),
|
self.model[0].to_bytes()),
|
||||||
'upper_model': lambda p: p.open('wb').write(
|
'upper_model': lambda p: p.open('wb').write(
|
||||||
util.model_to_bytes(self.model[1])),
|
self.model[1].to_bytes()),
|
||||||
'vocab': lambda p: self.vocab.to_disk(p),
|
'vocab': lambda p: self.vocab.to_disk(p),
|
||||||
'moves': lambda p: self.moves.to_disk(p, strings=False),
|
'moves': lambda p: self.moves.to_disk(p, strings=False),
|
||||||
'cfg': lambda p: p.open('w').write(json_dumps(self.cfg))
|
'cfg': lambda p: p.open('w').write(json_dumps(self.cfg))
|
||||||
|
@ -669,8 +669,8 @@ cdef class Parser:
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
serializers = {
|
serializers = {
|
||||||
'lower_model': lambda: util.model_to_bytes(self.model[0]),
|
'lower_model': lambda: self.model[0].to_bytes(),
|
||||||
'upper_model': lambda: util.model_to_bytes(self.model[1]),
|
'upper_model': lambda: self.model[1].to_bytes(),
|
||||||
'vocab': lambda: self.vocab.to_bytes(),
|
'vocab': lambda: self.vocab.to_bytes(),
|
||||||
'moves': lambda: self.moves.to_bytes(strings=False),
|
'moves': lambda: self.moves.to_bytes(strings=False),
|
||||||
'cfg': lambda: ujson.dumps(self.cfg)
|
'cfg': lambda: ujson.dumps(self.cfg)
|
||||||
|
@ -692,7 +692,7 @@ cdef class Parser:
|
||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
self.model[0].from_bytes(msg['lower_model'])
|
self.model[0].from_bytes(msg['lower_model'])
|
||||||
util.model[1].from_bytes(msg['upper_model'])
|
self.model[1].from_bytes(msg['upper_model'])
|
||||||
self.cfg.update(cfg)
|
self.cfg.update(cfg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ..util import ensure_path
|
from ..util import ensure_path
|
||||||
from ..util import model_to_bytes, model_from_bytes
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..displacy import parse_deps, parse_ents
|
from ..displacy import parse_deps, parse_ents
|
||||||
from ..tokens import Span
|
from ..tokens import Span
|
||||||
|
@ -20,41 +19,6 @@ def test_util_ensure_path_succeeds(text):
|
||||||
assert isinstance(path, Path)
|
assert isinstance(path, Path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.models
|
|
||||||
def test_simple_model_roundtrip_bytes():
|
|
||||||
model = Maxout(5, 10, pieces=2)
|
|
||||||
model.b += 1
|
|
||||||
data = model_to_bytes(model)
|
|
||||||
model.b -= 1
|
|
||||||
model_from_bytes(model, data)
|
|
||||||
assert model.b[0, 0] == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.models
|
|
||||||
def test_multi_model_roundtrip_bytes():
|
|
||||||
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
|
||||||
model._layers[0].b += 1
|
|
||||||
model._layers[1].b += 2
|
|
||||||
data = model_to_bytes(model)
|
|
||||||
model._layers[0].b -= 1
|
|
||||||
model._layers[1].b -= 2
|
|
||||||
model_from_bytes(model, data)
|
|
||||||
assert model._layers[0].b[0, 0] == 1
|
|
||||||
assert model._layers[1].b[0, 0] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.models
|
|
||||||
def test_multi_model_load_missing_dims():
|
|
||||||
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
|
||||||
model._layers[0].b += 1
|
|
||||||
model._layers[1].b += 2
|
|
||||||
data = model_to_bytes(model)
|
|
||||||
|
|
||||||
model2 = chain(Maxout(5), Maxout())
|
|
||||||
model_from_bytes(model2, data)
|
|
||||||
assert model2._layers[0].b[0, 0] == 1
|
|
||||||
assert model2._layers[1].b[0, 0] == 2
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('package', ['numpy'])
|
@pytest.mark.parametrize('package', ['numpy'])
|
||||||
def test_util_is_package(package):
|
def test_util_is_package(package):
|
||||||
"""Test that an installed package via pip is recognised by util.is_package."""
|
"""Test that an installed package via pip is recognised by util.is_package."""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user