Fix loading of multiple pre-trained vectors

This patch addresses #1660, which was caused by keying all pre-trained
vectors with the same ID when telling Thinc how to refer to them. This
meant that if multiple models were loaded that had pre-trained vectors,
errors or incorrect behaviour resulted.

The vectors class now includes a .name attribute, which defaults to:
{nlp.meta['lang']_nlp.meta['name']}.vectors
The vectors name is set in the cfg of the pipeline components under the
key pretrained_vectors. This replaces the previous cfg key
pretrained_dims.

In order to make existing models compatible with this change, we check
for the pretrained_dims key when loading models in from_disk and
from_bytes, and add the cfg key pretrained_vectors if we find it.
This commit is contained in:
Matthew Honnibal 2018-03-28 16:02:59 +02:00
parent 070b6c6495
commit 95a9615221
7 changed files with 100 additions and 32 deletions

View File

@ -242,6 +242,10 @@ class PrecomputableAffine(Model):
def link_vectors_to_models(vocab):
vectors = vocab.vectors
if vectors.name is None:
raise ValueError(
"Unnamed vectors -- this won't allow multiple vectors "
"models to be loaded. (Shape: (%d, %d))" % vectors.data.shape)
ops = Model.ops
for word in vocab:
if word.orth in vectors.key2row:
@ -251,11 +255,11 @@ def link_vectors_to_models(vocab):
data = ops.asarray(vectors.data)
# Set an entry here, so that vectors are accessed by StaticVectors
# (unideal, I know)
thinc.extra.load_nlp.VECTORS[(ops.device, VECTORS_KEY)] = data
thinc.extra.load_nlp.VECTORS[(ops.device, vectors.name)] = data
def Tok2Vec(width, embed_size, **kwargs):
pretrained_dims = kwargs.get('pretrained_dims', 0)
pretrained_vectors = kwargs.get('pretrained_vectors', None)
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 2)
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone,
@ -268,16 +272,16 @@ def Tok2Vec(width, embed_size, **kwargs):
name='embed_suffix')
shape = HashEmbed(width, embed_size//2, column=cols.index(SHAPE),
name='embed_shape')
if pretrained_dims is not None and pretrained_dims >= 1:
glove = StaticVectors(VECTORS_KEY, width, column=cols.index(ID))
if pretrained_vectors is not None:
glove = StaticVectors(pretrained_vectors, width, column=cols.index(ID))
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> LN(Maxout(width, width*5, pieces=3)), column=5)
>> LN(Maxout(width, width*5, pieces=3)), column=cols.index(ORTH))
else:
embed = uniqued(
(norm | prefix | suffix | shape)
>> LN(Maxout(width, width*4, pieces=3)), column=5)
>> LN(Maxout(width, width*4, pieces=3)), column=cols.index(ORTH))
convolution = Residual(
ExtractWindow(nW=1)
@ -433,13 +437,13 @@ def build_tagger_model(nr_class, **cfg):
token_vector_width = cfg['token_vector_width']
else:
token_vector_width = util.env_opt('token_vector_width', 128)
pretrained_dims = cfg.get('pretrained_dims', 0)
pretrained_vectors = cfg['pretrained_vectors']
with Model.define_operators({'>>': chain, '+': add}):
if 'tok2vec' in cfg:
tok2vec = cfg['tok2vec']
else:
tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=pretrained_dims)
pretrained_vectors=pretrained_vectors)
softmax = with_flatten(Softmax(nr_class, token_vector_width))
model = (
tok2vec

View File

@ -133,6 +133,8 @@ class Language(object):
if vocab is True:
factory = self.Defaults.create_vocab
vocab = factory(self, **meta.get('vocab', {}))
if vocab.vectors.name is None:
vocab.vectors.name = meta.get('vectors', {}).get('name')
self.vocab = vocab
if make_doc is True:
factory = self.Defaults.create_tokenizer
@ -158,7 +160,8 @@ class Language(object):
self._meta.setdefault('license', '')
self._meta['vectors'] = {'width': self.vocab.vectors_length,
'vectors': len(self.vocab.vectors),
'keys': self.vocab.vectors.n_keys}
'keys': self.vocab.vectors.n_keys,
'name': self.vocab.vectors.name}
self._meta['pipeline'] = self.pipe_names
return self._meta
@ -457,6 +460,8 @@ class Language(object):
else:
device = None
link_vectors_to_models(self.vocab)
if self.vocab.vectors.data.shape[1]:
cfg['pretrained_vectors'] = self.vocab.vectors.name
if sgd is None:
sgd = create_default_optimizer(Model.ops)
self._optimizer = sgd
@ -629,6 +634,7 @@ class Language(object):
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
('meta.json', lambda p: self.meta.update(util.read_json(p)))
))
_fix_pretrained_vectors_name(self)
for name, proc in self.pipeline:
if name in disable:
continue
@ -674,6 +680,7 @@ class Language(object):
('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
('meta', lambda b: self.meta.update(ujson.loads(b)))
))
_fix_pretrained_vectors_name(self)
for i, (name, proc) in enumerate(self.pipeline):
if name in disable:
continue
@ -683,6 +690,24 @@ class Language(object):
msg = util.from_bytes(bytes_data, deserializers, {})
return self
def _fix_pretrained_vectors_name(nlp):
# TODO: Replace this once we handle vectors consistently as static
# data
if 'vectors' in nlp.meta and nlp.meta['vectors'].get('name'):
nlp.vocab.vectors.name = nlp.meta['vectors']['name']
elif 'name' in nlp.meta and 'lang' in nlp.meta:
vectors_name = '%s_%s.vectors' % (nlp.meta['lang'], nlp.meta['name'])
nlp.vocab.vectors.name = vectors_name
else:
raise ValueError("Unnamed vectors")
for name, proc in nlp.pipeline:
if not hasattr(proc, 'cfg'):
continue
if proc.cfg.get('pretrained_dims'):
assert nlp.vocab.vectors.name
proc.cfg['pretrained_vectors'] = nlp.vocab.vectors.name
print(proc.cfg)
class DisabledPipes(list):
"""Manager for temporary pipeline disabling."""

