mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Tidy up and fix types
This commit is contained in:
		
							parent
							
								
									dbfc9f688f
								
							
						
					
					
						commit
						5f049b2de6
					
				| 
						 | 
					@ -94,7 +94,7 @@ def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
 | 
				
			||||||
            opt = opt.replace("--", "")
 | 
					            opt = opt.replace("--", "")
 | 
				
			||||||
            if "." not in opt:
 | 
					            if "." not in opt:
 | 
				
			||||||
                if is_cli:
 | 
					                if is_cli:
 | 
				
			||||||
                    raise radicli.CliParseError(f"unrecognized argument: {orig_opt}")
 | 
					                    raise radicli.CliParserError(f"unrecognized argument: {orig_opt}")
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    msg.fail(f"{err}: can't override top-level sections", exits=1)
 | 
					                    msg.fail(f"{err}: can't override top-level sections", exits=1)
 | 
				
			||||||
            if "=" in opt:  # we have --opt=value
 | 
					            if "=" in opt:  # we have --opt=value
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -53,8 +53,9 @@ def assemble_cli(
 | 
				
			||||||
    with nlp.select_pipes(disable=[*sourced]):
 | 
					    with nlp.select_pipes(disable=[*sourced]):
 | 
				
			||||||
        nlp.initialize()
 | 
					        nlp.initialize()
 | 
				
			||||||
    msg.good("Initialized pipeline")
 | 
					    msg.good("Initialized pipeline")
 | 
				
			||||||
 | 
					    if output_path is not None:
 | 
				
			||||||
        msg.divider("Serializing to disk")
 | 
					        msg.divider("Serializing to disk")
 | 
				
			||||||
    if output_path is not None and not output_path.exists():
 | 
					        if not output_path.exists():
 | 
				
			||||||
            output_path.mkdir(parents=True)
 | 
					            output_path.mkdir(parents=True)
 | 
				
			||||||
            msg.good(f"Created output directory: {output_path}")
 | 
					            msg.good(f"Created output directory: {output_path}")
 | 
				
			||||||
        nlp.to_disk(output_path)
 | 
					        nlp.to_disk(output_path)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,4 @@
 | 
				
			||||||
from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal
 | 
					from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal, cast
 | 
				
			||||||
from enum import Enum
 | 
					 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from wasabi import Printer
 | 
					from wasabi import Printer
 | 
				
			||||||
import srsly
 | 
					import srsly
 | 
				
			||||||
| 
						 | 
					@ -19,6 +18,8 @@ from ..training.converters import conllu_to_docs
 | 
				
			||||||
# matched by file extension and content. To add a converter, add a new
 | 
					# matched by file extension and content. To add a converter, add a new
 | 
				
			||||||
# entry to this dict with the file extension mapped to the converter function
 | 
					# entry to this dict with the file extension mapped to the converter function
 | 
				
			||||||
# imported from /converters.
 | 
					# imported from /converters.
 | 
				
			||||||
 | 
					ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
 | 
				
			||||||
 | 
					FileTypesType = Literal["json", "spacy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
 | 
					CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
 | 
				
			||||||
    "conllubio": conllu_to_docs,
 | 
					    "conllubio": conllu_to_docs,
 | 
				
			||||||
| 
						 | 
					@ -28,19 +29,11 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
 | 
				
			||||||
    "iob": iob_to_docs,
 | 
					    "iob": iob_to_docs,
 | 
				
			||||||
    "json": json_to_docs,
 | 
					    "json": json_to_docs,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
AUTO = "auto"
 | 
					AUTO: ConvertersType = "auto"
 | 
				
			||||||
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# File types that can be written to stdout
 | 
					# File types that can be written to stdout
 | 
				
			||||||
FILE_TYPES_STDOUT = ("json",)
 | 
					FILE_TYPES_STDOUT = ("json",)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FileTypes(str, Enum):
 | 
					 | 
				
			||||||
    json = "json"
 | 
					 | 
				
			||||||
    spacy = "spacy"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@cli.command(
 | 
					@cli.command(
 | 
				
			||||||
    "convert",
 | 
					    "convert",
 | 
				
			||||||
    # fmt: off
 | 
					    # fmt: off
 | 
				
			||||||
| 
						 | 
					@ -61,7 +54,7 @@ class FileTypes(str, Enum):
 | 
				
			||||||
def convert_cli(
 | 
					def convert_cli(
 | 
				
			||||||
    input_path: ExistingPathOrDash,
 | 
					    input_path: ExistingPathOrDash,
 | 
				
			||||||
    output_dir: ExistingDirPathOrDash = "-",
 | 
					    output_dir: ExistingDirPathOrDash = "-",
 | 
				
			||||||
    file_type: Literal["json", "spacy"] = "spacy",
 | 
					    file_type: FileTypesType = "spacy",
 | 
				
			||||||
    n_sents: int = 1,
 | 
					    n_sents: int = 1,
 | 
				
			||||||
    seg_sents: bool = False,
 | 
					    seg_sents: bool = False,
 | 
				
			||||||
    model: Optional[str] = None,
 | 
					    model: Optional[str] = None,
 | 
				
			||||||
| 
						 | 
					@ -110,7 +103,7 @@ def convert(
 | 
				
			||||||
    input_path: Path,
 | 
					    input_path: Path,
 | 
				
			||||||
    output_dir: Union[str, Path],
 | 
					    output_dir: Union[str, Path],
 | 
				
			||||||
    *,
 | 
					    *,
 | 
				
			||||||
    file_type: str = "json",
 | 
					    file_type: FileTypesType = "json",
 | 
				
			||||||
    n_sents: int = 1,
 | 
					    n_sents: int = 1,
 | 
				
			||||||
    seg_sents: bool = False,
 | 
					    seg_sents: bool = False,
 | 
				
			||||||
    model: Optional[str] = None,
 | 
					    model: Optional[str] = None,
 | 
				
			||||||
| 
						 | 
					@ -173,14 +166,16 @@ def convert(
 | 
				
			||||||
            msg.good(f"Generated output file ({len_docs} documents): {output_file}")
 | 
					            msg.good(f"Generated output file ({len_docs} documents): {output_file}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
 | 
					def _print_docs_to_stdout(data: Any, output_type: FileTypesType) -> None:
 | 
				
			||||||
    if output_type == "json":
 | 
					    if output_type == "json":
 | 
				
			||||||
        srsly.write_json("-", data)
 | 
					        srsly.write_json("-", data)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        sys.stdout.buffer.write(data)
 | 
					        sys.stdout.buffer.write(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
 | 
					def _write_docs_to_file(
 | 
				
			||||||
 | 
					    data: Any, output_file: Path, output_type: FileTypesType
 | 
				
			||||||
 | 
					) -> None:
 | 
				
			||||||
    if not output_file.parent.exists():
 | 
					    if not output_file.parent.exists():
 | 
				
			||||||
        output_file.parent.mkdir(parents=True)
 | 
					        output_file.parent.mkdir(parents=True)
 | 
				
			||||||
    if output_type == "json":
 | 
					    if output_type == "json":
 | 
				
			||||||
| 
						 | 
					@ -190,7 +185,7 @@ def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
 | 
				
			||||||
            file_.write(data)
 | 
					            file_.write(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def autodetect_ner_format(input_data: str) -> Optional[str]:
 | 
					def autodetect_ner_format(input_data: str) -> Optional[ConvertersType]:
 | 
				
			||||||
    # guess format from the first 20 lines
 | 
					    # guess format from the first 20 lines
 | 
				
			||||||
    lines = input_data.split("\n")[:20]
 | 
					    lines = input_data.split("\n")[:20]
 | 
				
			||||||
    format_guesses = {"ner": 0, "iob": 0}
 | 
					    format_guesses = {"ner": 0, "iob": 0}
 | 
				
			||||||
| 
						 | 
					@ -213,7 +208,7 @@ def verify_cli_args(
 | 
				
			||||||
    msg: Printer,
 | 
					    msg: Printer,
 | 
				
			||||||
    input_path: Path,
 | 
					    input_path: Path,
 | 
				
			||||||
    output_dir: Union[str, Path],
 | 
					    output_dir: Union[str, Path],
 | 
				
			||||||
    file_type: str,
 | 
					    file_type: FileTypesType,
 | 
				
			||||||
    converter: str,
 | 
					    converter: str,
 | 
				
			||||||
    ner_map: Optional[Path],
 | 
					    ner_map: Optional[Path],
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
| 
						 | 
					@ -236,7 +231,7 @@ def verify_cli_args(
 | 
				
			||||||
        msg.fail(f"Can't find converter for {converter}", exits=1)
 | 
					        msg.fail(f"Can't find converter for {converter}", exits=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_converter(msg, converter, input_path: Path):
 | 
					def _get_converter(msg: Printer, converter: ConvertersType, input_path: Path) -> str:
 | 
				
			||||||
    if input_path.is_dir():
 | 
					    if input_path.is_dir():
 | 
				
			||||||
        if converter == AUTO:
 | 
					        if converter == AUTO:
 | 
				
			||||||
            input_locs = walk_directory(input_path, suffix=None)
 | 
					            input_locs = walk_directory(input_path, suffix=None)
 | 
				
			||||||
| 
						 | 
					@ -265,4 +260,4 @@ def _get_converter(msg, converter, input_path: Path):
 | 
				
			||||||
                "Conversion may not succeed. "
 | 
					                "Conversion may not succeed. "
 | 
				
			||||||
                "See https://spacy.io/api/cli#convert"
 | 
					                "See https://spacy.io/api/cli#convert"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
    return converter
 | 
					    return cast(str, converter)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,7 +4,7 @@ from pathlib import Path
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from typing import Optional, Tuple, Any, Dict, List
 | 
					from typing import Optional, Tuple, Any, Dict, List
 | 
				
			||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
import wasabi.tables
 | 
					from wasabi import msg, row
 | 
				
			||||||
from radicli import Arg, ExistingPath, ExistingFilePath
 | 
					from radicli import Arg, ExistingPath, ExistingFilePath
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
 | 
					from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
 | 
				
			||||||
| 
						 | 
					@ -13,11 +13,9 @@ from ..training import Corpus
 | 
				
			||||||
from ._util import cli, import_code, setup_gpu
 | 
					from ._util import cli, import_code, setup_gpu
 | 
				
			||||||
from .. import util
 | 
					from .. import util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_DEFAULTS = {
 | 
					DEFAULT_N_TRIALS: int = 11
 | 
				
			||||||
    "n_trials": 11,
 | 
					DEFAULT_USE_GPU: int = -1
 | 
				
			||||||
    "use_gpu": -1,
 | 
					DEFAULT_GOLD_PREPROC: bool = False
 | 
				
			||||||
    "gold_preproc": False,
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@cli.command(
 | 
					@cli.command(
 | 
				
			||||||
| 
						 | 
					@ -41,10 +39,10 @@ def find_threshold_cli(
 | 
				
			||||||
    pipe_name: str,
 | 
					    pipe_name: str,
 | 
				
			||||||
    threshold_key: str,
 | 
					    threshold_key: str,
 | 
				
			||||||
    scores_key: str,
 | 
					    scores_key: str,
 | 
				
			||||||
    n_trials: int = _DEFAULTS["n_trials"],
 | 
					    n_trials: int = DEFAULT_N_TRIALS,
 | 
				
			||||||
    code_path: Optional[ExistingFilePath] = None,
 | 
					    code_path: Optional[ExistingFilePath] = None,
 | 
				
			||||||
    use_gpu: int = _DEFAULTS["use_gpu"],
 | 
					    use_gpu: int = DEFAULT_USE_GPU,
 | 
				
			||||||
    gold_preproc: bool = _DEFAULTS["gold_preproc"],
 | 
					    gold_preproc: bool = DEFAULT_GOLD_PREPROC,
 | 
				
			||||||
    verbose: bool = False,
 | 
					    verbose: bool = False,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -83,9 +81,9 @@ def find_threshold(
 | 
				
			||||||
    threshold_key: str,
 | 
					    threshold_key: str,
 | 
				
			||||||
    scores_key: str,
 | 
					    scores_key: str,
 | 
				
			||||||
    *,
 | 
					    *,
 | 
				
			||||||
    n_trials: int = _DEFAULTS["n_trials"],  # type: ignore
 | 
					    n_trials: int = DEFAULT_N_TRIALS,
 | 
				
			||||||
    use_gpu: int = _DEFAULTS["use_gpu"],  # type: ignore
 | 
					    use_gpu: int = DEFAULT_USE_GPU,
 | 
				
			||||||
    gold_preproc: bool = _DEFAULTS["gold_preproc"],  # type: ignore
 | 
					    gold_preproc: bool = DEFAULT_GOLD_PREPROC,
 | 
				
			||||||
    silent: bool = True,
 | 
					    silent: bool = True,
 | 
				
			||||||
) -> Tuple[float, float, Dict[float, float]]:
 | 
					) -> Tuple[float, float, Dict[float, float]]:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -104,30 +102,25 @@ def find_threshold(
 | 
				
			||||||
    RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
 | 
					    RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
 | 
				
			||||||
        evaluated thresholds.
 | 
					        evaluated thresholds.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					 | 
				
			||||||
    setup_gpu(use_gpu, silent=silent)
 | 
					    setup_gpu(use_gpu, silent=silent)
 | 
				
			||||||
    data_path = util.ensure_path(data_path)
 | 
					    data_path = util.ensure_path(data_path)
 | 
				
			||||||
    if not data_path.exists():
 | 
					    if not data_path.exists():
 | 
				
			||||||
        wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
 | 
					        msg.fail("Evaluation data not found", data_path, exits=1)
 | 
				
			||||||
    nlp = util.load_model(model)
 | 
					    nlp = util.load_model(model)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    if pipe_name not in nlp.component_names:
 | 
					    if pipe_name not in nlp.component_names:
 | 
				
			||||||
        raise AttributeError(
 | 
					        err = Errors.E001.format(name=pipe_name, opts=nlp.component_names)
 | 
				
			||||||
            Errors.E001.format(name=pipe_name, opts=nlp.component_names)
 | 
					        raise AttributeError(err)
 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    pipe = nlp.get_pipe(pipe_name)
 | 
					    pipe = nlp.get_pipe(pipe_name)
 | 
				
			||||||
    if not hasattr(pipe, "scorer"):
 | 
					    if not hasattr(pipe, "scorer"):
 | 
				
			||||||
        raise AttributeError(Errors.E1045)
 | 
					        raise AttributeError(Errors.E1045)
 | 
				
			||||||
    if type(pipe) == TextCategorizer:
 | 
					    if type(pipe) == TextCategorizer:
 | 
				
			||||||
        wasabi.msg.warn(
 | 
					        msg.warn(
 | 
				
			||||||
            "The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
 | 
					            "The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
 | 
				
			||||||
            "exclusive classes. All thresholds will yield the same results."
 | 
					            "exclusive classes. All thresholds will yield the same results."
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    if not silent:
 | 
					    if not silent:
 | 
				
			||||||
        wasabi.msg.info(
 | 
					        text = f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} trials."
 | 
				
			||||||
            title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
 | 
					        msg.info(text)
 | 
				
			||||||
            f"trials."
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    # Load evaluation corpus.
 | 
					    # Load evaluation corpus.
 | 
				
			||||||
    corpus = Corpus(data_path, gold_preproc=gold_preproc)
 | 
					    corpus = Corpus(data_path, gold_preproc=gold_preproc)
 | 
				
			||||||
    dev_dataset = list(corpus(nlp))
 | 
					    dev_dataset = list(corpus(nlp))
 | 
				
			||||||
| 
						 | 
					@ -155,9 +148,9 @@ def find_threshold(
 | 
				
			||||||
        RETURNS (Dict[str, Any]): Filtered dictionary.
 | 
					        RETURNS (Dict[str, Any]): Filtered dictionary.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if keys[0] not in config:
 | 
					        if keys[0] not in config:
 | 
				
			||||||
            wasabi.msg.fail(
 | 
					            msg.fail(
 | 
				
			||||||
                title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
 | 
					                f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
 | 
				
			||||||
                text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
 | 
					                f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
 | 
				
			||||||
                f"{list(config.keys())}",
 | 
					                f"{list(config.keys())}",
 | 
				
			||||||
                exits=1,
 | 
					                exits=1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					@ -172,7 +165,7 @@ def find_threshold(
 | 
				
			||||||
    config_keys_full = ["components", pipe_name, *config_keys]
 | 
					    config_keys_full = ["components", pipe_name, *config_keys]
 | 
				
			||||||
    table_col_widths = (10, 10)
 | 
					    table_col_widths = (10, 10)
 | 
				
			||||||
    thresholds = numpy.linspace(0, 1, n_trials)
 | 
					    thresholds = numpy.linspace(0, 1, n_trials)
 | 
				
			||||||
    print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
 | 
					    print(row(["Threshold", f"{scores_key}"], widths=table_col_widths))
 | 
				
			||||||
    for threshold in thresholds:
 | 
					    for threshold in thresholds:
 | 
				
			||||||
        # Reload pipeline with overrides specifying the new threshold.
 | 
					        # Reload pipeline with overrides specifying the new threshold.
 | 
				
			||||||
        nlp = util.load_model(
 | 
					        nlp = util.load_model(
 | 
				
			||||||
| 
						 | 
					@ -191,34 +184,28 @@ def find_threshold(
 | 
				
			||||||
                "cfg",
 | 
					                "cfg",
 | 
				
			||||||
                set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
 | 
					                set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					 | 
				
			||||||
        eval_scores = nlp.evaluate(dev_dataset)
 | 
					        eval_scores = nlp.evaluate(dev_dataset)
 | 
				
			||||||
        if scores_key not in eval_scores:
 | 
					        if scores_key not in eval_scores:
 | 
				
			||||||
            wasabi.msg.fail(
 | 
					            msg.fail(
 | 
				
			||||||
                title=f"Failed to look up score `{scores_key}` in evaluation results.",
 | 
					                title=f"Failed to look up score `{scores_key}` in evaluation results.",
 | 
				
			||||||
                text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
 | 
					                text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
 | 
				
			||||||
                f"available: {list(eval_scores.keys())}",
 | 
					                f"available: {list(eval_scores.keys())}",
 | 
				
			||||||
                exits=1,
 | 
					                exits=1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        scores[threshold] = eval_scores[scores_key]
 | 
					        scores[threshold] = eval_scores[scores_key]
 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not isinstance(scores[threshold], (float, int)):
 | 
					        if not isinstance(scores[threshold], (float, int)):
 | 
				
			||||||
            wasabi.msg.fail(
 | 
					            msg.fail(
 | 
				
			||||||
                f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
 | 
					                f"Returned score for key '{scores_key}' is not numeric. Threshold "
 | 
				
			||||||
                f"scores.",
 | 
					                f"optimization only works for numeric scores.",
 | 
				
			||||||
                exits=1,
 | 
					                exits=1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        print(
 | 
					        data = [round(threshold, 3), round(scores[threshold], 3)]
 | 
				
			||||||
            wasabi.row(
 | 
					        print(row(data, widths=table_col_widths))
 | 
				
			||||||
                [round(threshold, 3), round(scores[threshold], 3)],
 | 
					 | 
				
			||||||
                widths=table_col_widths,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
 | 
					    best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
 | 
				
			||||||
    # If all scores are identical, emit warning.
 | 
					    # If all scores are identical, emit warning.
 | 
				
			||||||
    if len(set(scores.values())) == 1:
 | 
					    if len(set(scores.values())) == 1:
 | 
				
			||||||
        wasabi.msg.warn(
 | 
					        msg.warn(
 | 
				
			||||||
            title="All scores are identical. Verify that all settings are correct.",
 | 
					            "All scores are identical. Verify that all settings are correct.",
 | 
				
			||||||
            text=""
 | 
					            text=""
 | 
				
			||||||
            if (
 | 
					            if (
 | 
				
			||||||
                not isinstance(pipe, MultiLabel_TextCategorizer)
 | 
					                not isinstance(pipe, MultiLabel_TextCategorizer)
 | 
				
			||||||
| 
						 | 
					@ -229,7 +216,7 @@ def find_threshold(
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        if not silent:
 | 
					        if not silent:
 | 
				
			||||||
            print(
 | 
					            print(
 | 
				
			||||||
                f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
 | 
					                f"\nBest threshold: {round(best_threshold, ndigits=4)} with "
 | 
				
			||||||
 | 
					                f"{scores_key} value of {scores[best_threshold]}."
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					 | 
				
			||||||
    return best_threshold, scores[best_threshold], scores
 | 
					    return best_threshold, scores[best_threshold], scores
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
from typing import Optional, List, Tuple, Literal
 | 
					from typing import Optional, List, Tuple, Literal, cast
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from wasabi import Printer, diff_strings
 | 
					from wasabi import Printer, msg, diff_strings
 | 
				
			||||||
from thinc.api import Config
 | 
					from thinc.api import Config
 | 
				
			||||||
import srsly
 | 
					import srsly
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
| 
						 | 
					@ -25,16 +25,17 @@ OptimizationsType = Literal["efficiency", "accuracy"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InitValues:
 | 
					class InitValues:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and
 | 
					    Default values for initialization. Dedicated class to allow synchronized
 | 
				
			||||||
    init_config(), i.e. initialization calls via CLI respectively Python.
 | 
					    default values for init_config_cli() and init_config(), i.e. initialization
 | 
				
			||||||
 | 
					    calls via CLI respectively Python.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lang = "en"
 | 
					    lang: str = "en"
 | 
				
			||||||
    pipeline = SimpleFrozenList(["tagger", "parser", "ner"])
 | 
					    pipeline: List[str] = SimpleFrozenList(["tagger", "parser", "ner"])
 | 
				
			||||||
    optimize = "efficiency"
 | 
					    optimize: OptimizationsType = "efficiency"
 | 
				
			||||||
    gpu = False
 | 
					    gpu: bool = False
 | 
				
			||||||
    pretraining = False
 | 
					    pretraining: bool = False
 | 
				
			||||||
    force_overwrite = False
 | 
					    force_overwrite: bool = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@cli.subcommand(
 | 
					@cli.subcommand(
 | 
				
			||||||
| 
						 | 
					@ -68,10 +69,11 @@ def init_config_cli(
 | 
				
			||||||
    DOCS: https://spacy.io/api/cli#init-config
 | 
					    DOCS: https://spacy.io/api/cli#init-config
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    is_stdout = output_file == "-"
 | 
					    is_stdout = output_file == "-"
 | 
				
			||||||
    if not is_stdout and output_file.exists() and not force_overwrite:
 | 
					    if not is_stdout:
 | 
				
			||||||
        msg = Printer()
 | 
					        if output_file.exists() and not force_overwrite:
 | 
				
			||||||
            msg.fail(
 | 
					            msg.fail(
 | 
				
			||||||
            "The provided output file already exists. To force overwriting the config file, set the --force or -F flag.",
 | 
					                "The provided output file already exists. To force overwriting "
 | 
				
			||||||
 | 
					                "the config file, set the --force or -F flag.",
 | 
				
			||||||
                exits=1,
 | 
					                exits=1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
    config = init_config(
 | 
					    config = init_config(
 | 
				
			||||||
| 
						 | 
					@ -239,7 +241,10 @@ def init_config(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def save_config(
 | 
					def save_config(
 | 
				
			||||||
    config: Config, output_file: Path, is_stdout: bool = False, silent: bool = False
 | 
					    config: Config,
 | 
				
			||||||
 | 
					    output_file: Path,
 | 
				
			||||||
 | 
					    is_stdout: bool = False,
 | 
				
			||||||
 | 
					    silent: bool = False,
 | 
				
			||||||
) -> None:
 | 
					) -> None:
 | 
				
			||||||
    no_print = is_stdout or silent
 | 
					    no_print = is_stdout or silent
 | 
				
			||||||
    msg = Printer(no_print=no_print)
 | 
					    msg = Printer(no_print=no_print)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,15 +5,13 @@ from typing import Tuple, List, Dict, Any
 | 
				
			||||||
import pkg_resources
 | 
					import pkg_resources
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					 | 
				
			||||||
import spacy
 | 
					 | 
				
			||||||
import numpy
 | 
					 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
import srsly
 | 
					import srsly
 | 
				
			||||||
from click import NoSuchOption
 | 
					from click import NoSuchOption
 | 
				
			||||||
from packaging.specifiers import SpecifierSet
 | 
					from packaging.specifiers import SpecifierSet
 | 
				
			||||||
from thinc.api import Config, ConfigValidationError
 | 
					from thinc.api import Config, ConfigValidationError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import spacy
 | 
				
			||||||
from spacy import about
 | 
					from spacy import about
 | 
				
			||||||
from spacy.cli import info
 | 
					from spacy.cli import info
 | 
				
			||||||
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
 | 
					from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
 | 
				
			||||||
| 
						 | 
					@ -745,13 +743,13 @@ def test_get_labels_from_model(factory_name, pipe_name):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_permitted_package_names():
 | 
					def test_permitted_package_names():
 | 
				
			||||||
    # https://www.python.org/dev/peps/pep-0426/#name
 | 
					    # https://www.python.org/dev/peps/pep-0426/#name
 | 
				
			||||||
    assert _is_permitted_package_name("Meine_Bäume") == False
 | 
					    assert _is_permitted_package_name("Meine_Bäume") is False
 | 
				
			||||||
    assert _is_permitted_package_name("_package") == False
 | 
					    assert _is_permitted_package_name("_package") is False
 | 
				
			||||||
    assert _is_permitted_package_name("package_") == False
 | 
					    assert _is_permitted_package_name("package_") is False
 | 
				
			||||||
    assert _is_permitted_package_name(".package") == False
 | 
					    assert _is_permitted_package_name(".package") is False
 | 
				
			||||||
    assert _is_permitted_package_name("package.") == False
 | 
					    assert _is_permitted_package_name("package.") is False
 | 
				
			||||||
    assert _is_permitted_package_name("-package") == False
 | 
					    assert _is_permitted_package_name("-package") is False
 | 
				
			||||||
    assert _is_permitted_package_name("package-") == False
 | 
					    assert _is_permitted_package_name("package-") is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_debug_data_compile_gold():
 | 
					def test_debug_data_compile_gold():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user