spaCy/spacy/training/callbacks.py
Daniël de Kok e2b70df012
Configure isort to use the Black profile, recursively isort the spacy module (#12721)
* Use isort with Black profile

* isort all the things

* Fix import cycles as a result of import sorting

* Add DOCBIN_ALL_ATTRS type definition

* Add isort to requirements

* Remove isort from build dependencies check

* Typo
2023-06-14 17:48:41 +02:00

36 lines
1.3 KiB
Python

from typing import TYPE_CHECKING, Callable, Optional
from ..errors import Errors
from ..util import load_model, logger, registry
if TYPE_CHECKING:
from ..language import Language
@registry.callbacks("spacy.copy_from_base_model.v1")
def create_copy_from_base_model(
tokenizer: Optional[str] = None,
vocab: Optional[str] = None,
) -> Callable[["Language"], "Language"]:
def copy_from_base_model(nlp):
if tokenizer:
logger.info("Copying tokenizer from: %s", tokenizer)
base_nlp = load_model(tokenizer)
if nlp.config["nlp"]["tokenizer"] == base_nlp.config["nlp"]["tokenizer"]:
nlp.tokenizer.from_bytes(base_nlp.tokenizer.to_bytes(exclude=["vocab"]))
else:
raise ValueError(
Errors.E872.format(
curr_config=nlp.config["nlp"]["tokenizer"],
base_config=base_nlp.config["nlp"]["tokenizer"],
)
)
if vocab:
logger.info("Copying vocab from: %s", vocab)
# only reload if the vocab is from a different model
if tokenizer != vocab:
base_nlp = load_model(vocab)
nlp.vocab.from_bytes(base_nlp.vocab.to_bytes())
return copy_from_base_model