diff --git a/spacy/cli/configure.py b/spacy/cli/configure.py index aa28da480..2d0f1ade2 100644 --- a/spacy/cli/configure.py +++ b/spacy/cli/configure.py @@ -2,6 +2,7 @@ from pathlib import Path import re from wasabi import msg import typer +from thinc.api import Config import spacy from spacy.language import Language @@ -162,7 +163,7 @@ def configure_resume_cli( @configure_cli.command("transformer") def use_transformer( base_model: str, output_path: Path, transformer_name: str = "roberta-base" -): +) -> Config: """Replace pipeline tok2vec with transformer.""" # 1. identify tok2vec @@ -206,7 +207,7 @@ def use_transformer( @configure_cli.command("tok2vec") -def use_tok2vec(base_model: str, output_path: Path) -> Language: +def use_tok2vec(base_model: str, output_path: Path) -> Config: """Replace pipeline tok2vec with CNN tok2vec.""" nlp = spacy.load(base_model) _check_single_tok2vec(base_model, nlp.config)