mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Fix serialization when pre-trained vectors
This commit is contained in:
parent
980fb6e854
commit
05596159bf
|
@ -145,8 +145,8 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
|
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
|
||||||
('model', load_model),
|
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b))
|
('vocab', lambda b: self.vocab.from_bytes(b))
|
||||||
|
('model', load_model),
|
||||||
))
|
))
|
||||||
util.from_bytes(bytes_data, deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -154,8 +154,8 @@ class BaseThincComponent(object):
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
|
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
|
||||||
|
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||||
('vocab', lambda p: self.vocab.to_disk(p))
|
|
||||||
))
|
))
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -168,8 +168,8 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('cfg', lambda p: self.cfg.update(_load_cfg(p))),
|
('cfg', lambda p: self.cfg.update(_load_cfg(p))),
|
||||||
('model', load_model),
|
|
||||||
('vocab', lambda p: self.vocab.from_disk(p)),
|
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||||
|
('model', load_model),
|
||||||
))
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -289,6 +289,7 @@ class TokenVectorEncoder(BaseThincComponent):
|
||||||
pipeline (list): The pipeline the model is part of.
|
pipeline (list): The pipeline the model is part of.
|
||||||
"""
|
"""
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
|
self.cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
|
||||||
|
@ -398,6 +399,7 @@ class NeuralTagger(BaseThincComponent):
|
||||||
vocab.morphology.lemmatizer,
|
vocab.morphology.lemmatizer,
|
||||||
exc=vocab.morphology.exc)
|
exc=vocab.morphology.exc)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
|
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
|
||||||
|
@ -486,6 +488,7 @@ class NeuralLabeller(NeuralTagger):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||||
|
self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -508,8 +511,8 @@ class NeuralLabeller(NeuralTagger):
|
||||||
self.labels[dep] = len(self.labels)
|
self.labels[dep] = len(self.labels)
|
||||||
token_vector_width = pipeline[0].model.nO
|
token_vector_width = pipeline[0].model.nO
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(len(self.labels), token_vector_width=token_vector_width,
|
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||||
pretrained_dims=self.vocab.vectors_length)
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user