mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-30 20:06:30 +03:00
bdb485cc80
* Add callback to copy vocab/tokenizer from model Add callback `spacy.copy_from_base_model.v1` to copy the tokenizer settings and/or vocab (including vectors) from a base model. * Move spacy.copy_from_base_model.v1 to spacy.training.callbacks * Add documentation * Modify to specify model as tokenizer and vocab params
33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
from typing import Optional
|
|
from ..errors import Errors
|
|
from ..language import Language
|
|
from ..util import load_model, registry, logger
|
|
|
|
|
|
@registry.callbacks("spacy.copy_from_base_model.v1")
|
|
def create_copy_from_base_model(
|
|
tokenizer: Optional[str] = None,
|
|
vocab: Optional[str] = None,
|
|
) -> Language:
|
|
def copy_from_base_model(nlp):
|
|
if tokenizer:
|
|
logger.info(f"Copying tokenizer from: {tokenizer}")
|
|
base_nlp = load_model(tokenizer)
|
|
if nlp.config["nlp"]["tokenizer"] == base_nlp.config["nlp"]["tokenizer"]:
|
|
nlp.tokenizer.from_bytes(base_nlp.tokenizer.to_bytes(exclude=["vocab"]))
|
|
else:
|
|
raise ValueError(
|
|
Errors.E872.format(
|
|
curr_config=nlp.config["nlp"]["tokenizer"],
|
|
base_config=base_nlp.config["nlp"]["tokenizer"],
|
|
)
|
|
)
|
|
if vocab:
|
|
logger.info(f"Copying vocab from: {vocab}")
|
|
# only reload if the vocab is from a different model
|
|
if tokenizer != vocab:
|
|
base_nlp = load_model(vocab)
|
|
nlp.vocab.from_bytes(base_nlp.vocab.to_bytes())
|
|
|
|
return copy_from_base_model
|