prevent loading a pretrained Tok2Vec layer AND pretrained components

This commit is contained in:
svlandeg 2020-05-29 17:38:33 +02:00
parent 04ba37b667
commit 291483157d
2 changed files with 9 additions and 2 deletions

View File

@ -15,6 +15,7 @@ import random
from .._ml import create_default_optimizer
from ..util import use_gpu as set_gpu
from ..errors import Errors
from ..gold import GoldCorpus
from ..compat import path2str
from ..lookups import Lookups
@ -182,6 +183,7 @@ def train(
msg.warn("Unable to activate GPU: {}".format(use_gpu))
msg.text("Using CPU only")
use_gpu = -1
base_components = []
if base_model:
msg.text("Starting with base model '{}'".format(base_model))
nlp = util.load_model(base_model)
@ -227,6 +229,7 @@ def train(
exits=1,
)
msg.text("Extending component from base model '{}'".format(pipe))
base_components.append(pipe)
disabled_pipes = nlp.disable_pipes(
[p for p in nlp.pipe_names if p not in pipeline]
)
@ -299,7 +302,7 @@ def train(
# Load in pretrained weights
if init_tok2vec is not None:
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
components = _load_pretrained_tok2vec(nlp, init_tok2vec, base_components)
msg.text("Loaded pretrained tok2vec for: {}".format(components))
# Verify textcat config
@ -642,7 +645,7 @@ def _load_vectors(nlp, vectors):
util.load_model(vectors, vocab=nlp.vocab)
def _load_pretrained_tok2vec(nlp, loc):
def _load_pretrained_tok2vec(nlp, loc, base_components):
"""Load pretrained weights for the 'token-to-vector' part of the component
models, which is typically a CNN. See 'spacy pretrain'. Experimental.
"""
@ -651,6 +654,8 @@ def _load_pretrained_tok2vec(nlp, loc):
loaded = []
for name, component in nlp.pipeline:
if hasattr(component, "model") and hasattr(component.model, "tok2vec"):
if name in base_components:
raise ValueError(Errors.E200.format(component=name))
component.tok2vec.from_bytes(weights_data)
loaded.append(name)
return loaded

View File

@ -568,6 +568,8 @@ class Errors(object):
E198 = ("Unable to return {n} most similar vectors for the current vectors "
"table, which contains {n_rows} vectors.")
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
E200 = ("Specifying a base model with a pretrained component '{component}' "
"can not be combined with adding a pretrained Tok2Vec layer.")
@add_codes