init_nlp_distill -> init_nlp_student

This commit is contained in:
Daniël de Kok 2023-04-18 17:48:36 +02:00
parent aa0783cf7f
commit 3fba5e07b5
2 changed files with 10 additions and 3 deletions

View File

@ -2,7 +2,7 @@ from typing import Callable, Iterable, Iterator
import pytest
from spacy import Language
from spacy.training import Example
from spacy.training.initialize import init_nlp_distill
from spacy.training.initialize import init_nlp_student
from spacy.training.loop import distill, train
from spacy.util import load_model_from_config, registry
from thinc.api import Config
@ -101,7 +101,7 @@ def test_distill_loop(config_str):
train(teacher)
orig_config = Config().from_str(config_str)
student = init_nlp_distill(orig_config, teacher)
student = init_nlp_student(orig_config, teacher)
student.initialize()
distill(teacher, student)

View File

@ -94,9 +94,16 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
return nlp
def init_nlp_distill(
def init_nlp_student(
config: Config, teacher: "Language", *, use_gpu: int = -1
) -> "Language":
"""Initialize student pipeline for distillation.
config (Config): Student model configuration.
teacher (Language): The teacher pipeline to distill from.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function.
"""
raw_config = config
config = raw_config.interpolate()
_set_seed_from_config(config)