mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Fix loading of multiple pre-trained vectors
This patch addresses #1660, which was caused by keying all pre-trained vectors with the same ID when telling Thinc how to refer to them. This meant that if multiple models were loaded that had pre-trained vectors, errors or incorrect behaviour resulted. The vectors class now includes a .name attribute, which defaults to: {nlp.meta['lang']_nlp.meta['name']}.vectors The vectors name is set in the cfg of the pipeline components under the key pretrained_vectors. This replaces the previous cfg key pretrained_dims. In order to make existing models compatible with this change, we check for the pretrained_dims key when loading models in from_disk and from_bytes, and add the cfg key pretrained_vectors if we find it.
This commit is contained in:
parent
070b6c6495
commit
95a9615221
20
spacy/_ml.py
20
spacy/_ml.py
|
@ -242,6 +242,10 @@ class PrecomputableAffine(Model):
|
||||||
|
|
||||||
def link_vectors_to_models(vocab):
|
def link_vectors_to_models(vocab):
|
||||||
vectors = vocab.vectors
|
vectors = vocab.vectors
|
||||||
|
if vectors.name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Unnamed vectors -- this won't allow multiple vectors "
|
||||||
|
"models to be loaded. (Shape: (%d, %d))" % vectors.data.shape)
|
||||||
ops = Model.ops
|
ops = Model.ops
|
||||||
for word in vocab:
|
for word in vocab:
|
||||||
if word.orth in vectors.key2row:
|
if word.orth in vectors.key2row:
|
||||||
|
@ -251,11 +255,11 @@ def link_vectors_to_models(vocab):
|
||||||
data = ops.asarray(vectors.data)
|
data = ops.asarray(vectors.data)
|
||||||
# Set an entry here, so that vectors are accessed by StaticVectors
|
# Set an entry here, so that vectors are accessed by StaticVectors
|
||||||
# (unideal, I know)
|
# (unideal, I know)
|
||||||
thinc.extra.load_nlp.VECTORS[(ops.device, VECTORS_KEY)] = data
|
thinc.extra.load_nlp.VECTORS[(ops.device, vectors.name)] = data
|
||||||
|
|
||||||
|
|
||||||
def Tok2Vec(width, embed_size, **kwargs):
|
def Tok2Vec(width, embed_size, **kwargs):
|
||||||
pretrained_dims = kwargs.get('pretrained_dims', 0)
|
pretrained_vectors = kwargs.get('pretrained_vectors', None)
|
||||||
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 2)
|
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 2)
|
||||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone,
|
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone,
|
||||||
|
@ -268,16 +272,16 @@ def Tok2Vec(width, embed_size, **kwargs):
|
||||||
name='embed_suffix')
|
name='embed_suffix')
|
||||||
shape = HashEmbed(width, embed_size//2, column=cols.index(SHAPE),
|
shape = HashEmbed(width, embed_size//2, column=cols.index(SHAPE),
|
||||||
name='embed_shape')
|
name='embed_shape')
|
||||||
if pretrained_dims is not None and pretrained_dims >= 1:
|
if pretrained_vectors is not None:
|
||||||
glove = StaticVectors(VECTORS_KEY, width, column=cols.index(ID))
|
glove = StaticVectors(pretrained_vectors, width, column=cols.index(ID))
|
||||||
|
|
||||||
embed = uniqued(
|
embed = uniqued(
|
||||||
(glove | norm | prefix | suffix | shape)
|
(glove | norm | prefix | suffix | shape)
|
||||||
>> LN(Maxout(width, width*5, pieces=3)), column=5)
|
>> LN(Maxout(width, width*5, pieces=3)), column=cols.index(ORTH))
|
||||||
else:
|
else:
|
||||||
embed = uniqued(
|
embed = uniqued(
|
||||||
(norm | prefix | suffix | shape)
|
(norm | prefix | suffix | shape)
|
||||||
>> LN(Maxout(width, width*4, pieces=3)), column=5)
|
>> LN(Maxout(width, width*4, pieces=3)), column=cols.index(ORTH))
|
||||||
|
|
||||||
convolution = Residual(
|
convolution = Residual(
|
||||||
ExtractWindow(nW=1)
|
ExtractWindow(nW=1)
|
||||||
|
@ -433,13 +437,13 @@ def build_tagger_model(nr_class, **cfg):
|
||||||
token_vector_width = cfg['token_vector_width']
|
token_vector_width = cfg['token_vector_width']
|
||||||
else:
|
else:
|
||||||
token_vector_width = util.env_opt('token_vector_width', 128)
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
pretrained_dims = cfg.get('pretrained_dims', 0)
|
pretrained_vectors = cfg['pretrained_vectors']
|
||||||
with Model.define_operators({'>>': chain, '+': add}):
|
with Model.define_operators({'>>': chain, '+': add}):
|
||||||
if 'tok2vec' in cfg:
|
if 'tok2vec' in cfg:
|
||||||
tok2vec = cfg['tok2vec']
|
tok2vec = cfg['tok2vec']
|
||||||
else:
|
else:
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||||
pretrained_dims=pretrained_dims)
|
pretrained_vectors=pretrained_vectors)
|
||||||
softmax = with_flatten(Softmax(nr_class, token_vector_width))
|
softmax = with_flatten(Softmax(nr_class, token_vector_width))
|
||||||
model = (
|
model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
|
|
|
@ -133,6 +133,8 @@ class Language(object):
|
||||||
if vocab is True:
|
if vocab is True:
|
||||||
factory = self.Defaults.create_vocab
|
factory = self.Defaults.create_vocab
|
||||||
vocab = factory(self, **meta.get('vocab', {}))
|
vocab = factory(self, **meta.get('vocab', {}))
|
||||||
|
if vocab.vectors.name is None:
|
||||||
|
vocab.vectors.name = meta.get('vectors', {}).get('name')
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
if make_doc is True:
|
if make_doc is True:
|
||||||
factory = self.Defaults.create_tokenizer
|
factory = self.Defaults.create_tokenizer
|
||||||
|
@ -158,7 +160,8 @@ class Language(object):
|
||||||
self._meta.setdefault('license', '')
|
self._meta.setdefault('license', '')
|
||||||
self._meta['vectors'] = {'width': self.vocab.vectors_length,
|
self._meta['vectors'] = {'width': self.vocab.vectors_length,
|
||||||
'vectors': len(self.vocab.vectors),
|
'vectors': len(self.vocab.vectors),
|
||||||
'keys': self.vocab.vectors.n_keys}
|
'keys': self.vocab.vectors.n_keys,
|
||||||
|
'name': self.vocab.vectors.name}
|
||||||
self._meta['pipeline'] = self.pipe_names
|
self._meta['pipeline'] = self.pipe_names
|
||||||
return self._meta
|
return self._meta
|
||||||
|
|
||||||
|
@ -457,6 +460,8 @@ class Language(object):
|
||||||
else:
|
else:
|
||||||
device = None
|
device = None
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
if self.vocab.vectors.data.shape[1]:
|
||||||
|
cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = create_default_optimizer(Model.ops)
|
sgd = create_default_optimizer(Model.ops)
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
|
@ -629,6 +634,7 @@ class Language(object):
|
||||||
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
|
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
|
||||||
('meta.json', lambda p: self.meta.update(util.read_json(p)))
|
('meta.json', lambda p: self.meta.update(util.read_json(p)))
|
||||||
))
|
))
|
||||||
|
_fix_pretrained_vectors_name(self)
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if name in disable:
|
if name in disable:
|
||||||
continue
|
continue
|
||||||
|
@ -674,6 +680,7 @@ class Language(object):
|
||||||
('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
|
('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
|
||||||
('meta', lambda b: self.meta.update(ujson.loads(b)))
|
('meta', lambda b: self.meta.update(ujson.loads(b)))
|
||||||
))
|
))
|
||||||
|
_fix_pretrained_vectors_name(self)
|
||||||
for i, (name, proc) in enumerate(self.pipeline):
|
for i, (name, proc) in enumerate(self.pipeline):
|
||||||
if name in disable:
|
if name in disable:
|
||||||
continue
|
continue
|
||||||
|
@ -683,6 +690,24 @@ class Language(object):
|
||||||
msg = util.from_bytes(bytes_data, deserializers, {})
|
msg = util.from_bytes(bytes_data, deserializers, {})
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _fix_pretrained_vectors_name(nlp):
|
||||||
|
# TODO: Replace this once we handle vectors consistently as static
|
||||||
|
# data
|
||||||
|
if 'vectors' in nlp.meta and nlp.meta['vectors'].get('name'):
|
||||||
|
nlp.vocab.vectors.name = nlp.meta['vectors']['name']
|
||||||
|
elif 'name' in nlp.meta and 'lang' in nlp.meta:
|
||||||
|
vectors_name = '%s_%s.vectors' % (nlp.meta['lang'], nlp.meta['name'])
|
||||||
|
nlp.vocab.vectors.name = vectors_name
|
||||||
|
else:
|
||||||
|
raise ValueError("Unnamed vectors")
|
||||||
|
for name, proc in nlp.pipeline:
|
||||||
|
if not hasattr(proc, 'cfg'):
|
||||||
|
continue
|
||||||
|
if proc.cfg.get('pretrained_dims'):
|
||||||
|
assert nlp.vocab.vectors.name
|
||||||
|
proc.cfg['pretrained_vectors'] = nlp.vocab.vectors.name
|
||||||
|
print(proc.cfg)
|
||||||
|
|
||||||
|
|
||||||
class DisabledPipes(list):
|
class DisabledPipes(list):
|
||||||
"""Manager for temporary pipeline disabling."""
|
"""Manager for temporary pipeline disabling."""
|
||||||
|
|
|
@ -202,8 +202,10 @@ class Pipe(object):
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
"""Load the pipe from a bytestring."""
|
"""Load the pipe from a bytestring."""
|
||||||
def load_model(b):
|
def load_model(b):
|
||||||
|
# TODO: Remove this once we don't have to handle previous models
|
||||||
|
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||||
|
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
|
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
self.model.from_bytes(b)
|
self.model.from_bytes(b)
|
||||||
|
|
||||||
|
@ -227,8 +229,10 @@ class Pipe(object):
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
"""Load the pipe from disk."""
|
"""Load the pipe from disk."""
|
||||||
def load_model(p):
|
def load_model(p):
|
||||||
|
# TODO: Remove this once we don't have to handle previous models
|
||||||
|
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||||
|
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
|
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
self.model.from_bytes(p.open('rb').read())
|
self.model.from_bytes(p.open('rb').read())
|
||||||
|
|
||||||
|
@ -286,7 +290,6 @@ class Tensorizer(Pipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.input_models = []
|
self.input_models = []
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
|
||||||
self.cfg.setdefault('cnn_maxout_pieces', 3)
|
self.cfg.setdefault('cnn_maxout_pieces', 3)
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
|
@ -403,8 +406,6 @@ class Tagger(Pipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.cfg = OrderedDict(sorted(cfg.items()))
|
self.cfg = OrderedDict(sorted(cfg.items()))
|
||||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||||
self.cfg.setdefault('pretrained_dims',
|
|
||||||
self.vocab.vectors.data.shape[1])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -516,7 +517,6 @@ class Tagger(Pipe):
|
||||||
vocab.morphology.lemmatizer,
|
vocab.morphology.lemmatizer,
|
||||||
exc=vocab.morphology.exc)
|
exc=vocab.morphology.exc)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
|
@ -525,6 +525,14 @@ class Tagger(Pipe):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, n_tags, **cfg):
|
def Model(cls, n_tags, **cfg):
|
||||||
|
if cfg.get('pretrained_dims') and not cfg.get('pretrained_vectors'):
|
||||||
|
raise ValueError(
|
||||||
|
"Bad configuration of Tagger --- this is probably a bug "
|
||||||
|
"within spaCy. We changed the name of an internal attribute "
|
||||||
|
"for loading pre-trained vectors, and the class has been "
|
||||||
|
"passed the old name (pretrained_dims) but not the new name "
|
||||||
|
"(pretrained_vectors)")
|
||||||
|
print(cfg)
|
||||||
return build_tagger_model(n_tags, **cfg)
|
return build_tagger_model(n_tags, **cfg)
|
||||||
|
|
||||||
def add_label(self, label, values=None):
|
def add_label(self, label, values=None):
|
||||||
|
@ -572,6 +580,10 @@ class Tagger(Pipe):
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
def load_model(b):
|
def load_model(b):
|
||||||
|
# TODO: Remove this once we don't have to handle previous models
|
||||||
|
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||||
|
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
|
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
token_vector_width = util.env_opt(
|
token_vector_width = util.env_opt(
|
||||||
'token_vector_width',
|
'token_vector_width',
|
||||||
|
@ -597,7 +609,6 @@ class Tagger(Pipe):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1])
|
|
||||||
tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items()))
|
tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items()))
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||||
|
@ -610,6 +621,9 @@ class Tagger(Pipe):
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
def load_model(p):
|
def load_model(p):
|
||||||
|
# TODO: Remove this once we don't have to handle previous models
|
||||||
|
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||||
|
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||||
with p.open('rb') as file_:
|
with p.open('rb') as file_:
|
||||||
|
@ -659,8 +673,6 @@ class MultitaskObjective(Tagger):
|
||||||
"one of: dep, tag, ent, dep_tag_offset, ent_tag.")
|
"one of: dep, tag, ent, dep_tag_offset, ent_tag.")
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
self.cfg.setdefault('cnn_maxout_pieces', 2)
|
||||||
self.cfg.setdefault('pretrained_dims',
|
|
||||||
self.vocab.vectors.data.shape[1])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -904,7 +916,6 @@ class TextCategorizer(Pipe):
|
||||||
else:
|
else:
|
||||||
token_vector_width = 64
|
token_vector_width = 64
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg['pretrained_dims'] = self.vocab.vectors_length
|
|
||||||
self.model = self.Model(len(self.labels), token_vector_width,
|
self.model = self.Model(len(self.labels), token_vector_width,
|
||||||
**self.cfg)
|
**self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
|
|
@ -256,7 +256,7 @@ cdef class Parser:
|
||||||
if hist_width != 0:
|
if hist_width != 0:
|
||||||
raise ValueError("Currently history width is hard-coded to 0")
|
raise ValueError("Currently history width is hard-coded to 0")
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
tok2vec = Tok2Vec(token_vector_width, embed_size,
|
||||||
pretrained_dims=cfg.get('pretrained_dims', 0))
|
pretrained_vectors=cfg.get('pretrained_vectors', None))
|
||||||
tok2vec = chain(tok2vec, flatten)
|
tok2vec = chain(tok2vec, flatten)
|
||||||
lower = PrecomputableAffine(hidden_width,
|
lower = PrecomputableAffine(hidden_width,
|
||||||
nF=cls.nr_feature, nI=token_vector_width,
|
nF=cls.nr_feature, nI=token_vector_width,
|
||||||
|
@ -294,9 +294,9 @@ cdef class Parser:
|
||||||
unless True (default), in which case a new instance is created with
|
unless True (default), in which case a new instance is created with
|
||||||
`Parser.Moves()`.
|
`Parser.Moves()`.
|
||||||
model (object): Defines how the parse-state is created, updated and
|
model (object): Defines how the parse-state is created, updated and
|
||||||
evaluated. The value is set to the .model attribute unless True
|
evaluated. The value is set to the .model attribute. If set to True
|
||||||
(default), in which case a new instance is created with
|
(default), a new instance will be created with `Parser.Model()`
|
||||||
`Parser.Model()`.
|
in parser.begin_training(), parser.from_disk() or parser.from_bytes().
|
||||||
**cfg: Arbitrary configuration parameters. Set to the `.cfg` attribute
|
**cfg: Arbitrary configuration parameters. Set to the `.cfg` attribute
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
|
@ -308,8 +308,6 @@ cdef class Parser:
|
||||||
cfg['beam_width'] = util.env_opt('beam_width', 1)
|
cfg['beam_width'] = util.env_opt('beam_width', 1)
|
||||||
if 'beam_density' not in cfg:
|
if 'beam_density' not in cfg:
|
||||||
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
|
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
|
||||||
if 'pretrained_dims' not in cfg:
|
|
||||||
cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
|
||||||
cfg.setdefault('cnn_maxout_pieces', 3)
|
cfg.setdefault('cnn_maxout_pieces', 3)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
if 'actions' in self.cfg:
|
if 'actions' in self.cfg:
|
||||||
|
@ -832,7 +830,6 @@ cdef class Parser:
|
||||||
self.moves.add_action(action, label)
|
self.moves.add_action(action, label)
|
||||||
cfg.setdefault('token_vector_width', 128)
|
cfg.setdefault('token_vector_width', 128)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
|
||||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
|
@ -896,9 +893,12 @@ cdef class Parser:
|
||||||
}
|
}
|
||||||
util.from_disk(path, deserializers, exclude)
|
util.from_disk(path, deserializers, exclude)
|
||||||
if 'model' not in exclude:
|
if 'model' not in exclude:
|
||||||
|
# TODO: Remove this once we don't have to handle previous models
|
||||||
|
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||||
|
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
|
print("Create parser model", self.cfg)
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
|
|
||||||
self.model, cfg = self.Model(**self.cfg)
|
self.model, cfg = self.Model(**self.cfg)
|
||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
|
@ -941,12 +941,14 @@ cdef class Parser:
|
||||||
))
|
))
|
||||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||||
if 'model' not in exclude:
|
if 'model' not in exclude:
|
||||||
|
# TODO: Remove this once we don't have to handle previous models
|
||||||
|
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||||
|
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||||
|
print("Create parser model", self.cfg)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model, cfg = self.Model(**self.cfg)
|
self.model, cfg = self.Model(**self.cfg)
|
||||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
|
||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
|
||||||
if 'tok2vec_model' in msg:
|
if 'tok2vec_model' in msg:
|
||||||
self.model[0].from_bytes(msg['tok2vec_model'])
|
self.model[0].from_bytes(msg['tok2vec_model'])
|
||||||
if 'lower_model' in msg:
|
if 'lower_model' in msg:
|
||||||
|
|
|
@ -19,7 +19,9 @@ _languages = ['bn', 'da', 'de', 'en', 'es', 'fi', 'fr', 'ga', 'he', 'hu', 'id',
|
||||||
_models = {'en': ['en_core_web_sm'],
|
_models = {'en': ['en_core_web_sm'],
|
||||||
'de': ['de_core_news_md'],
|
'de': ['de_core_news_md'],
|
||||||
'fr': ['fr_core_news_sm'],
|
'fr': ['fr_core_news_sm'],
|
||||||
'xx': ['xx_ent_web_md']}
|
'xx': ['xx_ent_web_md'],
|
||||||
|
'en_core_web_md': ['en_core_web_md'],
|
||||||
|
'es_core_news_md': ['es_core_news_md']}
|
||||||
|
|
||||||
|
|
||||||
# only used for tests that require loading the models
|
# only used for tests that require loading the models
|
||||||
|
@ -183,6 +185,9 @@ def pytest_addoption(parser):
|
||||||
|
|
||||||
for lang in _languages + ['all']:
|
for lang in _languages + ['all']:
|
||||||
parser.addoption("--%s" % lang, action="store_true", help="Use %s models" % lang)
|
parser.addoption("--%s" % lang, action="store_true", help="Use %s models" % lang)
|
||||||
|
for model in _models:
|
||||||
|
if model not in _languages:
|
||||||
|
parser.addoption("--%s" % model, action="store_true", help="Use %s model" % model)
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
def pytest_runtest_setup(item):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import functools
|
||||||
import numpy
|
import numpy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import msgpack
|
import msgpack
|
||||||
|
@ -19,6 +20,20 @@ def unpickle_vectors(bytes_data):
|
||||||
return Vectors().from_bytes(bytes_data)
|
return Vectors().from_bytes(bytes_data)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalRegistry(object):
|
||||||
|
'''Global store of vectors, to avoid repeatedly loading the data.'''
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, data):
|
||||||
|
cls.data[name] = data
|
||||||
|
return functools.partial(cls.get, name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, name):
|
||||||
|
return cls.data[name]
|
||||||
|
|
||||||
|
|
||||||
cdef class Vectors:
|
cdef class Vectors:
|
||||||
"""Store, save and load word vectors.
|
"""Store, save and load word vectors.
|
||||||
|
|
||||||
|
@ -31,18 +46,21 @@ cdef class Vectors:
|
||||||
the table need to be assigned --- so len(list(vectors.keys())) may be
|
the table need to be assigned --- so len(list(vectors.keys())) may be
|
||||||
greater or smaller than vectors.shape[0].
|
greater or smaller than vectors.shape[0].
|
||||||
"""
|
"""
|
||||||
|
cdef public object name
|
||||||
cdef public object data
|
cdef public object data
|
||||||
cdef public object key2row
|
cdef public object key2row
|
||||||
cdef public object _unset
|
cdef public object _unset
|
||||||
|
|
||||||
def __init__(self, *, shape=None, data=None, keys=None):
|
def __init__(self, *, shape=None, data=None, keys=None, name=None):
|
||||||
"""Create a new vector store.
|
"""Create a new vector store.
|
||||||
|
|
||||||
shape (tuple): Size of the table, as (# entries, # columns)
|
shape (tuple): Size of the table, as (# entries, # columns)
|
||||||
data (numpy.ndarray): The vector data.
|
data (numpy.ndarray): The vector data.
|
||||||
keys (iterable): A sequence of keys, aligned with the data.
|
keys (iterable): A sequence of keys, aligned with the data.
|
||||||
|
name (string): A name to identify the vectors table.
|
||||||
RETURNS (Vectors): The newly created object.
|
RETURNS (Vectors): The newly created object.
|
||||||
"""
|
"""
|
||||||
|
self.name = name
|
||||||
if data is None:
|
if data is None:
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = (0,0)
|
shape = (0,0)
|
||||||
|
|
|
@ -381,6 +381,7 @@ cdef class Vocab:
|
||||||
self.lexemes_from_bytes(file_.read())
|
self.lexemes_from_bytes(file_.read())
|
||||||
if self.vectors is not None:
|
if self.vectors is not None:
|
||||||
self.vectors.from_disk(path, exclude='strings.json')
|
self.vectors.from_disk(path, exclude='strings.json')
|
||||||
|
if self.vectors.name is not None:
|
||||||
link_vectors_to_models(self)
|
link_vectors_to_models(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -421,6 +422,8 @@ cdef class Vocab:
|
||||||
('vectors', lambda b: serialize_vectors(b))
|
('vectors', lambda b: serialize_vectors(b))
|
||||||
))
|
))
|
||||||
util.from_bytes(bytes_data, setters, exclude)
|
util.from_bytes(bytes_data, setters, exclude)
|
||||||
|
if self.vectors.name is not None:
|
||||||
|
link_vectors_to_models(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def lexemes_to_bytes(self):
|
def lexemes_to_bytes(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user