View File

@ -202,8 +202,10 @@ class Pipe(object):
def from_bytes(self, bytes_data, **exclude):
"""Load the pipe from a bytestring."""
def load_model(b):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
self.model = self.Model(**self.cfg)
self.model.from_bytes(b)
@ -227,8 +229,10 @@ class Pipe(object):
def from_disk(self, path, **exclude):
"""Load the pipe from disk."""
def load_model(p):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
self.model = self.Model(**self.cfg)
self.model.from_bytes(p.open('rb').read())
@ -286,7 +290,6 @@ class Tensorizer(Pipe):
self.model = model
self.input_models = []
self.cfg = dict(cfg)
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
self.cfg.setdefault('cnn_maxout_pieces', 3)
def __call__(self, doc):
@ -403,8 +406,6 @@ class Tagger(Pipe):
self.model = model
self.cfg = OrderedDict(sorted(cfg.items()))
self.cfg.setdefault('cnn_maxout_pieces', 2)
self.cfg.setdefault('pretrained_dims',
self.vocab.vectors.data.shape[1])
@property
def labels(self):
@ -516,7 +517,6 @@ class Tagger(Pipe):
vocab.morphology.lemmatizer,
exc=vocab.morphology.exc)
if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
link_vectors_to_models(self.vocab)
if sgd is None:
@ -525,6 +525,14 @@ class Tagger(Pipe):
@classmethod
def Model(cls, n_tags, **cfg):
if cfg.get('pretrained_dims') and not cfg.get('pretrained_vectors'):
raise ValueError(
"Bad configuration of Tagger --- this is probably a bug "
"within spaCy. We changed the name of an internal attribute "
"for loading pre-trained vectors, and the class has been "
"passed the old name (pretrained_dims) but not the new name "
"(pretrained_vectors)")
print(cfg)
return build_tagger_model(n_tags, **cfg)
def add_label(self, label, values=None):
@ -572,6 +580,10 @@ class Tagger(Pipe):
def from_bytes(self, bytes_data, **exclude):
def load_model(b):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
token_vector_width = util.env_opt(
'token_vector_width',
@ -597,7 +609,6 @@ class Tagger(Pipe):
return self
def to_disk(self, path, **exclude):
self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1])
tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items()))
serialize = OrderedDict((
('vocab', lambda p: self.vocab.to_disk(p)),
@ -610,6 +621,9 @@ class Tagger(Pipe):
def from_disk(self, path, **exclude):
def load_model(p):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
with p.open('rb') as file_:
@ -659,8 +673,6 @@ class MultitaskObjective(Tagger):
"one of: dep, tag, ent, dep_tag_offset, ent_tag.")
self.cfg = dict(cfg)
self.cfg.setdefault('cnn_maxout_pieces', 2)
self.cfg.setdefault('pretrained_dims',
self.vocab.vectors.data.shape[1])
@property
def labels(self):
@ -904,7 +916,6 @@ class TextCategorizer(Pipe):
else:
token_vector_width = 64
if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors_length
self.model = self.Model(len(self.labels), token_vector_width,
**self.cfg)
link_vectors_to_models(self.vocab)

View File

@ -256,7 +256,7 @@ cdef class Parser:
if hist_width != 0:
raise ValueError("Currently history width is hard-coded to 0")
tok2vec = Tok2Vec(token_vector_width, embed_size,
pretrained_dims=cfg.get('pretrained_dims', 0))
pretrained_vectors=cfg.get('pretrained_vectors', None))
tok2vec = chain(tok2vec, flatten)
lower = PrecomputableAffine(hidden_width,
nF=cls.nr_feature, nI=token_vector_width,
@ -294,9 +294,9 @@ cdef class Parser:
unless True (default), in which case a new instance is created with
`Parser.Moves()`.
model (object): Defines how the parse-state is created, updated and
evaluated. The value is set to the .model attribute unless True
(default), in which case a new instance is created with
`Parser.Model()`.
evaluated. The value is set to the .model attribute. If set to True
(default), a new instance will be created with `Parser.Model()`
in parser.begin_training(), parser.from_disk() or parser.from_bytes().
**cfg: Arbitrary configuration parameters. Set to the `.cfg` attribute
"""
self.vocab = vocab
@ -308,8 +308,6 @@ cdef class Parser:
cfg['beam_width'] = util.env_opt('beam_width', 1)
if 'beam_density' not in cfg:
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
if 'pretrained_dims' not in cfg:
cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
cfg.setdefault('cnn_maxout_pieces', 3)
self.cfg = cfg
if 'actions' in self.cfg:
@ -832,7 +830,6 @@ cdef class Parser:
self.moves.add_action(action, label)
cfg.setdefault('token_vector_width', 128)
if self.model is True:
cfg['pretrained_dims'] = self.vocab.vectors_length
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
if sgd is None:
sgd = self.create_optimizer()
@ -896,9 +893,12 @@ cdef class Parser:
}
util.from_disk(path, deserializers, exclude)
if 'model' not in exclude:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
print("Create parser model", self.cfg)
path = util.ensure_path(path)
if self.model is True:
self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
self.model, cfg = self.Model(**self.cfg)
else:
cfg = {}
@ -941,12 +941,14 @@ cdef class Parser:
))
msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'model' not in exclude:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
print("Create parser model", self.cfg)
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
cfg['pretrained_dims'] = self.vocab.vectors_length
else:
cfg = {}
cfg['pretrained_dims'] = self.vocab.vectors_length
if 'tok2vec_model' in msg:
self.model[0].from_bytes(msg['tok2vec_model'])
if 'lower_model' in msg:

View File

@ -19,7 +19,9 @@ _languages = ['bn', 'da', 'de', 'en', 'es', 'fi', 'fr', 'ga', 'he', 'hu', 'id',
_models = {'en': ['en_core_web_sm'],
'de': ['de_core_news_md'],
'fr': ['fr_core_news_sm'],
'xx': ['xx_ent_web_md']}
'xx': ['xx_ent_web_md'],
'en_core_web_md': ['en_core_web_md'],
'es_core_news_md': ['es_core_news_md']}
# only used for tests that require loading the models
@ -183,6 +185,9 @@ def pytest_addoption(parser):
for lang in _languages + ['all']:
parser.addoption("--%s" % lang, action="store_true", help="Use %s models" % lang)
for model in _models:
if model not in _languages:
parser.addoption("--%s" % model, action="store_true", help="Use %s model" % model)
def pytest_runtest_setup(item):

View File

@ -1,6 +1,7 @@
# coding: utf8
from __future__ import unicode_literals
import functools
import numpy
from collections import OrderedDict
import msgpack
@ -19,6 +20,20 @@ def unpickle_vectors(bytes_data):
return Vectors().from_bytes(bytes_data)
class GlobalRegistry(object):
'''Global store of vectors, to avoid repeatedly loading the data.'''
data = {}
@classmethod
def register(cls, name, data):
cls.data[name] = data
return functools.partial(cls.get, name)
@classmethod
def get(cls, name):
return cls.data[name]
cdef class Vectors:
"""Store, save and load word vectors.
@ -31,18 +46,21 @@ cdef class Vectors:
the table need to be assigned --- so len(list(vectors.keys())) may be
greater or smaller than vectors.shape[0].
"""
cdef public object name
cdef public object data
cdef public object key2row
cdef public object _unset
def __init__(self, *, shape=None, data=None, keys=None):
def __init__(self, *, shape=None, data=None, keys=None, name=None):
"""Create a new vector store.
shape (tuple): Size of the table, as (# entries, # columns)
data (numpy.ndarray): The vector data.
keys (iterable): A sequence of keys, aligned with the data.
name (string): A name to identify the vectors table.
RETURNS (Vectors): The newly created object.
"""
self.name = name
if data is None:
if shape is None:
shape = (0,0)

View File

@ -381,7 +381,8 @@ cdef class Vocab:
self.lexemes_from_bytes(file_.read())
if self.vectors is not None:
self.vectors.from_disk(path, exclude='strings.json')
link_vectors_to_models(self)
if self.vectors.name is not None:
link_vectors_to_models(self)
return self
def to_bytes(self, **exclude):
@ -421,6 +422,8 @@ cdef class Vocab:
('vectors', lambda b: serialize_vectors(b))
))
util.from_bytes(bytes_data, setters, exclude)
if self.vectors.name is not None:
link_vectors_to_models(self)
return self
def lexemes_to_bytes(self):