mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
init_nlp_distill -> init_nlp_student
This commit is contained in:
parent
aa0783cf7f
commit
3fba5e07b5
|
@ -2,7 +2,7 @@ from typing import Callable, Iterable, Iterator
|
||||||
import pytest
|
import pytest
|
||||||
from spacy import Language
|
from spacy import Language
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.training.initialize import init_nlp_distill
|
from spacy.training.initialize import init_nlp_student
|
||||||
from spacy.training.loop import distill, train
|
from spacy.training.loop import distill, train
|
||||||
from spacy.util import load_model_from_config, registry
|
from spacy.util import load_model_from_config, registry
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
|
@ -101,7 +101,7 @@ def test_distill_loop(config_str):
|
||||||
train(teacher)
|
train(teacher)
|
||||||
|
|
||||||
orig_config = Config().from_str(config_str)
|
orig_config = Config().from_str(config_str)
|
||||||
student = init_nlp_distill(orig_config, teacher)
|
student = init_nlp_student(orig_config, teacher)
|
||||||
student.initialize()
|
student.initialize()
|
||||||
distill(teacher, student)
|
distill(teacher, student)
|
||||||
|
|
||||||
|
|
|
@ -94,9 +94,16 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def init_nlp_distill(
|
def init_nlp_student(
|
||||||
config: Config, teacher: "Language", *, use_gpu: int = -1
|
config: Config, teacher: "Language", *, use_gpu: int = -1
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
|
"""Initialize student pipeline for distillation.
|
||||||
|
|
||||||
|
config (Config): Student model configuration.
|
||||||
|
teacher (Language): The teacher pipeline to distill from.
|
||||||
|
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||||
|
before calling this function.
|
||||||
|
"""
|
||||||
raw_config = config
|
raw_config = config
|
||||||
config = raw_config.interpolate()
|
config = raw_config.interpolate()
|
||||||
_set_seed_from_config(config)
|
_set_seed_from_config(config)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user