Merge branch 'develop' into nightly.spacy.io

This commit is contained in:
Ines Montani 2020-08-06 00:23:26 +02:00
commit c1095e2f22
137 changed files with 3881 additions and 2153 deletions

View File

@ -16,7 +16,7 @@ from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import Example from spacy.gold import Example
from spacy.util import compounding, minibatch, minibatch_by_words from spacy.util import compounding, minibatch, minibatch_by_words
from spacy.syntax.nonproj import projectivize from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy import displacy from spacy import displacy
from collections import defaultdict from collections import defaultdict

View File

@ -1,37 +1,46 @@
# Training hyper-parameters and additional features. [paths]
[training] train = ""
# Whether to train on sequences with 'gold standard' sentence boundaries dev = ""
# and tokens. If you set this to true, take care to ensure your run-time raw = null
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length or number of examples.
max_length = 5000
limit = 0
# Data augmentation
orth_variant_level = 0.0
dropout = 0.1
# Controls early-stopping. 0 or -1 mean unlimited.
patience = 1600
max_epochs = 0
max_steps = 20000
eval_frequency = 200
# Other settings
seed = 0
accumulate_gradient = 1
use_pytorch_for_gpu_memory = false
# Control how scores are printed and checkpoints are evaluated.
eval_batch_size = 128
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
init_tok2vec = null init_tok2vec = null
discard_oversize = false
batch_by = "words"
raw_text = null
tag_map = null
vectors = null
base_model = null
morph_rules = null
[training.batch_size] [system]
seed = 0
use_pytorch_for_gpu_memory = false
[training]
seed = ${system:seed}
dropout = 0.1
init_tok2vec = ${paths:init_tok2vec}
vectors = null
accumulate_gradient = 1
max_steps = 0
max_epochs = 0
patience = 10000
eval_frequency = 200
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
frozen_components = []
[training.train_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
gold_preproc = true
max_length = 0
limit = 0
[training.dev_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
gold_preproc = ${training.read_train:gold_preproc}
max_length = 0
limit = 0
[training.batcher]
@batchers = "batch_by_words.v1"
discard_oversize = false
tolerance = 0.2
[training.batcher.size]
@schedules = "compounding.v1" @schedules = "compounding.v1"
start = 100 start = 100
stop = 1000 stop = 1000

View File

@ -1,30 +1,45 @@
[paths]
train = ""
dev = ""
raw = null
init_tok2vec = null
[system]
seed = 0
use_pytorch_for_gpu_memory = false
[training] [training]
seed = ${system:seed}
dropout = 0.2
init_tok2vec = ${paths:init_tok2vec}
vectors = null
accumulate_gradient = 1
max_steps = 0 max_steps = 0
max_epochs = 0
patience = 10000 patience = 10000
eval_frequency = 200 eval_frequency = 200
dropout = 0.2 score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
init_tok2vec = null
vectors = null [training.read_train]
max_epochs = 100 @readers = "spacy.Corpus.v1"
orth_variant_level = 0.0 path = ${paths:train}
gold_preproc = true gold_preproc = true
max_length = 0 max_length = 0
scores = ["tag_acc", "dep_uas", "dep_las", "speed"]
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
limit = 0 limit = 0
seed = 0
accumulate_gradient = 1 [training.read_dev]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
gold_preproc = ${training.read_train:gold_preproc}
max_length = 0
limit = 0
[training.batcher]
@batchers = "batch_by_words.v1"
discard_oversize = false discard_oversize = false
raw_text = null tolerance = 0.2
tag_map = null
morph_rules = null
base_model = null
eval_batch_size = 128 [training.batcher.size]
use_pytorch_for_gpu_memory = false
batch_by = "words"
[training.batch_size]
@schedules = "compounding.v1" @schedules = "compounding.v1"
start = 100 start = 100
stop = 1000 stop = 1000

View File

@ -13,7 +13,7 @@ import spacy
import spacy.util import spacy.util
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import Example from spacy.gold import Example
from spacy.syntax.nonproj import projectivize from spacy.pipeline._parser_internals.nonproj import projectivize
from collections import defaultdict from collections import defaultdict
from spacy.matcher import Matcher from spacy.matcher import Matcher

View File

@ -48,7 +48,8 @@ def main(model, output_dir=None):
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality. # You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
# For simplicity, we'll just use the original vector dimension here instead. # For simplicity, we'll just use the original vector dimension here instead.
vectors_dim = nlp.vocab.vectors.shape[1] vectors_dim = nlp.vocab.vectors.shape[1]
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=vectors_dim) kb = KnowledgeBase(entity_vector_length=vectors_dim)
kb.initialize(nlp.vocab)
# set up the data # set up the data
entity_ids = [] entity_ids = []
@ -95,7 +96,8 @@ def main(model, output_dir=None):
print("Loading vocab from", vocab_path) print("Loading vocab from", vocab_path)
print("Loading KB from", kb_path) print("Loading KB from", kb_path)
vocab2 = Vocab().from_disk(vocab_path) vocab2 = Vocab().from_disk(vocab_path)
kb2 = KnowledgeBase(vocab=vocab2) kb2 = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab2)
kb2.load_bulk(kb_path) kb2.load_bulk(kb_path)
print() print()
_print_kb(kb2) _print_kb(kb2)

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a19,<8.0.0a30", "thinc>=8.0.0a22,<8.0.0a30",
"blis>=0.4.0,<0.5.0", "blis>=0.4.0,<0.5.0",
"pytokenizations", "pytokenizations",
"smart_open>=2.0.0,<3.0.0" "smart_open>=2.0.0,<3.0.0"

View File

@ -1,7 +1,7 @@
# Our libraries # Our libraries
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a19,<8.0.0a30 thinc>=8.0.0a22,<8.0.0a30
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1 ml_datasets>=0.1.1
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a19,<8.0.0a30 thinc>=8.0.0a22,<8.0.0a30
install_requires = install_requires =
# Our libraries # Our libraries
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a19,<8.0.0a30 thinc>=8.0.0a22,<8.0.0a30
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
wasabi>=0.7.1,<1.1.0 wasabi>=0.7.1,<1.1.0
srsly>=2.1.0,<3.0.0 srsly>=2.1.0,<3.0.0

View File

@ -31,6 +31,7 @@ MOD_NAMES = [
"spacy.vocab", "spacy.vocab",
"spacy.attrs", "spacy.attrs",
"spacy.kb", "spacy.kb",
"spacy.ml.parser_model",
"spacy.morphology", "spacy.morphology",
"spacy.pipeline.dep_parser", "spacy.pipeline.dep_parser",
"spacy.pipeline.morphologizer", "spacy.pipeline.morphologizer",
@ -40,14 +41,14 @@ MOD_NAMES = [
"spacy.pipeline.sentencizer", "spacy.pipeline.sentencizer",
"spacy.pipeline.senter", "spacy.pipeline.senter",
"spacy.pipeline.tagger", "spacy.pipeline.tagger",
"spacy.syntax.stateclass", "spacy.pipeline.transition_parser",
"spacy.syntax._state", "spacy.pipeline._parser_internals.arc_eager",
"spacy.pipeline._parser_internals.ner",
"spacy.pipeline._parser_internals.nonproj",
"spacy.pipeline._parser_internals._state",
"spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system",
"spacy.tokenizer", "spacy.tokenizer",
"spacy.syntax.nn_parser",
"spacy.syntax._parser_model",
"spacy.syntax.nonproj",
"spacy.syntax.transition_system",
"spacy.syntax.arc_eager",
"spacy.gold.gold_io", "spacy.gold.gold_io",
"spacy.tokens.doc", "spacy.tokens.doc",
"spacy.tokens.span", "spacy.tokens.span",
@ -57,7 +58,6 @@ MOD_NAMES = [
"spacy.matcher.matcher", "spacy.matcher.matcher",
"spacy.matcher.phrasematcher", "spacy.matcher.phrasematcher",
"spacy.matcher.dependencymatcher", "spacy.matcher.dependencymatcher",
"spacy.syntax.ner",
"spacy.symbols", "spacy.symbols",
"spacy.vectors", "spacy.vectors",
] ]

View File

@ -8,6 +8,7 @@ warnings.filterwarnings("ignore", message="numpy.ufunc size changed") # noqa
# These are imported as part of the API # These are imported as part of the API
from thinc.api import prefer_gpu, require_gpu # noqa: F401 from thinc.api import prefer_gpu, require_gpu # noqa: F401
from thinc.api import Config
from . import pipeline # noqa: F401 from . import pipeline # noqa: F401
from .cli.info import info # noqa: F401 from .cli.info import info # noqa: F401
@ -26,17 +27,17 @@ if sys.maxunicode == 65535:
def load( def load(
name: Union[str, Path], name: Union[str, Path],
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = util.SimpleFrozenDict(), config: Union[Dict[str, Any], Config] = util.SimpleFrozenDict(),
) -> Language: ) -> Language:
"""Load a spaCy model from an installed package or a local path. """Load a spaCy model from an installed package or a local path.
name (str): Package name or model path. name (str): Package name or model path.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components, config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by component names. keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
return util.load_model(name, disable=disable, component_cfg=component_cfg) return util.load_model(name, disable=disable, config=config)
def blank(name: str, **overrides) -> Language: def blank(name: str, **overrides) -> Language:

View File

@ -15,6 +15,7 @@ from .debug_model import debug_model # noqa: F401
from .evaluate import evaluate # noqa: F401 from .evaluate import evaluate # noqa: F401
from .convert import convert # noqa: F401 from .convert import convert # noqa: F401
from .init_model import init_model # noqa: F401 from .init_model import init_model # noqa: F401
from .init_config import init_config # noqa: F401
from .validate import validate # noqa: F401 from .validate import validate # noqa: F401
from .project.clone import project_clone # noqa: F401 from .project.clone import project_clone # noqa: F401
from .project.assets import project_assets # noqa: F401 from .project.assets import project_assets # noqa: F401

View File

@ -6,7 +6,7 @@ import hashlib
import typer import typer
from typer.main import get_command from typer.main import get_command
from contextlib import contextmanager from contextlib import contextmanager
from thinc.config import ConfigValidationError from thinc.config import Config, ConfigValidationError
from configparser import InterpolationError from configparser import InterpolationError
import sys import sys
@ -31,6 +31,7 @@ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
commands to check and validate your config files, training and evaluation data, commands to check and validate your config files, training and evaluation data,
and custom model implementations. and custom model implementations.
""" """
INIT_HELP = """Commands for initializing configs and models."""
# Wrappers for Typer's annotations. Initially created to set defaults and to # Wrappers for Typer's annotations. Initially created to set defaults and to
# keep the names short, but not needed at the moment. # keep the names short, but not needed at the moment.
@ -40,9 +41,11 @@ Opt = typer.Option
app = typer.Typer(name=NAME, help=HELP) app = typer.Typer(name=NAME, help=HELP)
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True) project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True) debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
app.add_typer(project_cli) app.add_typer(project_cli)
app.add_typer(debug_cli) app.add_typer(debug_cli)
app.add_typer(init_cli)
def setup_cli() -> None: def setup_cli() -> None:
@ -172,16 +175,34 @@ def get_checksum(path: Union[Path, str]) -> str:
@contextmanager @contextmanager
def show_validation_error(title: str = "Config validation error"): def show_validation_error(
file_path: Optional[Union[str, Path]] = None,
*,
title: str = "Config validation error",
hint_init: bool = True,
):
"""Helper to show custom config validation errors on the CLI. """Helper to show custom config validation errors on the CLI.
file_path (str / Path): Optional file path of config file, used in hints.
title (str): Title of the custom formatted error. title (str): Title of the custom formatted error.
hint_init (bool): Show hint about filling config.
""" """
try: try:
yield yield
except (ConfigValidationError, InterpolationError) as e: except (ConfigValidationError, InterpolationError) as e:
msg.fail(title, spaced=True) msg.fail(title, spaced=True)
print(str(e).replace("Config validation error", "").strip()) # TODO: This is kinda hacky and we should probably provide a better
# helper for this in Thinc
err_text = str(e).replace("Config validation error", "").strip()
print(err_text)
if hint_init and "field required" in err_text:
config_path = file_path if file_path is not None else "config.cfg"
msg.text(
"If your config contains missing values, you can run the 'init "
"config' command to fill in all the defaults, if possible:",
spaced=True,
)
print(f"{COMMAND} init config {config_path} --base {config_path}\n")
sys.exit(1) sys.exit(1)
@ -196,3 +217,15 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
import_file("python_code", code_path) import_file("python_code", code_path)
except Exception as e: except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return [
name
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
]

View File

@ -8,9 +8,9 @@ import typer
from thinc.api import Config from thinc.api import Config
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli from ._util import import_code, debug_cli, get_sourced_components
from ..gold import Corpus, Example from ..gold import Corpus, Example
from ..syntax import nonproj from ..pipeline._parser_internals import nonproj
from ..language import Language from ..language import Language
from .. import util from .. import util
@ -33,7 +33,6 @@ def debug_config_cli(
ctx: typer.Context, # This is only used to read additional arguments ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True), config_path: Path = Arg(..., help="Path to config file", exists=True),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"), code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
output_path: Optional[Path] = Opt(None, "--output", "-o", help="Output path for filled config or '-' for standard output", allow_dash=True),
auto_fill: bool = Opt(False, "--auto-fill", "-F", help="Whether or not to auto-fill the config with built-in defaults if possible"), auto_fill: bool = Opt(False, "--auto-fill", "-F", help="Whether or not to auto-fill the config with built-in defaults if possible"),
diff: bool = Opt(False, "--diff", "-D", help="Show a visual diff if config was auto-filled") diff: bool = Opt(False, "--diff", "-D", help="Show a visual diff if config was auto-filled")
# fmt: on # fmt: on
@ -49,15 +48,12 @@ def debug_config_cli(
""" """
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
import_code(code_path) import_code(code_path)
with show_validation_error(): with show_validation_error(config_path):
config = Config().from_disk(config_path) config = Config().from_disk(config_path, overrides=overrides)
try: try:
nlp, _ = util.load_model_from_config( nlp, _ = util.load_model_from_config(config, auto_fill=auto_fill)
config, overrides=overrides, auto_fill=auto_fill
)
except ValueError as e: except ValueError as e:
msg.fail(str(e), exits=1) msg.fail(str(e), exits=1)
is_stdout = output_path is not None and str(output_path) == "-"
if auto_fill: if auto_fill:
orig_config = config.to_str() orig_config = config.to_str()
filled_config = nlp.config.to_str() filled_config = nlp.config.to_str()
@ -68,12 +64,7 @@ def debug_config_cli(
if diff: if diff:
print(diff_strings(config.to_str(), nlp.config.to_str())) print(diff_strings(config.to_str(), nlp.config.to_str()))
else: else:
msg.good("Original config is valid", show=not is_stdout) msg.good("Original config is valid")
if is_stdout:
print(nlp.config.to_str())
elif output_path is not None:
nlp.config.to_disk(output_path)
msg.good(f"Saved updated config to {output_path}")
@debug_cli.command( @debug_cli.command(
@ -142,12 +133,13 @@ def debug_data(
msg.fail("Development data not found", dev_path, exits=1) msg.fail("Development data not found", dev_path, exits=1)
if not config_path.exists(): if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1) msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(): with show_validation_error(config_path):
cfg = Config().from_disk(config_path) cfg = Config().from_disk(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides) nlp, config = util.load_model_from_config(cfg)
# TODO: handle base model # Use original config here, not resolved version
lang = config["nlp"]["lang"] sourced_components = get_sourced_components(cfg)
base_model = config["training"]["base_model"] frozen_components = config["training"]["frozen_components"]
resume_components = [p for p in sourced_components if p not in frozen_components]
pipeline = nlp.pipe_names pipeline = nlp.pipe_names
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
tag_map_path = util.ensure_path(config["training"]["tag_map"]) tag_map_path = util.ensure_path(config["training"]["tag_map"])
@ -169,13 +161,12 @@ def debug_data(
loading_train_error_message = "" loading_train_error_message = ""
loading_dev_error_message = "" loading_dev_error_message = ""
with msg.loading("Loading corpus..."): with msg.loading("Loading corpus..."):
corpus = Corpus(train_path, dev_path)
try: try:
train_dataset = list(corpus.train_dataset(nlp)) train_dataset = list(Corpus(train_path)(nlp))
except ValueError as e: except ValueError as e:
loading_train_error_message = f"Training data cannot be loaded: {e}" loading_train_error_message = f"Training data cannot be loaded: {e}"
try: try:
dev_dataset = list(corpus.dev_dataset(nlp)) dev_dataset = list(Corpus(dev_path)(nlp))
except ValueError as e: except ValueError as e:
loading_dev_error_message = f"Development data cannot be loaded: {e}" loading_dev_error_message = f"Development data cannot be loaded: {e}"
if loading_train_error_message or loading_dev_error_message: if loading_train_error_message or loading_dev_error_message:
@ -195,13 +186,15 @@ def debug_data(
train_texts = gold_train_data["texts"] train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"] dev_texts = gold_dev_data["texts"]
frozen_components = config["training"]["frozen_components"]
msg.divider("Training stats") msg.divider("Training stats")
msg.text(f"Language: {config['nlp']['lang']}")
msg.text(f"Training pipeline: {', '.join(pipeline)}") msg.text(f"Training pipeline: {', '.join(pipeline)}")
if base_model: if resume_components:
msg.text(f"Starting with base model '{base_model}'") msg.text(f"Components from other models: {', '.join(resume_components)}")
else: if frozen_components:
msg.text(f"Starting with blank model '{lang}'") msg.text(f"Frozen components: {', '.join(frozen_components)}")
msg.text(f"{len(train_dataset)} training docs") msg.text(f"{len(train_dataset)} training docs")
msg.text(f"{len(dev_dataset)} evaluation docs") msg.text(f"{len(dev_dataset)} evaluation docs")
@ -212,7 +205,9 @@ def debug_data(
msg.warn(f"{overlap} training examples also in evaluation data") msg.warn(f"{overlap} training examples also in evaluation data")
else: else:
msg.good("No overlap between training and evaluation data") msg.good("No overlap between training and evaluation data")
if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD: # TODO: make this feedback more fine-grained and report on updated
# components vs. blank components
if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD:
text = ( text = (
f"Low number of examples to train from a blank model ({len(train_dataset)})" f"Low number of examples to train from a blank model ({len(train_dataset)})"
) )

View File

@ -2,13 +2,11 @@ from typing import Dict, Any, Optional
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam, Config from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam, Config
from thinc.api import Model from thinc.api import Model, data_validation
import typer import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
from .. import util from .. import util
from ..lang.en import English
from ..util import dot_to_object
@debug_cli.command("model") @debug_cli.command("model")
@ -16,7 +14,7 @@ def debug_model_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True), config_path: Path = Arg(..., help="Path to config file", exists=True),
section: str = Arg(..., help="Section that defines the model to be analysed"), component: str = Arg(..., help="Name of the pipeline component of which the model should be analysed"),
layers: str = Opt("", "--layers", "-l", help="Comma-separated names of layer IDs to print"), layers: str = Opt("", "--layers", "-l", help="Comma-separated names of layer IDs to print"),
dimensions: bool = Opt(False, "--dimensions", "-DIM", help="Show dimensions"), dimensions: bool = Opt(False, "--dimensions", "-DIM", help="Show dimensions"),
parameters: bool = Opt(False, "--parameters", "-PAR", help="Show parameters"), parameters: bool = Opt(False, "--parameters", "-PAR", help="Show parameters"),
@ -25,7 +23,7 @@ def debug_model_cli(
P0: bool = Opt(False, "--print-step0", "-P0", help="Print model before training"), P0: bool = Opt(False, "--print-step0", "-P0", help="Print model before training"),
P1: bool = Opt(False, "--print-step1", "-P1", help="Print model after initialization"), P1: bool = Opt(False, "--print-step1", "-P1", help="Print model after initialization"),
P2: bool = Opt(False, "--print-step2", "-P2", help="Print model after training"), P2: bool = Opt(False, "--print-step2", "-P2", help="Print model after training"),
P3: bool = Opt(True, "--print-step3", "-P3", help="Print final predictions"), P3: bool = Opt(False, "--print-step3", "-P3", help="Print final predictions"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU") use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
# fmt: on # fmt: on
): ):
@ -50,10 +48,10 @@ def debug_model_cli(
"print_prediction": P3, "print_prediction": P3,
} }
config_overrides = parse_config_overrides(ctx.args) config_overrides = parse_config_overrides(ctx.args)
cfg = Config().from_disk(config_path) with show_validation_error(config_path):
with show_validation_error(): cfg = Config().from_disk(config_path, overrides=config_overrides)
try: try:
_, config = util.load_model_from_config(cfg, overrides=config_overrides) nlp, config = util.load_model_from_config(cfg)
except ValueError as e: except ValueError as e:
msg.fail(str(e), exits=1) msg.fail(str(e), exits=1)
seed = config["pretraining"]["seed"] seed = config["pretraining"]["seed"]
@ -61,12 +59,12 @@ def debug_model_cli(
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed) fix_random_seed(seed)
component = dot_to_object(config, section) pipe = nlp.get_pipe(component)
if hasattr(component, "model"): if hasattr(pipe, "model"):
model = component.model model = pipe.model
else: else:
msg.fail( msg.fail(
f"The section '{section}' does not specify an object that holds a Model.", f"The component '{component}' does not specify an object that holds a Model.",
exits=1, exits=1,
) )
debug_model(model, print_settings=print_settings) debug_model(model, print_settings=print_settings)
@ -84,15 +82,17 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
# STEP 0: Printing before training # STEP 0: Printing before training
msg.info(f"Analysing model with ID {model.id}") msg.info(f"Analysing model with ID {model.id}")
if print_settings.get("print_before_training"): if print_settings.get("print_before_training"):
msg.info(f"Before training:") msg.divider(f"STEP 0 - before training")
_print_model(model, print_settings) _print_model(model, print_settings)
# STEP 1: Initializing the model and printing again # STEP 1: Initializing the model and printing again
Y = _get_output(model.ops.xp) Y = _get_output(model.ops.xp)
_set_output_dim(nO=Y.shape[-1], model=model) _set_output_dim(nO=Y.shape[-1], model=model)
model.initialize(X=_get_docs(), Y=Y) # The output vector might differ from the official type of the output layer
with data_validation(False):
model.initialize(X=_get_docs(), Y=Y)
if print_settings.get("print_after_init"): if print_settings.get("print_after_init"):
msg.info(f"After initialization:") msg.divider(f"STEP 1 - after initialization")
_print_model(model, print_settings) _print_model(model, print_settings)
# STEP 2: Updating the model and printing again # STEP 2: Updating the model and printing again
@ -104,13 +104,14 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
get_dX(dY) get_dX(dY)
model.finish_update(optimizer) model.finish_update(optimizer)
if print_settings.get("print_after_training"): if print_settings.get("print_after_training"):
msg.info(f"After training:") msg.divider(f"STEP 2 - after training")
_print_model(model, print_settings) _print_model(model, print_settings)
# STEP 3: the final prediction # STEP 3: the final prediction
prediction = model.predict(_get_docs()) prediction = model.predict(_get_docs())
if print_settings.get("print_prediction"): if print_settings.get("print_prediction"):
msg.info(f"Prediction:", str(prediction)) msg.divider(f"STEP 3 - prediction")
msg.info(str(prediction))
def get_gradient(model, Y): def get_gradient(model, Y):
@ -127,8 +128,8 @@ def _sentences():
] ]
def _get_docs(): def _get_docs(lang: str = "en"):
nlp = English() nlp = util.get_lang_class(lang)()
return list(nlp.pipe(_sentences())) return list(nlp.pipe(_sentences()))

View File

@ -7,23 +7,7 @@ import typer
from ._util import app, Arg, Opt from ._util import app, Arg, Opt
from .. import about from .. import about
from ..util import is_package, get_base_version, run_command from ..util import is_package, get_base_version, run_command
from ..errors import OLD_MODEL_SHORTCUTS
# These are the old shortcuts we previously supported in spacy download. As of
# v3, shortcuts are deprecated so we're not expecting to add anything to this
# list. It only exists to show users warnings.
OLD_SHORTCUTS = {
"en": "en_core_web_sm",
"de": "de_core_news_sm",
"es": "es_core_news_sm",
"pt": "pt_core_news_sm",
"fr": "fr_core_news_sm",
"it": "it_core_news_sm",
"nl": "nl_core_news_sm",
"el": "el_core_news_sm",
"nb": "nb_core_news_sm",
"lt": "lt_core_news_sm",
"xx": "xx_ent_wiki_sm",
}
@app.command( @app.command(
@ -66,12 +50,12 @@ def download(model: str, direct: bool = False, *pip_args) -> None:
download_model(dl_tpl.format(m=model_name, v=version), pip_args) download_model(dl_tpl.format(m=model_name, v=version), pip_args)
else: else:
model_name = model model_name = model
if model in OLD_SHORTCUTS: if model in OLD_MODEL_SHORTCUTS:
msg.warn( msg.warn(
f"As of spaCy v3.0, shortcuts like '{model}' are deprecated. " f"As of spaCy v3.0, shortcuts like '{model}' are deprecated. Please"
f"Please use the full model name '{OLD_SHORTCUTS[model]}' instead." f"use the full model name '{OLD_MODEL_SHORTCUTS[model]}' instead."
) )
model_name = OLD_SHORTCUTS[model] model_name = OLD_MODEL_SHORTCUTS[model]
compatibility = get_compatibility() compatibility = get_compatibility()
version = get_version(model_name, compatibility) version = get_version(model_name, compatibility)
download_model(dl_tpl.format(m=model_name, v=version), pip_args) download_model(dl_tpl.format(m=model_name, v=version), pip_args)

View File

@ -1,5 +1,4 @@
from typing import Optional, List, Dict from typing import Optional, List, Dict
from timeit import default_timer as timer
from wasabi import Printer from wasabi import Printer
from pathlib import Path from pathlib import Path
import re import re
@ -64,9 +63,9 @@ def evaluate(
msg.fail("Evaluation data not found", data_path, exits=1) msg.fail("Evaluation data not found", data_path, exits=1)
if displacy_path and not displacy_path.exists(): if displacy_path and not displacy_path.exists():
msg.fail("Visualization output directory not found", displacy_path, exits=1) msg.fail("Visualization output directory not found", displacy_path, exits=1)
corpus = Corpus(data_path, data_path) corpus = Corpus(data_path, gold_preproc=gold_preproc)
nlp = util.load_model(model) nlp = util.load_model(model)
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc)) dev_dataset = list(corpus(nlp))
scores = nlp.evaluate(dev_dataset, verbose=False) scores = nlp.evaluate(dev_dataset, verbose=False)
metrics = { metrics = {
"TOK": "token_acc", "TOK": "token_acc",

81
spacy/cli/init_config.py Normal file
View File

@ -0,0 +1,81 @@
from typing import Optional, List
from pathlib import Path
from thinc.api import Config
from wasabi import msg
from ..util import load_model_from_config, get_lang_class, load_model
from ._util import init_cli, Arg, Opt, show_validation_error
@init_cli.command("config")
def init_config_cli(
# fmt: off
output_path: Path = Arg("-", help="Output path or - for stdout", allow_dash=True),
base_path: Optional[Path] = Opt(None, "--base", "-b", help="Optional base config to fill", exists=True, dir_okay=False),
model: Optional[str] = Opt(None, "--model", "-m", help="Optional model to copy config from"),
lang: Optional[str] = Opt(None, "--lang", "-l", help="Optional language code for blank config"),
pipeline: Optional[str] = Opt(None, "--pipeline", "-p", help="Optional pipeline components to use")
# fmt: on
):
"""Generate a starter config.cfg for training."""
validate_cli_args(base_path, model, lang)
is_stdout = str(output_path) == "-"
pipeline = [p.strip() for p in pipeline.split(",")] if pipeline else []
cfg = init_config(output_path, base_path, model, lang, pipeline, silent=is_stdout)
if is_stdout:
print(cfg.to_str())
else:
cfg.to_disk(output_path)
msg.good("Saved config", output_path)
def init_config(
output_path: Path,
config_path: Optional[Path],
model: Optional[str],
lang: Optional[str],
pipeline: Optional[List[str]],
silent: bool = False,
) -> Config:
if config_path is not None:
msg.info("Generating config from base config", show=not silent)
with show_validation_error(config_path, hint_init=False):
config = Config().from_disk(config_path)
try:
nlp, _ = load_model_from_config(config, auto_fill=True)
except ValueError as e:
msg.fail(str(e), exits=1)
return nlp.config
if model is not None:
ext = f" with pipeline {pipeline}" if pipeline else ""
msg.info(f"Generating config from model {model}{ext}", show=not silent)
nlp = load_model(model)
for existing_pipe_name in nlp.pipe_names:
if existing_pipe_name not in pipeline:
nlp.remove_pipe(existing_pipe_name)
for pipe_name in pipeline:
if pipe_name not in nlp.pipe_names:
nlp.add_pipe(pipe_name)
return nlp.config
if lang is not None:
ext = f" with pipeline {pipeline}" if pipeline else ""
msg.info(f"Generating config for language '{lang}'{ext}", show=not silent)
nlp = get_lang_class(lang)()
for pipe_name in pipeline:
nlp.add_pipe(pipe_name)
return nlp.config
def validate_cli_args(
config_path: Optional[Path], model: Optional[str], lang: Optional[str]
) -> None:
args = {"--base": config_path, "--model": model, "--lang": lang}
if sum(arg is not None for arg in args.values()) != 1:
existing = " ".join(f"{a} {v}" for a, v in args.items() if v is not None)
msg.fail(
"The init config command expects only one of the following arguments: "
"--base (base config to fill and update), --lang (language code to "
"use for blank config) or --model (base model to copy config from).",
f"Got: {existing if existing else 'no arguments'}",
exits=1,
)

View File

@ -10,14 +10,14 @@ import gzip
import zipfile import zipfile
import srsly import srsly
import warnings import warnings
from wasabi import Printer from wasabi import msg, Printer
import typer
from ._util import app, Arg, Opt from ._util import app, init_cli, Arg, Opt
from ..vectors import Vectors from ..vectors import Vectors
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..language import Language from ..language import Language
from ..util import ensure_path, get_lang_class, load_model, OOV_RANK from ..util import ensure_path, get_lang_class, load_model, OOV_RANK
from ..lookups import Lookups
try: try:
import ftfy import ftfy
@ -28,9 +28,15 @@ except ImportError:
DEFAULT_OOV_PROB = -20 DEFAULT_OOV_PROB = -20
@app.command("init-model") @init_cli.command("model")
@app.command(
"init-model",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
hidden=True, # hide this from main CLI help but still allow it to work with warning
)
def init_model_cli( def init_model_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments
lang: str = Arg(..., help="Model language"), lang: str = Arg(..., help="Model language"),
output_dir: Path = Arg(..., help="Model output directory"), output_dir: Path = Arg(..., help="Model output directory"),
freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file", exists=True), freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file", exists=True),
@ -48,6 +54,12 @@ def init_model_cli(
Create a new model from raw data. If vectors are provided in Word2Vec format, Create a new model from raw data. If vectors are provided in Word2Vec format,
they can be either a .txt or zipped as a .zip or .tar.gz. they can be either a .txt or zipped as a .zip or .tar.gz.
""" """
if ctx.command.name == "init-model":
msg.warn(
"The init-model command is now available via the 'init model' "
"subcommand (without the hyphen). You can run python -m spacy init "
"--help for an overview of the other available initialization commands."
)
init_model( init_model(
lang, lang,
output_dir, output_dir,

View File

@ -87,9 +87,9 @@ def pretrain(
else: else:
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
config = Config().from_disk(config_path) with show_validation_error(config_path):
with show_validation_error(): config = Config().from_disk(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(config, overrides=config_overrides) nlp, config = util.load_model_from_config(config)
# TODO: validate that [pretraining] block exists # TODO: validate that [pretraining] block exists
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()

View File

@ -1,7 +1,6 @@
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import tqdm
import re import re
import shutil import shutil
import requests import requests

View File

@ -11,10 +11,10 @@ import random
import typer import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code from ._util import import_code, get_sourced_components
from ..gold import Corpus, Example
from ..language import Language from ..language import Language
from .. import util from .. import util
from ..gold.example import Example
from ..errors import Errors from ..errors import Errors
@ -28,8 +28,6 @@ from ..ml import models # noqa: F401
def train_cli( def train_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments ctx: typer.Context, # This is only used to read additional arguments
train_path: Path = Arg(..., help="Location of training data", exists=True),
dev_path: Path = Arg(..., help="Location of development data", exists=True),
config_path: Path = Arg(..., help="Path to config file", exists=True), config_path: Path = Arg(..., help="Path to config file", exists=True),
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"), output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"), code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
@ -51,12 +49,11 @@ def train_cli(
referenced in the config. referenced in the config.
""" """
util.set_env_log(verbose) util.set_env_log(verbose)
verify_cli_args(train_path, dev_path, config_path, output_path) verify_cli_args(config_path, output_path)
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
import_code(code_path) import_code(code_path)
train( train(
config_path, config_path,
{"train": train_path, "dev": dev_path},
output_path=output_path, output_path=output_path,
config_overrides=overrides, config_overrides=overrides,
use_gpu=use_gpu, use_gpu=use_gpu,
@ -66,8 +63,6 @@ def train_cli(
def train( def train(
config_path: Path, config_path: Path,
data_paths: Dict[str, Path],
raw_text: Optional[Path] = None,
output_path: Optional[Path] = None, output_path: Optional[Path] = None,
config_overrides: Dict[str, Any] = {}, config_overrides: Dict[str, Any] = {},
use_gpu: int = -1, use_gpu: int = -1,
@ -79,41 +74,37 @@ def train(
else: else:
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}") msg.info(f"Loading config and nlp from: {config_path}")
config = Config().from_disk(config_path) with show_validation_error(config_path):
config = Config().from_disk(config_path, overrides=config_overrides)
if config.get("training", {}).get("seed") is not None: if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
with show_validation_error(): # Use original config here before it's resolved to functions
nlp, config = util.load_model_from_config(config, overrides=config_overrides) sourced_components = get_sourced_components(config)
if config["training"]["base_model"]: with show_validation_error(config_path):
# TODO: do something to check base_nlp against regular nlp described in config? nlp, config = util.load_model_from_config(config)
# If everything matches it will look something like:
# base_nlp = util.load_model(config["training"]["base_model"])
# nlp = base_nlp
raise NotImplementedError("base_model not supported yet.")
if config["training"]["vectors"] is not None: if config["training"]["vectors"] is not None:
util.load_vectors_into_model(nlp, config["training"]["vectors"]) util.load_vectors_into_model(nlp, config["training"]["vectors"])
verify_config(nlp) verify_config(nlp)
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
if config["training"]["use_pytorch_for_gpu_memory"]: if config.get("system", {}).get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this. # It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory() use_pytorch_for_gpu_memory()
training = config["training"] T_cfg = config["training"]
optimizer = training["optimizer"] optimizer = T_cfg["optimizer"]
limit = training["limit"] train_corpus = T_cfg["train_corpus"]
corpus = Corpus(data_paths["train"], data_paths["dev"], limit=limit) dev_corpus = T_cfg["dev_corpus"]
if resume_training: batcher = T_cfg["batcher"]
msg.info("Resuming training") # Components that shouldn't be updated during training
nlp.resume_training() frozen_components = T_cfg["frozen_components"]
else: # Sourced components that require resume_training
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}") resume_components = [p for p in sourced_components if p not in frozen_components]
train_examples = corpus.train_dataset( msg.info(f"Pipeline: {nlp.pipe_names}")
nlp, if resume_components:
shuffle=False, with nlp.select_pipes(enable=resume_components):
gold_preproc=training["gold_preproc"], msg.info(f"Resuming training for: {resume_components}")
max_length=training["max_length"], nlp.resume_training()
) with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
train_examples = list(train_examples) nlp.begin_training(lambda: train_corpus(nlp))
nlp.begin_training(lambda: train_examples)
if tag_map: if tag_map:
# Replace tag map with provided mapping # Replace tag map with provided mapping
@ -139,38 +130,36 @@ def train(
msg.fail(err, exits=1) msg.fail(err, exits=1)
tok2vec.from_bytes(weights_data) tok2vec.from_bytes(weights_data)
msg.info("Loading training corpus")
train_batches = create_train_batches(nlp, corpus, training)
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
# Create iterator, which yields out info after each optimization step. # Create iterator, which yields out info after each optimization step.
msg.info("Start training") msg.info("Start training")
score_weights = T_cfg["score_weights"]
training_step_iterator = train_while_improving( training_step_iterator = train_while_improving(
nlp, nlp,
optimizer, optimizer,
train_batches, create_train_batches(train_corpus(nlp), batcher, T_cfg["max_epochs"]),
evaluate, create_evaluation_callback(nlp, dev_corpus, score_weights),
dropout=training["dropout"], dropout=T_cfg["dropout"],
accumulate_gradient=training["accumulate_gradient"], accumulate_gradient=T_cfg["accumulate_gradient"],
patience=training["patience"], patience=T_cfg["patience"],
max_steps=training["max_steps"], max_steps=T_cfg["max_steps"],
eval_frequency=training["eval_frequency"], eval_frequency=T_cfg["eval_frequency"],
raw_text=raw_text, raw_text=None,
exclude=frozen_components,
) )
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(training, nlp) print_row = setup_printer(T_cfg, nlp)
try: try:
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
for batch, info, is_best_checkpoint in training_step_iterator: for batch, info, is_best_checkpoint in training_step_iterator:
progress.update(1) progress.update(1)
if is_best_checkpoint is not None: if is_best_checkpoint is not None:
progress.close() progress.close()
print_row(info) print_row(info)
if is_best_checkpoint and output_path is not None: if is_best_checkpoint and output_path is not None:
update_meta(training, nlp, info) update_meta(T_cfg, nlp, info)
nlp.to_disk(output_path / "model-best") nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
except Exception as e: except Exception as e:
if output_path is not None: if output_path is not None:
msg.warn( msg.warn(
@ -191,72 +180,32 @@ def train(
msg.good(f"Saved model to output directory {final_model_path}") msg.good(f"Saved model to output directory {final_model_path}")
def create_train_batches( def create_train_batches(iterator, batcher, max_epochs: int):
nlp: Language, corpus: Corpus, cfg: Union[Config, Dict[str, Any]] epoch = 1
): examples = []
max_epochs = cfg["max_epochs"] # Stream the first epoch, so we start training faster and support
train_examples = list( # infinite streams.
corpus.train_dataset( for batch in batcher(iterator):
nlp, yield epoch, batch
shuffle=True, if max_epochs != 1:
gold_preproc=cfg["gold_preproc"], examples.extend(batch)
max_length=cfg["max_length"], if not examples:
) # Raise error if no data
) raise ValueError(Errors.E986)
epoch = 0 while epoch != max_epochs:
batch_strategy = cfg["batch_by"] random.shuffle(examples)
while True: for batch in batcher(examples):
if len(train_examples) == 0:
raise ValueError(Errors.E988)
epoch += 1
if batch_strategy == "padded":
batches = util.minibatch_by_padded_size(
train_examples,
size=cfg["batch_size"],
buffer=256,
discard_oversize=cfg["discard_oversize"],
)
elif batch_strategy == "words":
batches = util.minibatch_by_words(
train_examples,
size=cfg["batch_size"],
discard_oversize=cfg["discard_oversize"],
)
else:
batches = util.minibatch(train_examples, size=cfg["batch_size"])
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
try:
first = next(batches)
yield epoch, first
except StopIteration:
raise ValueError(Errors.E986)
for batch in batches:
yield epoch, batch yield epoch, batch
if max_epochs >= 1 and epoch >= max_epochs: epoch += 1
break
random.shuffle(train_examples)
def create_evaluation_callback( def create_evaluation_callback(
nlp: Language, nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
optimizer: Optimizer,
corpus: Corpus,
cfg: Union[Config, Dict[str, Any]],
) -> Callable[[], Tuple[float, Dict[str, float]]]: ) -> Callable[[], Tuple[float, Dict[str, float]]]:
def evaluate() -> Tuple[float, Dict[str, float]]: def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = corpus.dev_dataset( dev_examples = list(dev_corpus(nlp))
nlp, gold_preproc=cfg["gold_preproc"] scores = nlp.evaluate(dev_examples)
)
dev_examples = list(dev_examples)
n_words = sum(len(ex.predicted) for ex in dev_examples)
batch_size = cfg["eval_batch_size"]
if optimizer.averages:
with nlp.use_params(optimizer.averages):
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
else:
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
# Calculate a weighted sum based on score_weights for the main score # Calculate a weighted sum based on score_weights for the main score
weights = cfg["score_weights"]
try: try:
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights) weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
except KeyError as e: except KeyError as e:
@ -280,6 +229,7 @@ def train_while_improving(
patience: int, patience: int,
max_steps: int, max_steps: int,
raw_text: List[Dict[str, str]], raw_text: List[Dict[str, str]],
exclude: List[str],
): ):
"""Train until an evaluation stops improving. Works as a generator, """Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -325,8 +275,6 @@ def train_while_improving(
dropouts = dropout dropouts = dropout
results = [] results = []
losses = {} losses = {}
to_enable = [name for name, proc in nlp.pipeline if hasattr(proc, "model")]
if raw_text: if raw_text:
random.shuffle(raw_text) random.shuffle(raw_text)
raw_examples = [ raw_examples = [
@ -336,20 +284,26 @@ def train_while_improving(
for step, (epoch, batch) in enumerate(train_data): for step, (epoch, batch) in enumerate(train_data):
dropout = next(dropouts) dropout = next(dropouts)
with nlp.select_pipes(enable=to_enable): for subbatch in subdivide_batch(batch, accumulate_gradient):
for subbatch in subdivide_batch(batch, accumulate_gradient): nlp.update(
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False) subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
if raw_text: )
# If raw text is available, perform 'rehearsal' updates, if raw_text:
# which use unlabelled data to reduce overfitting. # If raw text is available, perform 'rehearsal' updates,
raw_batch = list(next(raw_batches)) # which use unlabelled data to reduce overfitting.
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses) raw_batch = list(next(raw_batches))
for name, proc in nlp.pipeline: nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude)
if hasattr(proc, "model"): # TODO: refactor this so we don't have to run it separately in here
proc.model.finish_update(optimizer) for name, proc in nlp.pipeline:
if name not in exclude and hasattr(proc, "model"):
proc.model.finish_update(optimizer)
optimizer.step_schedules() optimizer.step_schedules()
if not (step % eval_frequency): if not (step % eval_frequency):
score, other_scores = evaluate() if optimizer.averages:
with nlp.use_params(optimizer.averages):
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
results.append((score, step)) results.append((score, step))
is_best_checkpoint = score == max(results)[0] is_best_checkpoint = score == max(results)[0]
else: else:
@ -460,17 +414,7 @@ def load_from_paths(
msg.fail("Can't find raw text", raw_text, exits=1) msg.fail("Can't find raw text", raw_text, exits=1)
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"])) raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
tag_map = {} tag_map = {}
tag_map_path = util.ensure_path(config["training"]["tag_map"])
if tag_map_path is not None:
if not tag_map_path.exists():
msg.fail("Can't find tag map path", tag_map_path, exits=1)
tag_map = srsly.read_json(config["training"]["tag_map"])
morph_rules = {} morph_rules = {}
morph_rules_path = util.ensure_path(config["training"]["morph_rules"])
if morph_rules_path is not None:
if not morph_rules_path.exists():
msg.fail("Can't find tag map path", morph_rules_path, exits=1)
morph_rules = srsly.read_json(config["training"]["morph_rules"])
weights_data = None weights_data = None
init_tok2vec = util.ensure_path(config["training"]["init_tok2vec"]) init_tok2vec = util.ensure_path(config["training"]["init_tok2vec"])
if init_tok2vec is not None: if init_tok2vec is not None:
@ -481,19 +425,10 @@ def load_from_paths(
return raw_text, tag_map, morph_rules, weights_data return raw_text, tag_map, morph_rules, weights_data
def verify_cli_args( def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
train_path: Path,
dev_path: Path,
config_path: Path,
output_path: Optional[Path] = None,
) -> None:
# Make sure all files and paths exists if they are needed # Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists(): if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1) msg.fail("Config file not found", config_path, exits=1)
if not train_path or not train_path.exists():
msg.fail("Training data not found", train_path, exits=1)
if not dev_path or not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if output_path is not None: if output_path is not None:
if not output_path.exists(): if not output_path.exists():
output_path.mkdir() output_path.mkdir()

View File

@ -13,8 +13,9 @@ from ..util import get_package_path, get_model_meta, is_compatible_version
@app.command("validate") @app.command("validate")
def validate_cli(): def validate_cli():
""" """
Validate that the currently installed version of spaCy is compatible Validate the currently installed models and spaCy version. Checks if the
with the installed models. Should be run after `pip install -U spacy`. installed models are compatible and shows upgrade instructions if available.
Should be run after `pip install -U spacy`.
""" """
validate() validate()

View File

@ -1,7 +1,20 @@
[paths]
train = ""
dev = ""
raw = null
init_tok2vec = null
[system]
seed = 0
use_pytorch_for_gpu_memory = false
[nlp] [nlp]
lang = null lang = null
pipeline = [] pipeline = []
load_vocab_data = true load_vocab_data = true
before_creation = null
after_creation = null
after_pipeline_creation = null
[nlp.tokenizer] [nlp.tokenizer]
@tokenizers = "spacy.Tokenizer.v1" @tokenizers = "spacy.Tokenizer.v1"
@ -13,38 +26,57 @@ load_vocab_data = true
# Training hyper-parameters and additional features. # Training hyper-parameters and additional features.
[training] [training]
# Whether to train on sequences with 'gold standard' sentence boundaries seed = ${system:seed}
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length or number of examples.
max_length = 5000
limit = 0
# Data augmentation
orth_variant_level = 0.0
dropout = 0.1 dropout = 0.1
accumulate_gradient = 1
# Extra resources for transfer-learning or pseudo-rehearsal
init_tok2vec = ${paths:init_tok2vec}
raw_text = ${paths:raw}
vectors = null
# Controls early-stopping. 0 or -1 mean unlimited. # Controls early-stopping. 0 or -1 mean unlimited.
patience = 1600 patience = 1600
max_epochs = 0 max_epochs = 0
max_steps = 20000 max_steps = 20000
eval_frequency = 200 eval_frequency = 200
eval_batch_size = 128
# Other settings
seed = 0
accumulate_gradient = 1
use_pytorch_for_gpu_memory = false
# Control how scores are printed and checkpoints are evaluated. # Control how scores are printed and checkpoints are evaluated.
score_weights = {} score_weights = {}
# These settings are invalid for the transformer models. # Names of pipeline components that shouldn't be updated during training
init_tok2vec = null frozen_components = []
[training.train_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 2000
# Limitation on number of training examples
limit = 0
[training.dev_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 2000
# Limitation on number of training examples
limit = 0
[training.batcher]
@batchers = "batch_by_words.v1"
discard_oversize = false discard_oversize = false
raw_text = null tolerance = 0.2
tag_map = null
morph_rules = null [training.batcher.size]
base_model = null @schedules = "compounding.v1"
vectors = null start = 100
batch_by = "words" stop = 1000
batch_size = 1000 compound = 1.001
[training.optimizer] [training.optimizer]
@optimizers = "Adam.v1" @optimizers = "Adam.v1"
@ -69,8 +101,8 @@ max_length = 500
dropout = 0.2 dropout = 0.2
n_save_every = null n_save_every = null
batch_size = 3000 batch_size = 3000
seed = ${training:seed} seed = ${system:seed}
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory} use_pytorch_for_gpu_memory = ${system:use_pytorch_for_gpu_memory}
tok2vec_model = "components.tok2vec.model" tok2vec_model = "components.tok2vec.model"
[pretraining.objective] [pretraining.objective]

View File

@ -63,8 +63,6 @@ class Warnings:
"have the spacy-lookups-data package installed.") "have the spacy-lookups-data package installed.")
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in " W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
"the Knowledge Base.") "the Knowledge Base.")
W025 = ("'{name}' requires '{attr}' to be assigned, but none of the "
"previous components in the pipeline declare that they assign it.")
W026 = ("Unable to set all sentence boundaries from dependency parses.") W026 = ("Unable to set all sentence boundaries from dependency parses.")
W027 = ("Found a large training file of {size} bytes. Note that it may " W027 = ("Found a large training file of {size} bytes. Note that it may "
"be more efficient to split your training data into multiple " "be more efficient to split your training data into multiple "
@ -376,7 +374,8 @@ class Errors:
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input " E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
"includes either the `text` or `tokens` key. For more info, see " "includes either the `text` or `tokens` key. For more info, see "
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl") "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
E139 = ("Knowledge Base for component '{name}' is empty.") E139 = ("Knowledge Base for component '{name}' is empty. Use the methods "
"kb.add_entity and kb.add_alias to add entries.")
E140 = ("The list of entities, prior probabilities and entity vectors " E140 = ("The list of entities, prior probabilities and entity vectors "
"should be of equal length.") "should be of equal length.")
E141 = ("Entity vectors should be of length {required} instead of the " E141 = ("Entity vectors should be of length {required} instead of the "
@ -483,10 +482,31 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E941 = ("Can't find model '{name}'. It looks like you're trying to load a "
"model from a shortcut, which is deprecated as of spaCy v3.0. To "
"load the model, use its full name instead:\n\n"
"nlp = spacy.load(\"{full}\")\n\nFor more details on the available "
"models, see the models directory: https://spacy.io/models. If you "
"want to create a blank model, use spacy.blank: "
"nlp = spacy.blank(\"{name}\")")
E942 = ("Executing after_{name} callback failed. Expected the function to "
"return an initialized nlp object but got: {value}. Maybe "
"you forgot to return the modified object in your function?")
E943 = ("Executing before_creation callback failed. Expected the function to "
"return an uninitialized Language subclass but got: {value}. Maybe "
"you forgot to return the modified object in your function or "
"returned the initialized nlp object instead?")
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
"not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}")
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected " E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'") "a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected " E948 = ("Matcher.add received invalid 'patterns' argument: expected "
"a List, but got: {arg_type}") "a List, but got: {arg_type}")
E949 = ("Can only create an alignment when the texts are the same.")
E952 = ("The section '{name}' is not a valid section in the provided config.") E952 = ("The section '{name}' is not a valid section in the provided config.")
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.") E954 = ("The Tok2Vec listener did not receive a valid input.")
@ -569,11 +589,13 @@ class Errors:
"into {values}, but found {value}.") "into {values}, but found {value}.")
E983 = ("Invalid key for '{dict}': {key}. Available keys: " E983 = ("Invalid key for '{dict}': {key}. Available keys: "
"{keys}") "{keys}")
E984 = ("Invalid component config for '{name}': no 'factory' key " E984 = ("Invalid component config for '{name}': component block needs either "
"specifying the registered function used to initialize the " "a key 'factory' specifying the registered function used to "
"component. For example, factory = \"ner\" will use the 'ner' " "initialize the component, or a key 'source' key specifying a "
"factory and all other settings in the block will be passed " "spaCy model to copy the component from. For example, factory = "
"to it as arguments.\n\n{config}") "\"ner\" will use the 'ner' factory and all other settings in the "
"block will be passed to it as arguments. Alternatively, source = "
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}") E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
E986 = ("Could not create any training batches: check your input. " E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?") "Perhaps discard_oversize should be set to False ?")
@ -608,6 +630,9 @@ class Errors:
"initializing the pipeline:\n" "initializing the pipeline:\n"
'cfg = {"tokenizer": {"segmenter": "pkuseg", "pkuseg_model": name_or_path}}\n' 'cfg = {"tokenizer": {"segmenter": "pkuseg", "pkuseg_model": name_or_path}}\n'
'nlp = Chinese(config=cfg)') 'nlp = Chinese(config=cfg)')
E1001 = ("Target token outside of matched span for match with tokens "
"'{span}' and offset '{index}' matched by patterns '{patterns}'.")
E1002 = ("Span index out of range.")
@add_codes @add_codes
@ -617,6 +642,15 @@ class TempErrors:
"issue tracker: http://github.com/explosion/spaCy/issues") "issue tracker: http://github.com/explosion/spaCy/issues")
# Deprecated model shortcuts, only used in errors and warnings
OLD_MODEL_SHORTCUTS = {
"en": "en_core_web_sm", "de": "de_core_news_sm", "es": "es_core_news_sm",
"pt": "pt_core_news_sm", "fr": "fr_core_news_sm", "it": "it_core_news_sm",
"nl": "nl_core_news_sm", "el": "el_core_news_sm", "nb": "nb_core_news_sm",
"lt": "lt_core_news_sm", "xx": "xx_ent_wiki_sm"
}
# fmt: on # fmt: on

View File

@ -1,11 +1,8 @@
from .corpus import Corpus from .corpus import Corpus # noqa: F401
from .example import Example from .example import Example # noqa: F401
from .align import Alignment from .align import Alignment # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags # noqa: F401
from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags from .iob_utils import spans_from_biluo_tags, tags_to_entities # noqa: F401
from .iob_utils import spans_from_biluo_tags from .gold_io import docs_to_json, read_json_file # noqa: F401
from .iob_utils import tags_to_entities from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
from .gold_io import docs_to_json
from .gold_io import read_json_file

View File

@ -4,6 +4,8 @@ from thinc.types import Ragged
from dataclasses import dataclass from dataclasses import dataclass
import tokenizations import tokenizations
from ..errors import Errors
@dataclass @dataclass
class Alignment: class Alignment:
@ -18,6 +20,8 @@ class Alignment:
@classmethod @classmethod
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment": def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
if "".join(A).replace(" ", "").lower() != "".join(B).replace(" ", "").lower():
raise ValueError(Errors.E949)
x2y, y2x = tokenizations.get_alignments(A, B) x2y, y2x = tokenizations.get_alignments(A, B)
return Alignment.from_indices(x2y=x2y, y2x=y2x) return Alignment.from_indices(x2y=x2y, y2x=y2x)

171
spacy/gold/batchers.py Normal file
View File

@ -0,0 +1,171 @@
from typing import Union, Iterator, Iterable, Sequence, TypeVar, List, Callable
from typing import Optional, Any
from functools import partial
import itertools
from ..util import registry, minibatch
Sizing = Union[Iterable[int], int]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@registry.batchers("batch_by_padded.v1")
def configure_minibatch_by_padded_size(
*,
size: Sizing,
buffer: int,
discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
# Avoid displacing optional values from the underlying function.
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(
minibatch_by_padded_size,
size=size,
buffer=buffer,
discard_oversize=discard_oversize,
**optionals
)
@registry.batchers("batch_by_words.v1")
def configure_minibatch_by_words(
*,
size: Sizing,
tolerance: float,
discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
)
@registry.batchers("batch_by_sequence.v1")
def configure_minibatch(
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(minibatch, size=size, **optionals)
def minibatch_by_padded_size(
docs: Iterator["Doc"],
size: Sizing,
buffer: int = 256,
discard_oversize: bool = False,
get_length: Callable = len,
) -> Iterator[Iterator["Doc"]]:
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
for outer_batch in minibatch(docs, size=buffer):
outer_batch = list(outer_batch)
target_size = next(size_)
for indices in _batch_by_length(outer_batch, target_size, get_length):
subbatch = [outer_batch[i] for i in indices]
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
if discard_oversize and padded_size >= target_size:
pass
else:
yield subbatch
def minibatch_by_words(
docs, size, tolerance=0.2, discard_oversize=False, get_length=len
):
"""Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by
themselves, or be discarded if discard_oversize=True.
The argument 'docs' can be a list of strings, Docs or Examples.
"""
if isinstance(size, int):
size_ = itertools.repeat(size)
elif isinstance(size, List):
size_ = iter(size)
else:
size_ = size
target_size = next(size_)
tol_size = target_size * tolerance
batch = []
overflow = []
batch_size = 0
overflow_size = 0
for doc in docs:
n_words = get_length(doc)
# if the current example exceeds the maximum batch size, it is returned separately
# but only if discard_oversize=False.
if n_words > target_size + tol_size:
if not discard_oversize:
yield [doc]
# add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(doc)
batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(doc)
overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples.
else:
if batch:
yield batch
target_size = next(size_)
tol_size = target_size * tolerance
batch = overflow
batch_size = overflow_size
overflow = []
overflow_size = 0
# this example still fits
if (batch_size + n_words) <= target_size:
batch.append(doc)
batch_size += n_words
# this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(doc)
overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch
else:
if batch:
yield batch
target_size = next(size_)
tol_size = target_size * tolerance
batch = [doc]
batch_size = n_words
batch.extend(overflow)
if batch:
yield batch
def _batch_by_length(
seqs: Sequence[Any], max_words: int, get_length=len
) -> List[List[Any]]:
"""Given a list of sequences, return a batched list of indices into the
list, where the batches are grouped by length, in descending order.
Batches may be at most max_words in size, defined as max sequence length * size.
"""
# Use negative index so we can get sort by position ascending.
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
lengths_indices.sort()
batches = []
batch = []
for length, i in lengths_indices:
if not batch:
batch.append(i)
elif length * (len(batch) + 1) <= max_words:
batch.append(i)
else:
batches.append(batch)
batch = [i]
if batch:
batches.append(batch)
# Check lengths match
assert sum(len(b) for b in batches) == len(seqs)
batches = [list(sorted(batch)) for batch in batches]
batches.reverse()
return batches

View File

@ -1,4 +1,4 @@
from .iob2docs import iob2docs # noqa: F401 from .iob2docs import iob2docs # noqa: F401
from .conll_ner2docs import conll_ner2docs # noqa: F401 from .conll_ner2docs import conll_ner2docs # noqa: F401
from .json2docs import json2docs from .json2docs import json2docs # noqa: F401
from .conllu2docs import conllu2docs # noqa: F401 from .conllu2docs import conllu2docs # noqa: F401

View File

@ -1,6 +1,5 @@
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from pathlib import Path from pathlib import Path
import random
from .. import util from .. import util
from .example import Example from .example import Example
@ -12,26 +11,43 @@ if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
@util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader(
path: Path, gold_preproc: bool, max_length: int = 0, limit: int = 0
) -> Callable[["Language"], Iterable[Example]]:
return Corpus(path, gold_preproc=gold_preproc, max_length=max_length, limit=limit)
class Corpus: class Corpus:
"""An annotated corpus, reading train and dev datasets from """Iterate Example objects from a file or directory of DocBin (.spacy)
the DocBin (.spacy) format. formated data files.
path (Path): The directory or filename to read from.
gold_preproc (bool): Whether to set up the Example object with gold-standard
sentences and tokens for the predictions. Gold preprocessing helps
the annotations align to the tokenization, and may result in sequences
of more consistent length. However, it may reduce run-time accuracy due
to train/test skew. Defaults to False.
max_length (int): Maximum document length. Longer documents will be
split into sentences, if sentence boundaries are available. Defaults to
0, which indicates no limit.
limit (int): Limit corpus to a subset of examples, e.g. for debugging.
Defaults to 0, which indicates no limit.
DOCS: https://spacy.io/api/corpus DOCS: https://spacy.io/api/corpus
""" """
def __init__( def __init__(
self, train_loc: Union[str, Path], dev_loc: Union[str, Path], limit: int = 0 self,
path,
*,
limit: int = 0,
gold_preproc: bool = False,
max_length: bool = False,
) -> None: ) -> None:
"""Create a Corpus. self.path = util.ensure_path(path)
self.gold_preproc = gold_preproc
train (str / Path): File or directory of training data. self.max_length = max_length
dev (str / Path): File or directory of development data.
limit (int): Max. number of examples returned.
DOCS: https://spacy.io/api/corpus#init
"""
self.train_loc = train_loc
self.dev_loc = dev_loc
self.limit = limit self.limit = limit
@staticmethod @staticmethod
@ -54,6 +70,21 @@ class Corpus:
locs.append(path) locs.append(path)
return locs return locs
def __call__(self, nlp: "Language") -> Iterator[Example]:
"""Yield examples from the data.
nlp (Language): The current nlp object.
YIELDS (Example): The examples.
DOCS: https://spacy.io/api/corpus#call
"""
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.path))
if self.gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs)
else:
examples = self.make_examples(nlp, ref_docs, self.max_length)
yield from examples
def _make_example( def _make_example(
self, nlp: "Language", reference: Doc, gold_preproc: bool self, nlp: "Language", reference: Doc, gold_preproc: bool
) -> Example: ) -> Example:
@ -114,68 +145,3 @@ class Corpus:
i += 1 i += 1
if self.limit >= 1 and i >= self.limit: if self.limit >= 1 and i >= self.limit:
break break
def count_train(self, nlp: "Language") -> int:
"""Returns count of words in train examples.
nlp (Language): The current nlp. object.
RETURNS (int): The word count.
DOCS: https://spacy.io/api/corpus#count_train
"""
n = 0
i = 0
for example in self.train_dataset(nlp):
n += len(example.predicted)
if self.limit >= 0 and i >= self.limit:
break
i += 1
return n
def train_dataset(
self,
nlp: "Language",
*,
shuffle: bool = True,
gold_preproc: bool = False,
max_length: int = 0
) -> Iterator[Example]:
"""Yield examples from the training data.
nlp (Language): The current nlp object.
shuffle (bool): Whether to shuffle the examples.
gold_preproc (bool): Whether to train on gold-standard sentences and tokens.
max_length (int): Maximum document length. Longer documents will be
split into sentences, if sentence boundaries are available. 0 for
no limit.
YIELDS (Example): The examples.
DOCS: https://spacy.io/api/corpus#train_dataset
"""
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
if gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs)
else:
examples = self.make_examples(nlp, ref_docs, max_length)
if shuffle:
examples = list(examples)
random.shuffle(examples)
yield from examples
def dev_dataset(
self, nlp: "Language", *, gold_preproc: bool = False
) -> Iterator[Example]:
"""Yield examples from the development data.
nlp (Language): The current nlp object.
gold_preproc (bool): Whether to train on gold-standard sentences and tokens.
YIELDS (Example): The examples.
DOCS: https://spacy.io/api/corpus#dev_dataset
"""
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
if gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs)
else:
examples = self.make_examples(nlp, ref_docs, max_length=0)
yield from examples

View File

@ -4,4 +4,6 @@ from ..tokens.doc cimport Doc
cdef class Example: cdef class Example:
cdef readonly Doc x cdef readonly Doc x
cdef readonly Doc y cdef readonly Doc y
cdef readonly object _alignment cdef readonly object _cached_alignment
cdef readonly object _cached_words_x
cdef readonly object _cached_words_y

View File

@ -10,7 +10,7 @@ from .align import Alignment
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
from .iob_utils import spans_from_biluo_tags from .iob_utils import spans_from_biluo_tags
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..syntax import nonproj from ..pipeline._parser_internals import nonproj
cpdef Doc annotations2doc(vocab, tok_annot, doc_annot): cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
@ -32,9 +32,9 @@ cdef class Example:
raise TypeError(Errors.E972.format(arg="predicted")) raise TypeError(Errors.E972.format(arg="predicted"))
if reference is None: if reference is None:
raise TypeError(Errors.E972.format(arg="reference")) raise TypeError(Errors.E972.format(arg="reference"))
self.x = predicted self.predicted = predicted
self.y = reference self.reference = reference
self._alignment = alignment self._cached_alignment = alignment
def __len__(self): def __len__(self):
return len(self.predicted) return len(self.predicted)
@ -45,7 +45,8 @@ cdef class Example:
def __set__(self, doc): def __set__(self, doc):
self.x = doc self.x = doc
self._alignment = None self._cached_alignment = None
self._cached_words_x = [t.text for t in doc]
property reference: property reference:
def __get__(self): def __get__(self):
@ -53,7 +54,8 @@ cdef class Example:
def __set__(self, doc): def __set__(self, doc):
self.y = doc self.y = doc
self._alignment = None self._cached_alignment = None
self._cached_words_y = [t.text for t in doc]
def copy(self): def copy(self):
return Example( return Example(
@ -79,13 +81,15 @@ cdef class Example:
@property @property
def alignment(self): def alignment(self):
if self._alignment is None: words_x = [token.text for token in self.x]
spacy_words = [token.orth_ for token in self.predicted] words_y = [token.text for token in self.y]
gold_words = [token.orth_ for token in self.reference] if self._cached_alignment is None or \
if gold_words == []: words_x != self._cached_words_x or \
gold_words = spacy_words words_y != self._cached_words_y:
self._alignment = Alignment.from_strings(spacy_words, gold_words) self._cached_alignment = Alignment.from_strings(words_x, words_y)
return self._alignment self._cached_words_x = words_x
self._cached_words_y = words_y
return self._cached_alignment
def get_aligned(self, field, as_string=False): def get_aligned(self, field, as_string=False):
"""Return an aligned array for a token attribute.""" """Return an aligned array for a token attribute."""
@ -179,15 +183,15 @@ cdef class Example:
"links": self._links_to_dict() "links": self._links_to_dict()
}, },
"token_annotation": { "token_annotation": {
"ids": [t.i+1 for t in self.reference], "ORTH": [t.text for t in self.reference],
"words": [t.text for t in self.reference], "SPACY": [bool(t.whitespace_) for t in self.reference],
"tags": [t.tag_ for t in self.reference], "TAG": [t.tag_ for t in self.reference],
"lemmas": [t.lemma_ for t in self.reference], "LEMMA": [t.lemma_ for t in self.reference],
"pos": [t.pos_ for t in self.reference], "POS": [t.pos_ for t in self.reference],
"morphs": [t.morph_ for t in self.reference], "MORPH": [t.morph_ for t in self.reference],
"heads": [t.head.i for t in self.reference], "HEAD": [t.head.i for t in self.reference],
"deps": [t.dep_ for t in self.reference], "DEP": [t.dep_ for t in self.reference],
"sent_starts": [int(bool(t.is_sent_start)) for t in self.reference] "SENT_START": [int(bool(t.is_sent_start)) for t in self.reference]
} }
} }
@ -331,10 +335,14 @@ def _fix_legacy_dict_data(example_dict):
for key, value in old_token_dict.items(): for key, value in old_token_dict.items():
if key in ("text", "ids", "brackets"): if key in ("text", "ids", "brackets"):
pass pass
elif key in remapping.values():
token_dict[key] = value
elif key.lower() in remapping: elif key.lower() in remapping:
token_dict[remapping[key.lower()]] = value token_dict[remapping[key.lower()]] = value
else: else:
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=remapping.keys())) all_keys = set(remapping.values())
all_keys.update(remapping.keys())
raise KeyError(Errors.E983.format(key=key, dict="token_annotation", keys=all_keys))
text = example_dict.get("text", example_dict.get("raw")) text = example_dict.get("text", example_dict.get("raw"))
if _has_field(token_dict, "ORTH") and not _has_field(token_dict, "SPACY"): if _has_field(token_dict, "ORTH") and not _has_field(token_dict, "SPACY"):
token_dict["SPACY"] = _guess_spaces(text, token_dict["ORTH"]) token_dict["SPACY"] = _guess_spaces(text, token_dict["ORTH"])

View File

@ -71,17 +71,25 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
def __init__(self, Vocab vocab, entity_vector_length=64): def __init__(self, entity_vector_length):
self.vocab = vocab """Create a KnowledgeBase. Make sure to call kb.initialize() before using it."""
self.mem = Pool() self.mem = Pool()
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap() self._entry_index = PreshMap()
self._alias_index = PreshMap() self._alias_index = PreshMap()
self.vocab = None
def initialize(self, Vocab vocab):
self.vocab = vocab
self.vocab.strings.add("") self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def require_vocab(self):
if self.vocab is None:
raise ValueError(Errors.E946)
@property @property
def entity_vector_length(self): def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors""" """RETURNS (uint64): length of the entity vectors"""
@ -94,12 +102,14 @@ cdef class KnowledgeBase:
return len(self._entry_index) return len(self._entry_index)
def get_entity_strings(self): def get_entity_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._entry_index] return [self.vocab.strings[x] for x in self._entry_index]
def get_size_aliases(self): def get_size_aliases(self):
return len(self._alias_index) return len(self._alias_index)
def get_alias_strings(self): def get_alias_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._alias_index] return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float freq, vector[float] entity_vector): def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
@ -107,6 +117,7 @@ cdef class KnowledgeBase:
Add an entity to the KB, optionally specifying its log probability based on corpus frequency Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end. Return the hash of the entity ID/name at the end.
""" """
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before # Return if this entity was added before
@ -129,6 +140,7 @@ cdef class KnowledgeBase:
return entity_hash return entity_hash
cpdef set_entities(self, entity_list, freq_list, vector_list): cpdef set_entities(self, entity_list, freq_list, vector_list):
self.require_vocab()
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list): if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140) raise ValueError(Errors.E140)
@ -164,10 +176,12 @@ cdef class KnowledgeBase:
i += 1 i += 1
def contains_entity(self, unicode entity): def contains_entity(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index return entity_hash in self._entry_index
def contains_alias(self, unicode alias): def contains_alias(self, unicode alias):
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings.add(alias) cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index return alias_hash in self._alias_index
@ -176,6 +190,7 @@ cdef class KnowledgeBase:
For a given alias, add its potential entities and prior probabilies to the KB. For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end Return the alias_hash at the end
""" """
self.require_vocab()
# Throw an error if the length of entities and probabilities are not the same # Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities): if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias, raise ValueError(Errors.E132.format(alias=alias,
@ -219,6 +234,7 @@ cdef class KnowledgeBase:
Throw an error if this entity+prior prob would exceed the sum of 1. Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
""" """
self.require_vocab()
# Check if the alias exists in the KB # Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
@ -265,6 +281,7 @@ cdef class KnowledgeBase:
and the prior probability of that alias resolving to that entity. and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned. If the alias is not known in the KB, and empty list is returned.
""" """
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
return [] return []
@ -281,6 +298,7 @@ cdef class KnowledgeBase:
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity): def get_vector(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB # Return an empty list if this entity is unknown in this KB
@ -293,6 +311,7 @@ cdef class KnowledgeBase:
def get_prior_prob(self, unicode entity, unicode alias): def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity, """ Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base""" or return 0.0 when this combination is not known in the knowledge base"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
@ -311,6 +330,7 @@ cdef class KnowledgeBase:
def dump(self, loc): def dump(self, loc):
self.require_vocab()
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length) writer.write_header(self.get_size_entities(), self.entity_vector_length)

View File

@ -18,7 +18,7 @@ from timeit import default_timer as timer
from .tokens.underscore import Underscore from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .gold import Example from .gold import Example
from .scorer import Scorer from .scorer import Scorer
from .util import create_default_optimizer, registry from .util import create_default_optimizer, registry
@ -37,8 +37,6 @@ from . import util
from . import about from . import about
# TODO: integrate pipeline analyis
ENABLE_PIPELINE_ANALYSIS = False
# This is the base config will all settings (training etc.) # This is the base config will all settings (training etc.)
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg" DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH) DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH)
@ -522,6 +520,25 @@ class Language:
return add_component(func) return add_component(func)
return add_component return add_component
def analyze_pipes(
self,
*,
keys: List[str] = ["assigns", "requires", "scores", "retokenizes"],
pretty: bool = False,
) -> Optional[Dict[str, Any]]:
"""Analyze the current pipeline components, print a summary of what
they assign or require and check that all requirements are met.
keys (List[str]): The meta values to display in the table. Corresponds
to values in FactoryMeta, defined by @Language.factory decorator.
pretty (bool): Pretty-print the results.
RETURNS (dict): The data.
"""
analysis = analyze_pipes(self, keys=keys)
if pretty:
print_pipe_analysis(analysis, keys=keys)
return analysis
def get_pipe(self, name: str) -> Callable[[Doc], Doc]: def get_pipe(self, name: str) -> Callable[[Doc], Doc]:
"""Get a pipeline component for a given component name. """Get a pipeline component for a given component name.
@ -541,7 +558,6 @@ class Language:
name: Optional[str] = None, name: Optional[str] = None,
*, *,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Create a pipeline component. Mostly used internally. To create and """Create a pipeline component. Mostly used internally. To create and
@ -552,8 +568,6 @@ class Language:
Defaults to factory name if not set. Defaults to factory name if not set.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
passed in via the CLI.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -596,13 +610,39 @@ class Language:
# registered functions twice # registered functions twice
# TODO: customize validation to make it more readable / relate it to # TODO: customize validation to make it more readable / relate it to
# pipeline component and why it failed, explain default config # pipeline component and why it failed, explain default config
resolved, filled = registry.resolve(cfg, validate=validate, overrides=overrides) resolved, filled = registry.resolve(cfg, validate=validate)
filled = filled[factory_name] filled = filled[factory_name]
filled["factory"] = factory_name filled["factory"] = factory_name
filled.pop("@factories", None) filled.pop("@factories", None)
self._pipe_configs[name] = filled self._pipe_configs[name] = filled
return resolved[factory_name] return resolved[factory_name]
def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str,
) -> Tuple[Callable[[Doc], Doc], str]:
"""Create a pipeline component by copying it from an existing model.
source_name (str): Name of the component in the source pipeline.
source (Language): The source nlp object to copy from.
name (str): Optional alternative name to use in current pipeline.
RETURNS (Tuple[Callable, str]): The component and its factory name.
"""
# TODO: handle errors and mismatches (vectors etc.)
if not isinstance(source, self.__class__):
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
if not source.has_pipe(source_name):
raise KeyError(
Errors.E944.format(
name=source_name,
model=f"{source.meta['lang']}_{source.meta['name']}",
opts=", ".join(source.pipe_names),
)
)
pipe = source.get_pipe(source_name)
pipe_config = util.copy_config(source.config["components"][source_name])
self._pipe_configs[name] = pipe_config
return pipe, pipe_config["factory"]
def add_pipe( def add_pipe(
self, self,
factory_name: str, factory_name: str,
@ -612,8 +652,8 @@ class Language:
after: Optional[Union[str, int]] = None, after: Optional[Union[str, int]] = None,
first: Optional[bool] = None, first: Optional[bool] = None,
last: Optional[bool] = None, last: Optional[bool] = None,
source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Add a component to the processing pipeline. Valid components are """Add a component to the processing pipeline. Valid components are
@ -631,10 +671,10 @@ class Language:
component directly after. component directly after.
first (bool): If True, insert component first in the pipeline. first (bool): If True, insert component first in the pipeline.
last (bool): If True, insert component last in the pipeline. last (bool): If True, insert component last in the pipeline.
source (Language): Optional loaded nlp object to copy the pipeline
component from.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically
passed in via the CLI.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -645,29 +685,30 @@ class Language:
bad_val = repr(factory_name) bad_val = repr(factory_name)
err = Errors.E966.format(component=bad_val, name=name) err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err) raise ValueError(err)
if not self.has_factory(factory_name):
err = Errors.E002.format(
name=factory_name,
opts=", ".join(self.factory_names),
method="add_pipe",
lang=util.get_object_name(self),
lang_code=self.lang,
)
name = name if name is not None else factory_name name = name if name is not None else factory_name
if name in self.pipe_names: if name in self.pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names)) raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
pipe_component = self.create_pipe( if source is not None:
factory_name, # We're loading the component from a model. After loading the
name=name, # component, we know its real factory name
config=config, pipe_component, factory_name = self.create_pipe_from_source(
overrides=overrides, factory_name, source, name=name
validate=validate, )
) else:
if not self.has_factory(factory_name):
err = Errors.E002.format(
name=factory_name,
opts=", ".join(self.factory_names),
method="add_pipe",
lang=util.get_object_name(self),
lang_code=self.lang,
)
pipe_component = self.create_pipe(
factory_name, name=name, config=config, validate=validate,
)
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component)) self.pipeline.insert(pipe_index, (name, pipe_component))
if ENABLE_PIPELINE_ANALYSIS:
analyze_pipes(self, name, pipe_index)
return pipe_component return pipe_component
def _get_pipe_index( def _get_pipe_index(
@ -754,12 +795,11 @@ class Language:
# to Language.pipeline to make sure the configs are handled correctly # to Language.pipeline to make sure the configs are handled correctly
pipe_index = self.pipe_names.index(name) pipe_index = self.pipe_names.index(name)
self.remove_pipe(name) self.remove_pipe(name)
if not len(self.pipeline): # we have no components to insert before/after if not len(self.pipeline) or pipe_index == len(self.pipeline):
# we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name) self.add_pipe(factory_name, name=name)
else: else:
self.add_pipe(factory_name, name=name, before=pipe_index) self.add_pipe(factory_name, name=name, before=pipe_index)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
def rename_pipe(self, old_name: str, new_name: str) -> None: def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component. """Rename a pipeline component.
@ -793,8 +833,6 @@ class Language:
# because factory may be used for something else # because factory may be used for something else
self._pipe_meta.pop(name) self._pipe_meta.pop(name)
self._pipe_configs.pop(name) self._pipe_configs.pop(name)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
return removed return removed
def __call__( def __call__(
@ -900,6 +938,7 @@ class Language:
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
): ):
"""Update the models in the pipeline. """Update the models in the pipeline.
@ -910,6 +949,7 @@ class Language:
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name. components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
RETURNS (Dict[str, float]): The updated losses dictionary RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#update DOCS: https://spacy.io/api/language#update
@ -942,12 +982,12 @@ class Language:
component_cfg[name].setdefault("drop", drop) component_cfg[name].setdefault("drop", drop)
component_cfg[name].setdefault("set_annotations", False) component_cfg[name].setdefault("set_annotations", False)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if not hasattr(proc, "update"): if name in exclude or not hasattr(proc, "update"):
continue continue
proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
if sgd not in (None, False): if sgd not in (None, False):
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "model"): if name not in exclude and hasattr(proc, "model"):
proc.model.finish_update(sgd) proc.model.finish_update(sgd)
return losses return losses
@ -958,6 +998,7 @@ class Language:
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Make a "rehearsal" update to the models in the pipeline, to prevent """Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some forgetting. Rehearsal updates run an initial copy of the model over some
@ -969,6 +1010,7 @@ class Language:
sgd (Optional[Optimizer]): An optimizer. sgd (Optional[Optimizer]): An optimizer.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name. components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
RETURNS (dict): Results from the update. RETURNS (dict): Results from the update.
EXAMPLE: EXAMPLE:
@ -1012,7 +1054,7 @@ class Language:
get_grads.b1 = sgd.b1 get_grads.b1 = sgd.b1
get_grads.b2 = sgd.b2 get_grads.b2 = sgd.b2
for name, proc in pipes: for name, proc in pipes:
if not hasattr(proc, "rehearse"): if name in exclude or not hasattr(proc, "rehearse"):
continue continue
grads = {} grads = {}
proc.rehearse( proc.rehearse(
@ -1063,7 +1105,7 @@ class Language:
return self._optimizer return self._optimizer
def resume_training( def resume_training(
self, *, sgd: Optional[Optimizer] = None, device: int = -1 self, *, sgd: Optional[Optimizer] = None, device: int = -1,
) -> Optimizer: ) -> Optimizer:
"""Continue training a pretrained model. """Continue training a pretrained model.
@ -1099,6 +1141,7 @@ class Language:
batch_size: int = 256, batch_size: int = 256,
scorer: Optional[Scorer] = None, scorer: Optional[Scorer] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
scorer_cfg: Optional[Dict[str, Any]] = None,
) -> Dict[str, Union[float, dict]]: ) -> Dict[str, Union[float, dict]]:
"""Evaluate a model's pipeline components. """Evaluate a model's pipeline components.
@ -1109,6 +1152,8 @@ class Language:
will be created. will be created.
component_cfg (dict): An optional dictionary with extra keyword component_cfg (dict): An optional dictionary with extra keyword
arguments for specific components. arguments for specific components.
scorer_cfg (dict): An optional dictionary with extra keyword arguments
for the scorer.
RETURNS (Scorer): The scorer containing the evaluation results. RETURNS (Scorer): The scorer containing the evaluation results.
DOCS: https://spacy.io/api/language#evaluate DOCS: https://spacy.io/api/language#evaluate
@ -1126,8 +1171,10 @@ class Language:
raise TypeError(err) raise TypeError(err)
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
if scorer_cfg is None:
scorer_cfg = {}
if scorer is None: if scorer is None:
kwargs = component_cfg.get("scorer", {}) kwargs = dict(scorer_cfg)
kwargs.setdefault("verbose", verbose) kwargs.setdefault("verbose", verbose)
kwargs.setdefault("nlp", self) kwargs.setdefault("nlp", self)
scorer = Scorer(**kwargs) scorer = Scorer(**kwargs)
@ -1136,9 +1183,9 @@ class Language:
start_time = timer() start_time = timer()
# tokenize the texts only for timing purposes # tokenize the texts only for timing purposes
if not hasattr(self.tokenizer, "pipe"): if not hasattr(self.tokenizer, "pipe"):
_ = [self.tokenizer(text) for text in texts] _ = [self.tokenizer(text) for text in texts] # noqa: F841
else: else:
_ = list(self.tokenizer.pipe(texts)) _ = list(self.tokenizer.pipe(texts)) # noqa: F841
for name, pipe in self.pipeline: for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {}) kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size) kwargs.setdefault("batch_size", batch_size)
@ -1357,8 +1404,8 @@ class Language:
cls, cls,
config: Union[Dict[str, Any], Config] = {}, config: Union[Dict[str, Any], Config] = {},
*, *,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {},
auto_fill: bool = True, auto_fill: bool = True,
validate: bool = True, validate: bool = True,
) -> "Language": ) -> "Language":
@ -1367,6 +1414,7 @@ class Language:
the default config of the given language is used. the default config of the given language is used.
config (Dict[str, Any] / Config): The loaded config. config (Dict[str, Any] / Config): The loaded config.
vocab (Vocab): A Vocab object. If True, a vocab is created.
disable (Iterable[str]): List of pipeline component names to disable. disable (Iterable[str]): List of pipeline component names to disable.
auto_fill (bool): Automatically fill in missing values in config based auto_fill (bool): Automatically fill in missing values in config based
on defaults and function argument annotations. on defaults and function argument annotations.
@ -1397,43 +1445,76 @@ class Language:
config = util.copy_config(config) config = util.copy_config(config)
orig_pipeline = config.pop("components", {}) orig_pipeline = config.pop("components", {})
config["components"] = {} config["components"] = {}
non_pipe_overrides, pipe_overrides = _get_config_overrides(overrides)
resolved, filled = registry.resolve( resolved, filled = registry.resolve(
config, validate=validate, schema=ConfigSchema, overrides=non_pipe_overrides config, validate=validate, schema=ConfigSchema
) )
filled["components"] = orig_pipeline filled["components"] = orig_pipeline
config["components"] = orig_pipeline config["components"] = orig_pipeline
create_tokenizer = resolved["nlp"]["tokenizer"] create_tokenizer = resolved["nlp"]["tokenizer"]
create_lemmatizer = resolved["nlp"]["lemmatizer"] create_lemmatizer = resolved["nlp"]["lemmatizer"]
nlp = cls( before_creation = resolved["nlp"]["before_creation"]
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer, after_creation = resolved["nlp"]["after_creation"]
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
lang_cls = cls
if before_creation is not None:
lang_cls = before_creation(cls)
if (
not isinstance(lang_cls, type)
or not issubclass(lang_cls, cls)
or lang_cls is not cls
):
raise ValueError(Errors.E943.format(value=type(lang_cls)))
nlp = lang_cls(
vocab=vocab,
create_tokenizer=create_tokenizer,
create_lemmatizer=create_lemmatizer,
) )
if after_creation is not None:
nlp = after_creation(nlp)
if not isinstance(nlp, cls):
raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
# Note that we don't load vectors here, instead they get loaded explicitly # Note that we don't load vectors here, instead they get loaded explicitly
# inside stuff like the spacy train function. If we loaded them here, # inside stuff like the spacy train function. If we loaded them here,
# then we would load them twice at runtime: once when we make from config, # then we would load them twice at runtime: once when we make from config,
# and then again when we load from disk. # and then again when we load from disk.
pipeline = config.get("components", {}) pipeline = config.get("components", {})
# If components are loaded from a source (existing models), we cache
# them here so they're only loaded once
source_nlps = {}
for pipe_name in config["nlp"]["pipeline"]: for pipe_name in config["nlp"]["pipeline"]:
if pipe_name not in pipeline: if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys()) opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = util.copy_config(pipeline[pipe_name]) pipe_cfg = util.copy_config(pipeline[pipe_name])
if pipe_name not in disable: if pipe_name not in disable:
if "factory" not in pipe_cfg: if "factory" not in pipe_cfg and "source" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg) err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err) raise ValueError(err)
factory = pipe_cfg.pop("factory") if "factory" in pipe_cfg:
# The pipe name (key in the config) here is the unique name of the factory = pipe_cfg.pop("factory")
# component, not necessarily the factory # The pipe name (key in the config) here is the unique name
nlp.add_pipe( # of the component, not necessarily the factory
factory, nlp.add_pipe(
name=pipe_name, factory, name=pipe_name, config=pipe_cfg, validate=validate,
config=pipe_cfg, )
overrides=pipe_overrides, else:
validate=validate, model = pipe_cfg["source"]
) if model not in source_nlps:
# We only need the components here and we need to init
# model with the same vocab as the current nlp object
source_nlps[model] = util.load_model(
model, vocab=nlp.vocab, disable=["vocab", "tokenizer"]
)
source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved nlp.resolved = resolved
if after_pipeline_creation is not None:
nlp = after_pipeline_creation(nlp)
if not isinstance(nlp, cls):
raise ValueError(
Errors.E942.format(name="pipeline_creation", value=type(nlp))
)
return nlp return nlp
def to_disk( def to_disk(
@ -1599,15 +1680,6 @@ class FactoryMeta:
default_score_weights: Optional[Dict[str, float]] = None # noqa: E704 default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
def _get_config_overrides(
items: Dict[str, Any], prefix: str = "components"
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
prefix = f"{prefix}."
non_pipe = {k: v for k, v in items.items() if not k.startswith(prefix)}
pipe = {k.replace(prefix, ""): v for k, v in items.items() if k.startswith(prefix)}
return non_pipe, pipe
def _fix_pretrained_vectors_name(nlp: Language) -> None: def _fix_pretrained_vectors_name(nlp: Language) -> None:
# TODO: Replace this once we handle vectors consistently as static # TODO: Replace this once we handle vectors consistently as static
# data # data

View File

@ -80,7 +80,7 @@ def _get_transition_table(
B_start, B_end = (0, n_labels) B_start, B_end = (0, n_labels)
I_start, I_end = (B_end, B_end + n_labels) I_start, I_end = (B_end, B_end + n_labels)
L_start, L_end = (I_end, I_end + n_labels) L_start, L_end = (I_end, I_end + n_labels)
U_start, _ = (L_end, L_end + n_labels) U_start, _ = (L_end, L_end + n_labels) # noqa: F841
# Using ranges allows us to set specific cells, which is necessary to express # Using ranges allows us to set specific cells, which is necessary to express
# that only actions of the same label are valid continuations. # that only actions of the same label are valid continuations.
B_range = numpy.arange(B_start, B_end) B_range = numpy.arange(B_start, B_end)

View File

@ -1,6 +1,7 @@
from typing import List from typing import List
from thinc.api import Model from thinc.api import Model
from thinc.types import Floats2d from thinc.types import Floats2d
from ..tokens import Doc from ..tokens import Doc
@ -15,14 +16,14 @@ def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
) )
def init(model, X=None, Y=None): def init(model: Model, X=None, Y=None):
vectors_table = model.ops.alloc3f( vectors_table = model.ops.alloc3f(
model.get_dim("nC"), model.get_dim("nV"), model.get_dim("nM") model.get_dim("nC"), model.get_dim("nV"), model.get_dim("nM")
) )
model.set_param("E", vectors_table) model.set_param("E", vectors_table)
def forward(model, docs, is_train): def forward(model: Model, docs: List[Doc], is_train: bool):
if docs is None: if docs is None:
return [] return []
ids = [] ids = []

View File

@ -14,7 +14,7 @@ def IOB() -> Model[Padded, Padded]:
) )
def init(model, X: Optional[Padded] = None, Y: Optional[Padded] = None): def init(model: Model, X: Optional[Padded] = None, Y: Optional[Padded] = None) -> None:
if X is not None and Y is not None: if X is not None and Y is not None:
if X.data.shape != Y.data.shape: if X.data.shape != Y.data.shape:
# TODO: Fix error # TODO: Fix error

View File

@ -4,14 +4,14 @@ from thinc.api import Model
from ..attrs import LOWER from ..attrs import LOWER
def extract_ngrams(ngram_size, attr=LOWER) -> Model: def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
model = Model("extract_ngrams", forward) model = Model("extract_ngrams", forward)
model.attrs["ngram_size"] = ngram_size model.attrs["ngram_size"] = ngram_size
model.attrs["attr"] = attr model.attrs["attr"] = attr
return model return model
def forward(model, docs, is_train: bool): def forward(model: Model, docs, is_train: bool):
batch_keys = [] batch_keys = []
batch_vals = [] batch_vals = []
for doc in docs: for doc in docs:

View File

@ -1,5 +1,4 @@
from pathlib import Path from typing import Optional
from thinc.api import chain, clone, list2ragged, reduce_mean, residual from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear from thinc.api import Model, Maxout, Linear
@ -9,7 +8,7 @@ from ...vocab import Vocab
@registry.architectures.register("spacy.EntityLinker.v1") @registry.architectures.register("spacy.EntityLinker.v1")
def build_nel_encoder(tok2vec, nO=None): def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
with Model.define_operators({">>": chain, "**": clone}): with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.get_dim("nO") token_width = tok2vec.get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width) output_layer = Linear(nO=nO, nI=token_width)
@ -26,8 +25,15 @@ def build_nel_encoder(tok2vec, nO=None):
@registry.assets.register("spacy.KBFromFile.v1") @registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path, kb_path) -> KnowledgeBase: def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
vocab = Vocab().from_disk(vocab_path) vocab = Vocab().from_disk(vocab_path)
kb = KnowledgeBase(vocab=vocab) kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab)
kb.load_bulk(kb_path) kb.load_bulk(kb_path)
return kb return kb
@registry.assets.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> KnowledgeBase:
kb = KnowledgeBase(entity_vector_length=entity_vector_length)
return kb

View File

@ -1,10 +1,20 @@
from typing import Optional, Iterable, Tuple, List, TYPE_CHECKING
import numpy import numpy
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
from thinc.api import MultiSoftmax, list2array from thinc.api import MultiSoftmax, list2array
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from ...vocab import Vocab # noqa: F401
from ...tokens import Doc # noqa: F401
def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
def build_multi_task_model(
tok2vec: Model,
maxout_pieces: int,
token_vector_width: int,
nO: Optional[int] = None,
) -> Model:
softmax = Softmax(nO=nO, nI=token_vector_width * 2) softmax = Softmax(nO=nO, nI=token_vector_width * 2)
model = chain( model = chain(
tok2vec, tok2vec,
@ -22,7 +32,13 @@ def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
return model return model
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=None): def build_cloze_multi_task_model(
vocab: "Vocab",
tok2vec: Model,
maxout_pieces: int,
hidden_size: int,
nO: Optional[int] = None,
) -> Model:
# nO = vocab.vectors.data.shape[1] # nO = vocab.vectors.data.shape[1]
output_layer = chain( output_layer = chain(
list2array(), list2array(),
@ -43,24 +59,24 @@ def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=
def build_cloze_characters_multi_task_model( def build_cloze_characters_multi_task_model(
vocab, tok2vec, maxout_pieces, hidden_size, nr_char vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int, nr_char: int
): ) -> Model:
output_layer = chain( output_layer = chain(
list2array(), list2array(),
Maxout(hidden_size, nP=maxout_pieces), Maxout(hidden_size, nP=maxout_pieces),
LayerNorm(nI=hidden_size), LayerNorm(nI=hidden_size),
MultiSoftmax([256] * nr_char, nI=hidden_size), MultiSoftmax([256] * nr_char, nI=hidden_size),
) )
model = build_masked_language_model(vocab, chain(tok2vec, output_layer)) model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
model.set_ref("tok2vec", tok2vec) model.set_ref("tok2vec", tok2vec)
model.set_ref("output_layer", output_layer) model.set_ref("output_layer", output_layer)
return model return model
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15): def build_masked_language_model(
vocab: "Vocab", wrapped_model: Model, mask_prob: float = 0.15
) -> Model:
"""Convert a model into a BERT-style masked language model""" """Convert a model into a BERT-style masked language model"""
random_words = _RandomWords(vocab) random_words = _RandomWords(vocab)
def mlm_forward(model, docs, is_train): def mlm_forward(model, docs, is_train):
@ -74,7 +90,7 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
return output, mlm_backward return output, mlm_backward
def mlm_initialize(model, X=None, Y=None): def mlm_initialize(model: Model, X=None, Y=None):
wrapped = model.layers[0] wrapped = model.layers[0]
wrapped.initialize(X=X, Y=Y) wrapped.initialize(X=X, Y=Y)
for dim in wrapped.dim_names: for dim in wrapped.dim_names:
@ -90,12 +106,11 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
dims={dim: None for dim in wrapped_model.dim_names}, dims={dim: None for dim in wrapped_model.dim_names},
) )
mlm_model.set_ref("wrapped", wrapped_model) mlm_model.set_ref("wrapped", wrapped_model)
return mlm_model return mlm_model
class _RandomWords: class _RandomWords:
def __init__(self, vocab): def __init__(self, vocab: "Vocab") -> None:
self.words = [lex.text for lex in vocab if lex.prob != 0.0] self.words = [lex.text for lex in vocab if lex.prob != 0.0]
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0] self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
self.words = self.words[:10000] self.words = self.words[:10000]
@ -104,7 +119,7 @@ class _RandomWords:
self.probs /= self.probs.sum() self.probs /= self.probs.sum()
self._cache = [] self._cache = []
def next(self): def next(self) -> str:
if not self._cache: if not self._cache:
self._cache.extend( self._cache.extend(
numpy.random.choice(len(self.words), 10000, p=self.probs) numpy.random.choice(len(self.words), 10000, p=self.probs)
@ -113,9 +128,11 @@ class _RandomWords:
return self.words[index] return self.words[index]
def _apply_mask(docs, random_words, mask_prob=0.15): def _apply_mask(
docs: Iterable["Doc"], random_words: _RandomWords, mask_prob: float = 0.15
) -> Tuple[numpy.ndarray, List["Doc"]]:
# This needs to be here to avoid circular imports # This needs to be here to avoid circular imports
from ...tokens import Doc from ...tokens import Doc # noqa: F811
N = sum(len(doc) for doc in docs) N = sum(len(doc) for doc in docs)
mask = numpy.random.uniform(0.0, 1.0, (N,)) mask = numpy.random.uniform(0.0, 1.0, (N,))
@ -141,7 +158,7 @@ def _apply_mask(docs, random_words, mask_prob=0.15):
return mask, masked_docs return mask, masked_docs
def _replace_word(word, random_words, mask="[MASK]"): def _replace_word(word: str, random_words: _RandomWords, mask: str = "[MASK]") -> str:
roll = numpy.random.random() roll = numpy.random.random()
if roll < 0.8: if roll < 0.8:
return mask return mask

View File

@ -1,6 +1,5 @@
from pydantic import StrictInt from typing import Optional
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops, with_array from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.api import LayerNorm, Maxout, Mish
from ...util import registry from ...util import registry
from .._precomputable_affine import PrecomputableAffine from .._precomputable_affine import PrecomputableAffine
@ -10,16 +9,15 @@ from ..tb_framework import TransitionModel
@registry.architectures.register("spacy.TransitionBasedParser.v1") @registry.architectures.register("spacy.TransitionBasedParser.v1")
def build_tb_parser_model( def build_tb_parser_model(
tok2vec: Model, tok2vec: Model,
nr_feature_tokens: StrictInt, nr_feature_tokens: int,
hidden_width: StrictInt, hidden_width: int,
maxout_pieces: StrictInt, maxout_pieces: int,
use_upper=True, use_upper: bool = True,
nO=None, nO: Optional[int] = None,
): ) -> Model:
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width),) tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width),)
tok2vec.set_dim("nO", hidden_width) tok2vec.set_dim("nO", hidden_width)
lower = PrecomputableAffine( lower = PrecomputableAffine(
nO=hidden_width if use_upper else nO, nO=hidden_width if use_upper else nO,
nF=nr_feature_tokens, nF=nr_feature_tokens,

View File

@ -26,7 +26,6 @@ def BiluoTagger(
with_array(softmax_activation()), with_array(softmax_activation()),
padded2list(), padded2list(),
) )
return Model( return Model(
"biluo-tagger", "biluo-tagger",
forward, forward,
@ -52,7 +51,6 @@ def IOBTagger(
with_array(softmax_activation()), with_array(softmax_activation()),
padded2list(), padded2list(),
) )
return Model( return Model(
"iob-tagger", "iob-tagger",
forward, forward,

View File

@ -1,10 +1,11 @@
from typing import Optional
from thinc.api import zero_init, with_array, Softmax, chain, Model from thinc.api import zero_init, with_array, Softmax, chain, Model
from ...util import registry from ...util import registry
@registry.architectures.register("spacy.Tagger.v1") @registry.architectures.register("spacy.Tagger.v1")
def build_tagger_model(tok2vec, nO=None) -> Model: def build_tagger_model(tok2vec: Model, nO: Optional[int] = None) -> Model:
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?! # TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
output_layer = Softmax(nO, t2v_width, init_W=zero_init) output_layer = Softmax(nO, t2v_width, init_W=zero_init)

View File

@ -2,10 +2,9 @@ from typing import Optional
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued from thinc.api import HashEmbed, with_array, with_cpu, uniqued
from thinc.api import Relu, residual, expand_window, FeatureExtractor from thinc.api import Relu, residual, expand_window, FeatureExtractor
from ... import util
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
from ...util import registry from ...util import registry
from ..extract_ngrams import extract_ngrams from ..extract_ngrams import extract_ngrams
@ -40,7 +39,12 @@ def build_simple_cnn_text_classifier(
@registry.architectures.register("spacy.TextCatBOW.v1") @registry.architectures.register("spacy.TextCatBOW.v1")
def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None): def build_bow_text_classifier(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
nO: Optional[int] = None,
) -> Model:
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
sparse_linear = SparseLinear(nO) sparse_linear = SparseLinear(nO)
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
@ -55,16 +59,16 @@ def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO
@registry.architectures.register("spacy.TextCatEnsemble.v1") @registry.architectures.register("spacy.TextCatEnsemble.v1")
def build_text_classifier( def build_text_classifier(
width, width: int,
embed_size, embed_size: int,
pretrained_vectors, pretrained_vectors: Optional[bool],
exclusive_classes, exclusive_classes: bool,
ngram_size, ngram_size: int,
window_size, window_size: int,
conv_depth, conv_depth: int,
dropout, dropout: Optional[float],
nO=None, nO: Optional[int] = None,
): ) -> Model:
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID] cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
lower = HashEmbed( lower = HashEmbed(
@ -91,7 +95,6 @@ def build_text_classifier(
dropout=dropout, dropout=dropout,
seed=13, seed=13,
) )
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape]) width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
trained_vectors = FeatureExtractor(cols) >> with_array( trained_vectors = FeatureExtractor(cols) >> with_array(
uniqued( uniqued(
@ -100,7 +103,6 @@ def build_text_classifier(
column=cols.index(ORTH), column=cols.index(ORTH),
) )
) )
if pretrained_vectors: if pretrained_vectors:
static_vectors = StaticVectors(width) static_vectors = StaticVectors(width)
vector_layer = trained_vectors | static_vectors vector_layer = trained_vectors | static_vectors
@ -152,7 +154,12 @@ def build_text_classifier(
@registry.architectures.register("spacy.TextCatLowData.v1") @registry.architectures.register("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None): def build_text_classifier_lowdata(
width: int,
pretrained_vectors: Optional[bool],
dropout: Optional[float],
nO: Optional[int] = None,
) -> Model:
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims" # Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
with Model.define_operators({">>": chain, "**": clone}): with Model.define_operators({">>": chain, "**": clone}):
model = ( model = (

View File

@ -6,16 +6,15 @@ from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
from thinc.types import Floats2d from thinc.types import Floats2d
from ...tokens import Doc from ...tokens import Doc
from ... import util
from ...util import registry from ...util import registry
from ...ml import _character_embed from ...ml import _character_embed
from ..staticvectors import StaticVectors from ..staticvectors import StaticVectors
from ...pipeline.tok2vec import Tok2VecListener from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE from ...attrs import ORTH, NORM, PREFIX, SUFFIX, SHAPE
@registry.architectures.register("spacy.Tok2VecListener.v1") @registry.architectures.register("spacy.Tok2VecListener.v1")
def tok2vec_listener_v1(width, upstream="*"): def tok2vec_listener_v1(width: int, upstream: str = "*"):
tok2vec = Tok2VecListener(upstream_name=upstream, width=width) tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
return tok2vec return tok2vec
@ -45,10 +44,11 @@ def build_hash_embed_cnn_tok2vec(
width=width, width=width,
depth=depth, depth=depth,
window_size=window_size, window_size=window_size,
maxout_pieces=maxout_pieces maxout_pieces=maxout_pieces,
) ),
) )
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def build_Tok2Vec_model( def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
@ -68,7 +68,6 @@ def MultiHashEmbed(
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
): ):
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH] cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
seed = 7 seed = 7
def make_hash_embed(feature): def make_hash_embed(feature):
@ -124,11 +123,11 @@ def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
chain( chain(
FeatureExtractor([NORM]), FeatureExtractor([NORM]),
list2ragged(), list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)) with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
) ),
), ),
with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)), with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)),
ragged2list() ragged2list(),
) )
return model return model
@ -155,12 +154,7 @@ def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth:
def MishWindowEncoder(width, window_size, depth): def MishWindowEncoder(width, window_size, depth):
cnn = chain( cnn = chain(
expand_window(window_size=window_size), expand_window(window_size=window_size),
Mish( Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
nO=width,
nI=width * ((window_size * 2) + 1),
dropout=0.0,
normalize=True
),
) )
model = clone(residual(cnn), depth) model = clone(residual(cnn), depth)
model.set_dim("nO", width) model.set_dim("nO", width)

View File

@ -1,8 +1,6 @@
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc from ..typedefs cimport weight_t, hash_t
from ..typedefs cimport weight_t, class_t, hash_t from ..pipeline._parser_internals._state cimport StateC
from ._state cimport StateC
cdef struct SizesC: cdef struct SizesC:

View File

@ -1,29 +1,18 @@
# cython: infer_types=True, cdivision=True, boundscheck=False # cython: infer_types=True, cdivision=True, boundscheck=False
cimport cython.parallel
cimport numpy as np cimport numpy as np
from libc.math cimport exp from libc.math cimport exp
from libcpp.vector cimport vector
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc from libc.stdlib cimport calloc, free, realloc
from cymem.cymem cimport Pool
from thinc.extra.search cimport Beam
from thinc.backends.linalg cimport Vec, VecVec from thinc.backends.linalg cimport Vec, VecVec
cimport blis.cy cimport blis.cy
import numpy import numpy
import numpy.random import numpy.random
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop from thinc.api import Model, CupyOps, NumpyOps
from ..typedefs cimport weight_t, class_t, hash_t
from ..tokens.doc cimport Doc
from .stateclass cimport StateClass
from .transition_system cimport Transition
from ..compat import copy_array
from ..errors import Errors, TempErrors
from ..util import create_default_optimizer
from .. import util from .. import util
from . import nonproj from ..typedefs cimport weight_t, class_t, hash_t
from ..pipeline._parser_internals.stateclass cimport StateClass
cdef WeightsC get_c_weights(model) except *: cdef WeightsC get_c_weights(model) except *:

View File

@ -1,5 +1,5 @@
from thinc.api import Model, noop, use_ops, Linear from thinc.api import Model, noop, use_ops, Linear
from ..syntax._parser_model import ParserStepModel from .parser_model import ParserStepModel
def TransitionModel(tok2vec, lower, upper, dropout=0.2, unseen_classes=set()): def TransitionModel(tok2vec, lower, upper, dropout=0.2, unseen_classes=set()):

View File

@ -1,9 +1,8 @@
from typing import List, Dict, Iterable, Optional, Union, TYPE_CHECKING from typing import List, Dict, Iterable, Optional, Union, TYPE_CHECKING
from wasabi import Printer from wasabi import msg
import warnings
from .tokens import Doc, Token, Span from .tokens import Doc, Token, Span
from .errors import Errors, Warnings from .errors import Errors
from .util import dot_to_dict from .util import dot_to_dict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -11,48 +10,7 @@ if TYPE_CHECKING:
from .language import Language # noqa: F401 from .language import Language # noqa: F401
def analyze_pipes( DEFAULT_KEYS = ["requires", "assigns", "scores", "retokenizes"]
nlp: "Language", name: str, index: int, warn: bool = True
) -> List[str]:
"""Analyze a pipeline component with respect to its position in the current
pipeline and the other components. Will check whether requirements are
fulfilled (e.g. if previous components assign the attributes).
nlp (Language): The current nlp object.
name (str): The name of the pipeline component to analyze.
index (int): The index of the component in the pipeline.
warn (bool): Show user warning if problem is found.
RETURNS (List[str]): The problems found for the given pipeline component.
"""
assert nlp.pipeline[index][0] == name
prev_pipes = nlp.pipeline[:index]
meta = nlp.get_pipe_meta(name)
requires = {annot: False for annot in meta.requires}
if requires:
for prev_name, prev_pipe in prev_pipes:
prev_meta = nlp.get_pipe_meta(prev_name)
for annot in prev_meta.assigns:
requires[annot] = True
problems = []
for annot, fulfilled in requires.items():
if not fulfilled:
problems.append(annot)
if warn:
warnings.warn(Warnings.W025.format(name=name, attr=annot))
return problems
def analyze_all_pipes(nlp: "Language", warn: bool = True) -> Dict[str, List[str]]:
"""Analyze all pipes in the pipeline in order.
nlp (Language): The current nlp object.
warn (bool): Show user warning if problem is found.
RETURNS (Dict[str, List[str]]): The problems found, keyed by component name.
"""
problems = {}
for i, name in enumerate(nlp.pipe_names):
problems[name] = analyze_pipes(nlp, name, i, warn=warn)
return problems
def validate_attrs(values: Iterable[str]) -> Iterable[str]: def validate_attrs(values: Iterable[str]) -> Iterable[str]:
@ -101,89 +59,77 @@ def validate_attrs(values: Iterable[str]) -> Iterable[str]:
return values return values
def _get_feature_for_attr(nlp: "Language", attr: str, feature: str) -> List[str]: def get_attr_info(nlp: "Language", attr: str) -> Dict[str, List[str]]:
assert feature in ["assigns", "requires"] """Check which components in the pipeline assign or require an attribute.
result = []
nlp (Language): The current nlp object.
attr (str): The attribute, e.g. "doc.tensor".
RETURNS (Dict[str, List[str]]): A dict keyed by "assigns" and "requires",
mapped to a list of component names.
"""
result = {"assigns": [], "requires": []}
for pipe_name in nlp.pipe_names: for pipe_name in nlp.pipe_names:
meta = nlp.get_pipe_meta(pipe_name) meta = nlp.get_pipe_meta(pipe_name)
pipe_assigns = getattr(meta, feature, []) if attr in meta.assigns:
if attr in pipe_assigns: result["assigns"].append(pipe_name)
result.append(pipe_name) if attr in meta.requires:
result["requires"].append(pipe_name)
return result return result
def get_assigns_for_attr(nlp: "Language", attr: str) -> List[str]: def analyze_pipes(
"""Get all pipeline components that assign an attr, e.g. "doc.tensor". nlp: "Language", *, keys: List[str] = DEFAULT_KEYS,
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
pipeline (Language): The current nlp object.
attr (str): The attribute to check.
RETURNS (List[str]): Names of components that require the attr.
"""
return _get_feature_for_attr(nlp, attr, "assigns")
def get_requires_for_attr(nlp: "Language", attr: str) -> List[str]:
"""Get all pipeline components that require an attr, e.g. "doc.tensor".
pipeline (Language): The current nlp object.
attr (str): The attribute to check.
RETURNS (List[str]): Names of components that require the attr.
"""
return _get_feature_for_attr(nlp, attr, "requires")
def print_summary(
nlp: "Language", pretty: bool = True, no_print: bool = False
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
"""Print a formatted summary for the current nlp object's pipeline. Shows """Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as a table with the pipeline components and why they assign and require, as
well as any problems if available. well as any problems if available.
nlp (Language): The nlp object. nlp (Language): The nlp object.
pretty (bool): Pretty-print the results (color etc). keys (List[str]): The meta keys to show in the table.
no_print (bool): Don't print anything, just return the data. RETURNS (dict): A dict with "summary" and "problems".
RETURNS (dict): A dict with "overview" and "problems".
""" """
msg = Printer(pretty=pretty, no_print=no_print) result = {"summary": {}, "problems": {}}
overview = [] all_attrs = set()
problems = {}
for i, name in enumerate(nlp.pipe_names): for i, name in enumerate(nlp.pipe_names):
meta = nlp.get_pipe_meta(name) meta = nlp.get_pipe_meta(name)
overview.append((i, name, meta.requires, meta.assigns, meta.retokenizes)) all_attrs.update(meta.assigns)
problems[name] = analyze_pipes(nlp, name, i, warn=False) all_attrs.update(meta.requires)
result["summary"][name] = {key: getattr(meta, key, None) for key in keys}
prev_pipes = nlp.pipeline[:i]
requires = {annot: False for annot in meta.requires}
if requires:
for prev_name, prev_pipe in prev_pipes:
prev_meta = nlp.get_pipe_meta(prev_name)
for annot in prev_meta.assigns:
requires[annot] = True
result["problems"][name] = []
for annot, fulfilled in requires.items():
if not fulfilled:
result["problems"][name].append(annot)
result["attrs"] = {attr: get_attr_info(nlp, attr) for attr in all_attrs}
return result
def print_pipe_analysis(
analysis: Dict[str, Union[List[str], Dict[str, List[str]]]],
*,
keys: List[str] = DEFAULT_KEYS,
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
"""Print a formatted version of the pipe analysis produced by analyze_pipes.
analysis (Dict[str, Union[List[str], Dict[str, List[str]]]]): The analysis.
keys (List[str]): The meta keys to show in the table.
"""
msg.divider("Pipeline Overview") msg.divider("Pipeline Overview")
header = ("#", "Component", "Requires", "Assigns", "Retokenizes") header = ["#", "Component", *[key.capitalize() for key in keys]]
msg.table(overview, header=header, divider=True, multiline=True) summary = analysis["summary"].items()
n_problems = sum(len(p) for p in problems.values()) body = [[i, n, *[v for v in m.values()]] for i, (n, m) in enumerate(summary)]
if any(p for p in problems.values()): msg.table(body, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in analysis["problems"].values())
if any(p for p in analysis["problems"].values()):
msg.divider(f"Problems ({n_problems})") msg.divider(f"Problems ({n_problems})")
for name, problem in problems.items(): for name, problem in analysis["problems"].items():
if problem: if problem:
msg.warn(f"'{name}' requirements not met: {', '.join(problem)}") msg.warn(f"'{name}' requirements not met: {', '.join(problem)}")
else: else:
msg.good("No problems found.") msg.good("No problems found.")
if no_print:
return {"overview": overview, "problems": problems}
def count_pipeline_interdependencies(nlp: "Language") -> List[int]:
"""Count how many subsequent components require an annotation set by each
component in the pipeline.
nlp (Language): The current nlp object.
RETURNS (List[int]): The interdependency counts.
"""
pipe_assigns = []
pipe_requires = []
for name in nlp.pipe_names:
meta = nlp.get_pipe_meta(name)
pipe_assigns.append(set(meta.assigns))
pipe_requires.append(set(meta.requires))
counts = []
for i, assigns in enumerate(pipe_assigns):
count = 0
for requires in pipe_requires[i + 1 :]:
if assigns.intersection(requires):
count += 1
counts.append(count)
return counts

View File

@ -1,3 +1,4 @@
from .attributeruler import AttributeRuler
from .dep_parser import DependencyParser from .dep_parser import DependencyParser
from .entity_linker import EntityLinker from .entity_linker import EntityLinker
from .ner import EntityRecognizer from .ner import EntityRecognizer
@ -13,6 +14,7 @@ from .tok2vec import Tok2Vec
from .functions import merge_entities, merge_noun_chunks, merge_subtokens from .functions import merge_entities, merge_noun_chunks, merge_subtokens
__all__ = [ __all__ = [
"AttributeRuler",
"DependencyParser", "DependencyParser",
"EntityLinker", "EntityLinker",
"EntityRecognizer", "EntityRecognizer",

View File

@ -1,15 +1,14 @@
from libc.string cimport memcpy, memset, memmove from libc.string cimport memcpy, memset
from libc.stdlib cimport malloc, calloc, free from libc.stdlib cimport calloc, free
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from ..vocab cimport EMPTY_LEXEME from ...vocab cimport EMPTY_LEXEME
from ..structs cimport TokenC, SpanC from ...structs cimport TokenC, SpanC
from ..lexeme cimport Lexeme from ...lexeme cimport Lexeme
from ..symbols cimport punct from ...attrs cimport IS_SPACE
from ..attrs cimport IS_SPACE from ...typedefs cimport attr_t
from ..typedefs cimport attr_t
cdef inline bint is_space_token(const TokenC* token) nogil: cdef inline bint is_space_token(const TokenC* token) nogil:

View File

@ -1,8 +1,6 @@
from cymem.cymem cimport Pool
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ..typedefs cimport weight_t, attr_t from ...typedefs cimport weight_t, attr_t
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport Transition, TransitionSystem
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):

View File

@ -1,24 +1,17 @@
# cython: profile=True, cdivision=True, infer_types=True # cython: profile=True, cdivision=True, infer_types=True
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from collections import defaultdict, Counter from collections import defaultdict, Counter
import json
from ..typedefs cimport hash_t, attr_t from ...typedefs cimport hash_t, attr_t
from ..strings cimport hash_string from ...strings cimport hash_string
from ..structs cimport TokenC from ...structs cimport TokenC
from ..tokens.doc cimport Doc, set_children_from_heads from ...tokens.doc cimport Doc, set_children_from_heads
from ...gold.example cimport Example
from ...errors import Errors
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold.example cimport Example
from ..errors import Errors
from .nonproj import is_nonproj_tree
from . import nonproj
# Calculate cost as gold/not gold. We don't use scalar value anyway. # Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1 cdef int BINARY_COSTS = 1

View File

@ -1,6 +1,4 @@
from .transition_system cimport TransitionSystem from .transition_system cimport TransitionSystem
from .transition_system cimport Transition
from ..typedefs cimport attr_t
cdef class BiluoPushDown(TransitionSystem): cdef class BiluoPushDown(TransitionSystem):

View File

@ -2,17 +2,14 @@ from collections import Counter
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..typedefs cimport weight_t from ...typedefs cimport weight_t, attr_t
from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE
from ...gold.example cimport Example
from ...errors import Errors
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from .transition_system cimport Transition from .transition_system cimport Transition, do_func_t
from .transition_system cimport do_func_t
from ..lexeme cimport Lexeme
from ..attrs cimport IS_SPACE
from ..gold.iob_utils import biluo_tags_from_offsets
from ..gold.example cimport Example
from ..errors import Errors
cdef enum: cdef enum:

View File

@ -5,9 +5,9 @@ scheme.
""" """
from copy import copy from copy import copy
from ..tokens.doc cimport Doc, set_children_from_heads from ...tokens.doc cimport Doc, set_children_from_heads
from ..errors import Errors from ...errors import Errors
DELIMITER = '||' DELIMITER = '||'

View File

@ -1,12 +1,8 @@
from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
cimport cython
from ..structs cimport TokenC, SpanC from ...structs cimport TokenC, SpanC
from ..typedefs cimport attr_t from ...typedefs cimport attr_t
from ..vocab cimport EMPTY_LEXEME
from ._state cimport StateC from ._state cimport StateC

View File

@ -1,7 +1,7 @@
# cython: infer_types=True # cython: infer_types=True
import numpy import numpy
from ..tokens.doc cimport Doc from ...tokens.doc cimport Doc
cdef class StateClass: cdef class StateClass:

View File

@ -1,11 +1,11 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..typedefs cimport attr_t, weight_t from ...typedefs cimport attr_t, weight_t
from ..structs cimport TokenC from ...structs cimport TokenC
from ..strings cimport StringStore from ...strings cimport StringStore
from ...gold.example cimport Example
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from ..gold.example cimport Example
cdef struct Transition: cdef struct Transition:

View File

@ -1,19 +1,17 @@
# cython: infer_types=True # cython: infer_types=True
from __future__ import print_function from __future__ import print_function
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from collections import Counter from collections import Counter
import srsly import srsly
from ..typedefs cimport weight_t from ...typedefs cimport weight_t, attr_t
from ..tokens.doc cimport Doc from ...tokens.doc cimport Doc
from ..structs cimport TokenC from ...structs cimport TokenC
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ..typedefs cimport attr_t
from ..errors import Errors from ...errors import Errors
from .. import util from ... import util
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000

View File

@ -0,0 +1,266 @@
import srsly
from typing import List, Dict, Union, Iterable, Any, Optional
from pathlib import Path
from .pipe import Pipe
from ..errors import Errors
from ..language import Language
from ..matcher import Matcher
from ..symbols import IDS
from ..tokens import Doc, Span
from ..tokens._retokenize import normalize_token_attrs, set_token_attrs
from ..vocab import Vocab
from .. import util
MatcherPatternType = List[Dict[Union[int, str], Any]]
AttributeRulerPatternType = Dict[str, Union[MatcherPatternType, Dict, int]]
@Language.factory("attribute_ruler")
def make_attribute_ruler(
nlp: Language,
name: str,
pattern_dicts: Optional[Iterable[AttributeRulerPatternType]] = None,
):
return AttributeRuler(nlp.vocab, name, pattern_dicts=pattern_dicts)
class AttributeRuler(Pipe):
"""Set token-level attributes for tokens matched by Matcher patterns.
Additionally supports importing patterns from tag maps and morph rules.
DOCS: https://spacy.io/api/attributeruler
"""
def __init__(
self,
vocab: Vocab,
name: str = "attribute_ruler",
*,
pattern_dicts: Optional[Iterable[AttributeRulerPatternType]] = None,
) -> None:
"""Initialize the AttributeRuler.
vocab (Vocab): The vocab.
name (str): The pipe name. Defaults to "attribute_ruler".
pattern_dicts (Iterable[Dict]): A list of pattern dicts with the keys as
the arguments to AttributeRuler.add (`patterns`/`attrs`/`index`) to add
as patterns.
RETURNS (AttributeRuler): The AttributeRuler component.
DOCS: https://spacy.io/api/attributeruler#init
"""
self.name = name
self.vocab = vocab
self.matcher = Matcher(self.vocab)
self.attrs = []
self._attrs_unnormed = [] # store for reference
self.indices = []
if pattern_dicts:
self.add_patterns(pattern_dicts)
def __call__(self, doc: Doc) -> Doc:
"""Apply the attributeruler to a Doc and set all attribute exceptions.
doc (Doc): The document to process.
RETURNS (Doc): The processed Doc.
DOCS: https://spacy.io/api/attributeruler#call
"""
matches = self.matcher(doc)
for match_id, start, end in matches:
span = Span(doc, start, end, label=match_id)
attrs = self.attrs[span.label]
index = self.indices[span.label]
try:
token = span[index]
except IndexError:
raise ValueError(
Errors.E1001.format(
patterns=self.matcher.get(span.label),
span=[t.text for t in span],
index=index,
)
)
set_token_attrs(token, attrs)
return doc
def load_from_tag_map(
self, tag_map: Dict[str, Dict[Union[int, str], Union[int, str]]]
) -> None:
for tag, attrs in tag_map.items():
pattern = [{"TAG": tag}]
attrs, morph_attrs = _split_morph_attrs(attrs)
morph = self.vocab.morphology.add(morph_attrs)
attrs["MORPH"] = self.vocab.strings[morph]
self.add([pattern], attrs)
def load_from_morph_rules(
self, morph_rules: Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]]
) -> None:
for tag in morph_rules:
for word in morph_rules[tag]:
pattern = [{"ORTH": word, "TAG": tag}]
attrs = morph_rules[tag][word]
attrs, morph_attrs = _split_morph_attrs(attrs)
morph = self.vocab.morphology.add(morph_attrs)
attrs["MORPH"] = self.vocab.strings[morph]
self.add([pattern], attrs)
def add(
self, patterns: Iterable[MatcherPatternType], attrs: Dict, index: int = 0
) -> None:
"""Add Matcher patterns for tokens that should be modified with the
provided attributes. The token at the specified index within the
matched span will be assigned the attributes.
patterns (Iterable[List[Dict]]): A list of Matcher patterns.
attrs (Dict): The attributes to assign to the target token in the
matched span.
index (int): The index of the token in the matched span to modify. May
be negative to index from the end of the span. Defaults to 0.
DOCS: https://spacy.io/api/attributeruler#add
"""
self.matcher.add(len(self.attrs), patterns)
self._attrs_unnormed.append(attrs)
attrs = normalize_token_attrs(self.vocab, attrs)
self.attrs.append(attrs)
self.indices.append(index)
def add_patterns(self, pattern_dicts: Iterable[AttributeRulerPatternType]) -> None:
for p in pattern_dicts:
self.add(**p)
@property
def patterns(self) -> List[AttributeRulerPatternType]:
all_patterns = []
for i in range(len(self.attrs)):
p = {}
p["patterns"] = self.matcher.get(i)[1]
p["attrs"] = self._attrs_unnormed[i]
p["index"] = self.indices[i]
all_patterns.append(p)
return all_patterns
def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes:
"""Serialize the attributeruler to a bytestring.
exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
DOCS: https://spacy.io/api/attributeruler#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))}
serialize["patterns"] = lambda: srsly.msgpack_dumps(patterns)
serialize["attrs"] = lambda: srsly.msgpack_dumps(self.attrs)
serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices)
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()):
"""Load the attributeruler from a bytestring.
bytes_data (bytes): The data to load.
exclude (Iterable[str]): String names of serialization fields to exclude.
returns (AttributeRuler): The loaded object.
DOCS: https://spacy.io/api/attributeruler#from_bytes
"""
data = {"patterns": b""}
def load_patterns(b):
data["patterns"] = srsly.msgpack_loads(b)
def load_attrs(b):
self.attrs = srsly.msgpack_loads(b)
def load_indices(b):
self.indices = srsly.msgpack_loads(b)
deserialize = {
"vocab": lambda b: self.vocab.from_bytes(b),
"patterns": load_patterns,
"attrs": load_attrs,
"indices": load_indices,
}
util.from_bytes(bytes_data, deserialize, exclude)
if data["patterns"]:
for key, pattern in data["patterns"].items():
self.matcher.add(key, pattern)
assert len(self.attrs) == len(data["patterns"])
assert len(self.indices) == len(data["patterns"])
return self
def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None:
"""Serialize the attributeruler to disk.
path (Union[Path, str]): A path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/attributeruler#to_disk
"""
patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))}
serialize = {
"vocab": lambda p: self.vocab.to_disk(p),
"patterns": lambda p: srsly.write_msgpack(p, patterns),
"attrs": lambda p: srsly.write_msgpack(p, self.attrs),
"indices": lambda p: srsly.write_msgpack(p, self.indices),
}
util.to_disk(path, serialize, exclude)
def from_disk(
self, path: Union[Path, str], exclude: Iterable[str] = tuple()
) -> None:
"""Load the attributeruler from disk.
path (Union[Path, str]): A path to a directory.
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/attributeruler#from_disk
"""
data = {"patterns": b""}
def load_patterns(p):
data["patterns"] = srsly.read_msgpack(p)
def load_attrs(p):
self.attrs = srsly.read_msgpack(p)
def load_indices(p):
self.indices = srsly.read_msgpack(p)
deserialize = {
"vocab": lambda p: self.vocab.from_disk(p),
"patterns": load_patterns,
"attrs": load_attrs,
"indices": load_indices,
}
util.from_disk(path, deserialize, exclude)
if data["patterns"]:
for key, pattern in data["patterns"].items():
self.matcher.add(key, pattern)
assert len(self.attrs) == len(data["patterns"])
assert len(self.indices) == len(data["patterns"])
return self
def _split_morph_attrs(attrs):
"""Split entries from a tag map or morph rules dict into to two dicts, one
with the token-level features (POS, LEMMA) and one with the remaining
features, which are presumed to be individual MORPH features."""
other_attrs = {}
morph_attrs = {}
for k, v in attrs.items():
if k in "_" or k in IDS.keys() or k in IDS.values():
other_attrs[k] = v
else:
morph_attrs[k] = v
return other_attrs, morph_attrs

View File

@ -1,13 +1,13 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Optional, Iterable from typing import Optional, Iterable
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config from thinc.api import Model, Config
from ..syntax.nn_parser cimport Parser from .transition_parser cimport Parser
from ..syntax.arc_eager cimport ArcEager from ._parser_internals.arc_eager cimport ArcEager
from .functions import merge_subtokens from .functions import merge_subtokens
from ..language import Language from ..language import Language
from ..syntax import nonproj from ._parser_internals import nonproj
from ..scorer import Scorer from ..scorer import Scorer
@ -34,7 +34,7 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"parser", "parser",
assigns=["token.dep", "token.is_sent_start", "doc.sents"], assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
default_config={ default_config={
"moves": None, "moves": None,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
@ -120,7 +120,8 @@ cdef class DependencyParser(Parser):
return dep return dep
results = {} results = {}
results.update(Scorer.score_spans(examples, "sents", **kwargs)) results.update(Scorer.score_spans(examples, "sents", **kwargs))
results.update(Scorer.score_deps(examples, "dep", getter=dep_getter, kwargs.setdefault("getter", dep_getter)
ignore_labels=("p", "punct"), **kwargs)) kwargs.setdefault("ignore_label", ("p", "punct"))
results.update(Scorer.score_deps(examples, "dep", **kwargs))
del results["sents_per_type"] del results["sents_per_type"]
return results return results

View File

@ -33,24 +33,31 @@ dropout = null
""" """
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_kb_config = """
[kb]
@assets = "spacy.EmptyKB.v1"
entity_vector_length = 64
"""
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
@Language.factory( @Language.factory(
"entity_linker", "entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"], assigns=["token.ent_kb_id"],
default_config={ default_config={
"kb": None, # TODO - what kind of default makes sense here? "kb": DEFAULT_NEL_KB,
"model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"model": DEFAULT_NEL_MODEL,
}, },
) )
def make_entity_linker( def make_entity_linker(
nlp: Language, nlp: Language,
name: str, name: str,
model: Model, model: Model,
kb: Optional[KnowledgeBase], kb: KnowledgeBase,
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
@ -92,10 +99,10 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
kb (KnowledgeBase): TODO: kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases.
labels_discard (Iterable[str]): TODO: labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): TODO: incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): TODO: incl_context (bool): Whether or not to include the local context in the model.
DOCS: https://spacy.io/api/entitylinker#init DOCS: https://spacy.io/api/entitylinker#init
""" """
@ -108,14 +115,12 @@ class EntityLinker(Pipe):
"incl_prior": incl_prior, "incl_prior": incl_prior,
"incl_context": incl_context, "incl_context": incl_context,
} }
self.kb = kb if not isinstance(kb, KnowledgeBase):
if self.kb is None:
# create an empty KB that should be filled by calling from_disk
self.kb = KnowledgeBase(vocab=vocab)
else:
del cfg["kb"] # we don't want to duplicate its serialization
if not isinstance(self.kb, KnowledgeBase):
raise ValueError(Errors.E990.format(type=type(self.kb))) raise ValueError(Errors.E990.format(type=type(self.kb)))
kb.initialize(vocab)
self.kb = kb
if "kb" in cfg:
del cfg["kb"] # we don't want to duplicate its serialization
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account # how many neightbour sentences to take into account
@ -222,9 +227,9 @@ class EntityLinker(Pipe):
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
if not sentence_docs: if not sentence_docs:
warnings.warn(Warnings.W093.format(name="Entity Linker")) warnings.warn(Warnings.W093.format(name="Entity Linker"))
return 0.0 return losses
sentence_encodings, bp_context = self.model.begin_update(sentence_docs) sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss( loss, d_scores = self.get_loss(
sentence_encodings=sentence_encodings, examples=examples sentence_encodings=sentence_encodings, examples=examples
) )
bp_context(d_scores) bp_context(d_scores)
@ -235,7 +240,7 @@ class EntityLinker(Pipe):
self.set_annotations(docs, predictions) self.set_annotations(docs, predictions)
return losses return losses
def get_similarity_loss(self, examples: Iterable[Example], sentence_encodings): def get_loss(self, examples: Iterable[Example], sentence_encodings):
entity_encodings = [] entity_encodings = []
for eg in examples: for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
@ -247,7 +252,7 @@ class EntityLinker(Pipe):
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
if sentence_encodings.shape != entity_encodings.shape: if sentence_encodings.shape != entity_encodings.shape:
err = Errors.E147.format( err = Errors.E147.format(
method="get_similarity_loss", msg="gold entities do not match up" method="get_loss", msg="gold entities do not match up"
) )
raise RuntimeError(err) raise RuntimeError(err)
gradients = self.distance.get_grad(sentence_encodings, entity_encodings) gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
@ -337,13 +342,13 @@ class EntityLinker(Pipe):
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
else: else:
random.shuffle(candidates) random.shuffle(candidates)
# this will set all prior probabilities to 0 if they should be excluded from the model # set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray( prior_probs = xp.asarray(
[c.prior_prob for c in candidates] [c.prior_prob for c in candidates]
) )
if not self.cfg.get("incl_prior"): if not self.cfg.get("incl_prior"):
prior_probs = xp.asarray( prior_probs = xp.asarray(
[0.0 for c in candidates] [0.0 for _ in candidates]
) )
scores = prior_probs scores = prior_probs
# add in similarity from the context # add in similarity from the context
@ -437,9 +442,8 @@ class EntityLinker(Pipe):
raise ValueError(Errors.E149) raise ValueError(Errors.E149)
def load_kb(p): def load_kb(p):
self.kb = KnowledgeBase( self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
vocab=self.vocab, entity_vector_length=self.cfg["entity_width"] self.kb.initialize(self.vocab)
)
self.kb.load_bulk(p) self.kb.load_bulk(p)
deserialize = {} deserialize = {}

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Optional from typing import Optional
import numpy import numpy
from thinc.api import CosineDistance, to_categorical, to_categorical, Model, Config from thinc.api import CosineDistance, to_categorical, Model, Config
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -9,7 +9,7 @@ from ..tokens.doc cimport Doc
from .pipe import Pipe from .pipe import Pipe
from .tagger import Tagger from .tagger import Tagger
from ..language import Language from ..language import Language
from ..syntax import nonproj from ._parser_internals import nonproj
from ..attrs import POS, ID from ..attrs import POS, ID
from ..errors import Errors from ..errors import Errors
@ -219,3 +219,6 @@ class ClozeMultitask(Pipe):
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
def add_label(self, label):
raise NotImplementedError

View File

@ -1,9 +1,9 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Optional, Iterable from typing import Optional, Iterable
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config from thinc.api import Model, Config
from ..syntax.nn_parser cimport Parser from .transition_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language from ..language import Language
from ..scorer import Scorer from ..scorer import Scorer

2
spacy/pipeline/pipe.pxd Normal file
View File

@ -0,0 +1,2 @@
cdef class Pipe:
cdef public str name

View File

@ -8,7 +8,7 @@ from ..errors import Errors
from .. import util from .. import util
class Pipe: cdef class Pipe:
"""This class is a base class and not instantiated directly. Trainable """This class is a base class and not instantiated directly. Trainable
pipeline components like the EntityRecognizer or TextCategorizer inherit pipeline components like the EntityRecognizer or TextCategorizer inherit
from it and it defines the interface that components should follow to from it and it defines the interface that components should follow to
@ -17,8 +17,6 @@ class Pipe:
DOCS: https://spacy.io/api/pipe DOCS: https://spacy.io/api/pipe
""" """
name = None
def __init__(self, vocab, model, name, **cfg): def __init__(self, vocab, model, name, **cfg):
"""Initialize a pipeline component. """Initialize a pipeline component.

View File

@ -203,3 +203,9 @@ class Sentencizer(Pipe):
cfg = srsly.read_json(path) cfg = srsly.read_json(path)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars)) self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self return self
def get_loss(self, examples, scores):
raise NotImplementedError
def add_label(self, label):
raise NotImplementedError

View File

@ -108,8 +108,8 @@ class SentenceRecognizer(Tagger):
truths = [] truths = []
for eg in examples: for eg in examples:
eg_truth = [] eg_truth = []
for x in eg.get_aligned("sent_start"): for x in eg.get_aligned("SENT_START"):
if x == None: if x is None:
eg_truth.append(None) eg_truth.append(None)
elif x == 1: elif x == 1:
eg_truth.append(labels[1]) eg_truth.append(labels[1])

View File

@ -4,12 +4,12 @@ from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate, Model
from thinc.api import Optimizer, Config from thinc.api import Optimizer, Config
from thinc.util import to_numpy from thinc.util import to_numpy
from ..errors import Errors
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
from ..tokens import Doc from ..tokens import Doc
from ..language import Language from ..language import Language
from ..vocab import Vocab from ..vocab import Vocab
from ..scorer import Scorer from ..scorer import Scorer
from .. import util
from .pipe import Pipe from .pipe import Pipe
@ -37,7 +37,6 @@ DEFAULT_SIMPLE_NER_MODEL = Config().from_str(default_model_config)["model"]
default_config={"labels": [], "model": DEFAULT_SIMPLE_NER_MODEL}, default_config={"labels": [], "model": DEFAULT_SIMPLE_NER_MODEL},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"], scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0}, default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
) )
def make_simple_ner( def make_simple_ner(
nlp: Language, name: str, model: Model, labels: Iterable[str] nlp: Language, name: str, model: Model, labels: Iterable[str]
@ -60,7 +59,9 @@ class SimpleNER(Pipe):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.name = name self.name = name
self.labels = labels self.cfg = {"labels": []}
for label in labels:
self.add_label(label)
self.loss_func = SequenceCategoricalCrossentropy( self.loss_func = SequenceCategoricalCrossentropy(
names=self.get_tag_names(), normalize=True, missing_value=None names=self.get_tag_names(), normalize=True, missing_value=None
) )
@ -70,9 +71,20 @@ class SimpleNER(Pipe):
def is_biluo(self) -> bool: def is_biluo(self) -> bool:
return self.model.name.startswith("biluo") return self.model.name.startswith("biluo")
@property
def labels(self) -> Tuple[str]:
return tuple(self.cfg["labels"])
def add_label(self, label: str) -> None: def add_label(self, label: str) -> None:
"""Add a new label to the pipe.
label (str): The label to add.
DOCS: https://spacy.io/api/simplener#add_label
"""
if not isinstance(label, str):
raise ValueError(Errors.E187)
if label not in self.labels: if label not in self.labels:
self.labels.append(label) self.cfg["labels"].append(label)
self.vocab.strings.add(label)
def get_tag_names(self) -> List[str]: def get_tag_names(self) -> List[str]:
if self.is_biluo: if self.is_biluo:
@ -131,11 +143,9 @@ class SimpleNER(Pipe):
return losses return losses
def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]: def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]:
loss = 0
d_scores = []
truths = [] truths = []
for eg in examples: for eg in examples:
tags = eg.get_aligned("TAG", as_string=True) tags = eg.get_aligned_ner()
gold_tags = [(tag if tag != "-" else None) for tag in tags] gold_tags = [(tag if tag != "-" else None) for tag in tags]
if not self.is_biluo: if not self.is_biluo:
gold_tags = biluo_to_iob(gold_tags) gold_tags = biluo_to_iob(gold_tags)
@ -159,7 +169,6 @@ class SimpleNER(Pipe):
if not hasattr(get_examples, "__call__"): if not hasattr(get_examples, "__call__"):
gold_tuples = get_examples gold_tuples = get_examples
get_examples = lambda: gold_tuples get_examples = lambda: gold_tuples
labels = _get_labels(get_examples())
for label in _get_labels(get_examples()): for label in _get_labels(get_examples()):
self.add_label(label) self.add_label(label)
labels = self.labels labels = self.labels

View File

@ -259,7 +259,7 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger#get_loss DOCS: https://spacy.io/api/tagger#get_loss
""" """
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False) loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
truths = [eg.get_aligned("tag", as_string=True) for eg in examples] truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
d_scores, loss = loss_func(scores, truths) d_scores, loss = loss_func(scores, truths)
if self.model.ops.xp.isnan(loss): if self.model.ops.xp.isnan(loss):
raise ValueError("nan value when computing loss") raise ValueError("nan value when computing loss")

View File

@ -238,8 +238,11 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer#rehearse DOCS: https://spacy.io/api/textcategorizer#rehearse
""" """
if losses is not None:
losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None: if self._rehearsal_model is None:
return return losses
try: try:
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
except AttributeError: except AttributeError:
@ -250,7 +253,7 @@ class TextCategorizer(Pipe):
raise TypeError(err) raise TypeError(err)
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs) scores, bp_scores = self.model.begin_update(docs)
target = self._rehearsal_model(examples) target = self._rehearsal_model(examples)
@ -259,7 +262,6 @@ class TextCategorizer(Pipe):
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient ** 2).sum() losses[self.name] += (gradient ** 2).sum()
return losses return losses
@ -353,7 +355,7 @@ class TextCategorizer(Pipe):
for cat in y.cats: for cat in y.cats:
self.add_label(cat) self.add_label(cat)
self.require_labels() self.require_labels()
docs = [Doc(Vocab(), words=["hello"])] docs = [Doc(self.vocab, words=["hello"])]
truths, _ = self._examples_to_truth(examples) truths, _ = self._examples_to_truth(examples)
self.set_output(len(self.labels)) self.set_output(len(self.labels))
self.model.initialize(X=docs, Y=truths) self.model.initialize(X=docs, Y=truths)

View File

@ -199,6 +199,9 @@ class Tok2Vec(Pipe):
docs = [Doc(self.vocab, words=["hello"])] docs = [Doc(self.vocab, words=["hello"])]
self.model.initialize(X=docs) self.model.initialize(X=docs)
def add_label(self, label):
raise NotImplementedError
class Tok2VecListener(Model): class Tok2VecListener(Model):
"""A layer that gets fed its answers from an upstream connection, """A layer that gets fed its answers from an upstream connection,

View File

@ -1,16 +1,15 @@
from .stateclass cimport StateClass from cymem.cymem cimport Pool
from .arc_eager cimport TransitionSystem
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..tokens.doc cimport Doc from .pipe cimport Pipe
from ..structs cimport TokenC from ._parser_internals.transition_system cimport Transition, TransitionSystem
from ._state cimport StateC from ._parser_internals._state cimport StateC
from ._parser_model cimport WeightsC, ActivationsC, SizesC from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
cdef class Parser: cdef class Parser(Pipe):
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef public object model cdef public object model
cdef public str name
cdef public object _rehearsal_model cdef public object _rehearsal_model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef readonly object cfg cdef readonly object cfg

View File

@ -1,42 +1,32 @@
# cython: infer_types=True, cdivision=True, boundscheck=False # cython: infer_types=True, cdivision=True, boundscheck=False
cimport cython.parallel from __future__ import print_function
from cymem.cymem cimport Pool
cimport numpy as np cimport numpy as np
from itertools import islice from itertools import islice
from cpython.ref cimport PyObject, Py_XDECREF
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from libc.math cimport exp
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.string cimport memset, memcpy from libc.string cimport memset
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
from cymem.cymem cimport Pool
from thinc.backends.linalg cimport Vec, VecVec
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
from thinc.api import get_array_module, zero_init, set_dropout_rate
from itertools import islice
import srsly import srsly
from ._parser_internals.stateclass cimport StateClass
from ..ml.parser_model cimport alloc_activations, free_activations
from ..ml.parser_model cimport predict_states, arg_max_if_valid
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ..ml.parser_model cimport get_c_weights, get_c_sizes
from ..tokens.doc cimport Doc
from ..errors import Errors, Warnings
from .. import util
from ..util import create_default_optimizer
from thinc.api import set_dropout_rate
import numpy.random import numpy.random
import numpy import numpy
import warnings import warnings
from ..tokens.doc cimport Doc
from ..typedefs cimport weight_t, class_t, hash_t
from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ._parser_model cimport get_c_weights, get_c_sizes
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition
from ..util import create_default_optimizer, registry cdef class Parser(Pipe):
from ..compat import copy_array
from ..errors import Errors, Warnings
from .. import util
from . import nonproj
cdef class Parser:
""" """
Base class of the DependencyParser and EntityRecognizer. Base class of the DependencyParser and EntityRecognizer.
""" """
@ -107,7 +97,7 @@ cdef class Parser:
@property @property
def tok2vec(self): def tok2vec(self):
'''Return the embedding and convolutional layer of the model.''' """Return the embedding and convolutional layer of the model."""
return self.model.get_ref("tok2vec") return self.model.get_ref("tok2vec")
@property @property
@ -138,13 +128,13 @@ cdef class Parser:
raise NotImplementedError raise NotImplementedError
def init_multitask_objectives(self, get_examples, pipeline, **cfg): def init_multitask_objectives(self, get_examples, pipeline, **cfg):
'''Setup models for secondary objectives, to benefit from multi-task """Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses. learning. This method is intended to be overridden by subclasses.
For instance, the dependency parser can benefit from sharing For instance, the dependency parser can benefit from sharing
an input representation with a label prediction model. These auxiliary an input representation with a label prediction model. These auxiliary
models are discarded after training. models are discarded after training.
''' """
pass pass
def use_params(self, params): def use_params(self, params):

View File

@ -1,4 +1,5 @@
from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type from typing import Dict, List, Union, Optional, Sequence, Any, Callable, Type
from typing import Iterable, TypeVar, TYPE_CHECKING
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
@ -8,6 +9,16 @@ from thinc.api import Optimizer
from .attrs import NAMES from .attrs import NAMES
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
from .gold import Example # noqa: F401
ItemT = TypeVar("ItemT")
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
Reader = Callable[["Language", str], Iterable["Example"]]
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]: def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
"""Validate data against a given pydantic schema. """Validate data against a given pydantic schema.
@ -181,30 +192,22 @@ class ModelMetaSchema(BaseModel):
class ConfigSchemaTraining(BaseModel): class ConfigSchemaTraining(BaseModel):
# fmt: off # fmt: off
base_model: Optional[StrictStr] = Field(..., title="The base model to use")
vectors: Optional[StrictStr] = Field(..., title="Path to vectors") vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
gold_preproc: StrictBool = Field(..., title="Whether to train on gold-standard sentences and tokens") train_corpus: Reader = Field(..., title="Reader for the training data")
max_length: StrictInt = Field(..., title="Maximum length of examples (longer examples are divided into sentences if possible)") dev_corpus: Reader = Field(..., title="Reader for the dev data")
limit: StrictInt = Field(..., title="Number of examples to use (0 for all)") batcher: Batcher = Field(..., title="Batcher for the training data")
orth_variant_level: StrictFloat = Field(..., title="Orth variants for data augmentation")
dropout: StrictFloat = Field(..., title="Dropout rate") dropout: StrictFloat = Field(..., title="Dropout rate")
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score") patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for") max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for") max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)") eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
eval_batch_size: StrictInt = Field(..., title="Evaluation batch size")
seed: Optional[StrictInt] = Field(..., title="Random seed") seed: Optional[StrictInt] = Field(..., title="Random seed")
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps") accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model") score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model")
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights") init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
discard_oversize: StrictBool = Field(..., title="Whether to skip examples longer than batch size") raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
batch_by: StrictStr = Field(..., title="Batch examples by type")
raw_text: Optional[StrictStr] = Field(..., title="Raw text")
tag_map: Optional[StrictStr] = Field(..., title="Path to JSON-formatted tag map")
morph_rules: Optional[StrictStr] = Field(..., title="Path to morphology rules")
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
optimizer: Optimizer = Field(..., title="The optimizer to use") optimizer: Optimizer = Field(..., title="The optimizer to use")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
# fmt: on # fmt: on
class Config: class Config:
@ -219,6 +222,9 @@ class ConfigSchemaNlp(BaseModel):
tokenizer: Callable = Field(..., title="The tokenizer to use") tokenizer: Callable = Field(..., title="The tokenizer to use")
lemmatizer: Callable = Field(..., title="The lemmatizer to use") lemmatizer: Callable = Field(..., title="The lemmatizer to use")
load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data") load_vocab_data: StrictBool = Field(..., title="Whether to load additional vocab data from spacy-lookups-data")
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -1,55 +1,61 @@
from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING
import numpy as np import numpy as np
from .gold import Example
from .tokens import Token, Doc
from .errors import Errors from .errors import Errors
from .util import get_lang_class from .util import get_lang_class
from .morphology import Morphology from .morphology import Morphology
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
DEFAULT_PIPELINE = ["senter", "tagger", "morphologizer", "parser", "ner", "textcat"]
class PRFScore: class PRFScore:
""" """A precision / recall / F score."""
A precision / recall / F score
"""
def __init__(self): def __init__(self) -> None:
self.tp = 0 self.tp = 0
self.fp = 0 self.fp = 0
self.fn = 0 self.fn = 0
def score_set(self, cand, gold): def score_set(self, cand: set, gold: set) -> None:
self.tp += len(cand.intersection(gold)) self.tp += len(cand.intersection(gold))
self.fp += len(cand - gold) self.fp += len(cand - gold)
self.fn += len(gold - cand) self.fn += len(gold - cand)
@property @property
def precision(self): def precision(self) -> float:
return self.tp / (self.tp + self.fp + 1e-100) return self.tp / (self.tp + self.fp + 1e-100)
@property @property
def recall(self): def recall(self) -> float:
return self.tp / (self.tp + self.fn + 1e-100) return self.tp / (self.tp + self.fn + 1e-100)
@property @property
def fscore(self): def fscore(self) -> float:
p = self.precision p = self.precision
r = self.recall r = self.recall
return 2 * ((p * r) / (p + r + 1e-100)) return 2 * ((p * r) / (p + r + 1e-100))
def to_dict(self): def to_dict(self) -> Dict[str, float]:
return {"p": self.precision, "r": self.recall, "f": self.fscore} return {"p": self.precision, "r": self.recall, "f": self.fscore}
class ROCAUCScore: class ROCAUCScore:
""" """An AUC ROC score."""
An AUC ROC score.
"""
def __init__(self): def __init__(self) -> None:
self.golds = [] self.golds = []
self.cands = [] self.cands = []
self.saved_score = 0.0 self.saved_score = 0.0
self.saved_score_at_len = 0 self.saved_score_at_len = 0
def score_set(self, cand, gold): def score_set(self, cand, gold) -> None:
self.cands.append(cand) self.cands.append(cand)
self.golds.append(gold) self.golds.append(gold)
@ -70,51 +76,52 @@ class ROCAUCScore:
class Scorer: class Scorer:
"""Compute evaluation scores.""" """Compute evaluation scores."""
def __init__(self, nlp=None, **cfg): def __init__(
self,
nlp: Optional["Language"] = None,
default_lang: str = "xx",
default_pipeline=DEFAULT_PIPELINE,
**cfg,
) -> None:
"""Initialize the Scorer. """Initialize the Scorer.
DOCS: https://spacy.io/api/scorer#init DOCS: https://spacy.io/api/scorer#init
""" """
self.nlp = nlp self.nlp = nlp
self.cfg = cfg self.cfg = cfg
if not nlp: if not nlp:
# create a default pipeline nlp = get_lang_class(default_lang)()
nlp = get_lang_class("xx")() for pipe in default_pipeline:
nlp.add_pipe("senter") nlp.add_pipe(pipe)
nlp.add_pipe("tagger")
nlp.add_pipe("morphologizer")
nlp.add_pipe("parser")
nlp.add_pipe("ner")
nlp.add_pipe("textcat")
self.nlp = nlp self.nlp = nlp
def score(self, examples): def score(self, examples: Iterable[Example]) -> Dict[str, Any]:
"""Evaluate a list of Examples. """Evaluate a list of Examples.
examples (Iterable[Example]): The predicted annotations + correct annotations. examples (Iterable[Example]): The predicted annotations + correct annotations.
RETURNS (Dict): A dictionary of scores. RETURNS (Dict): A dictionary of scores.
DOCS: https://spacy.io/api/scorer#score DOCS: https://spacy.io/api/scorer#score
""" """
scores = {} scores = {}
if hasattr(self.nlp.tokenizer, "score"): if hasattr(self.nlp.tokenizer, "score"):
scores.update(self.nlp.tokenizer.score(examples, **self.cfg)) scores.update(self.nlp.tokenizer.score(examples, **self.cfg))
for name, component in self.nlp.pipeline: for name, component in self.nlp.pipeline:
if hasattr(component, "score"): if hasattr(component, "score"):
scores.update(component.score(examples, **self.cfg)) scores.update(component.score(examples, **self.cfg))
return scores return scores
@staticmethod @staticmethod
def score_tokenization(examples, **cfg): def score_tokenization(examples: Iterable[Example], **cfg) -> Dict[str, float]:
"""Returns accuracy and PRF scores for tokenization. """Returns accuracy and PRF scores for tokenization.
* token_acc: # correct tokens / # gold tokens * token_acc: # correct tokens / # gold tokens
* token_p/r/f: PRF for token character spans * token_p/r/f: PRF for token character spans
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
RETURNS (dict): A dictionary containing the scores token_acc/p/r/f. RETURNS (Dict[str, float]): A dictionary containing the scores
token_acc/p/r/f.
DOCS: https://spacy.io/api/scorer#score_tokenization
""" """
acc_score = PRFScore() acc_score = PRFScore()
prf_score = PRFScore() prf_score = PRFScore()
@ -145,16 +152,24 @@ class Scorer:
} }
@staticmethod @staticmethod
def score_token_attr(examples, attr, getter=getattr, **cfg): def score_token_attr(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
**cfg,
) -> Dict[str, float]:
"""Returns an accuracy score for a token-level attribute. """Returns an accuracy score for a token-level attribute.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided, getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an getter(token, attr) should return the value of the attribute for an
individual token. individual token.
RETURNS (dict): A dictionary containing the accuracy score under the RETURNS (Dict[str, float]): A dictionary containing the accuracy score
key attr_acc. under the key attr_acc.
DOCS: https://spacy.io/api/scorer#score_token_attr
""" """
tag_score = PRFScore() tag_score = PRFScore()
for example in examples: for example in examples:
@ -172,17 +187,21 @@ class Scorer:
gold_i = align.x2y[token.i].dataXd[0, 0] gold_i = align.x2y[token.i].dataXd[0, 0]
pred_tags.add((gold_i, getter(token, attr))) pred_tags.add((gold_i, getter(token, attr)))
tag_score.score_set(pred_tags, gold_tags) tag_score.score_set(pred_tags, gold_tags)
return { return {f"{attr}_acc": tag_score.fscore}
attr + "_acc": tag_score.fscore,
}
@staticmethod @staticmethod
def score_token_attr_per_feat(examples, attr, getter=getattr, **cfg): def score_token_attr_per_feat(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
**cfg,
):
"""Return PRF scores per feat for a token attribute in UFEATS format. """Return PRF scores per feat for a token attribute in UFEATS format.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided, getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an getter(token, attr) should return the value of the attribute for an
individual token. individual token.
RETURNS (dict): A dictionary containing the per-feat PRF scores unders RETURNS (dict): A dictionary containing the per-feat PRF scores unders
@ -223,20 +242,26 @@ class Scorer:
per_feat[field].score_set( per_feat[field].score_set(
pred_per_feat.get(field, set()), gold_per_feat.get(field, set()), pred_per_feat.get(field, set()), gold_per_feat.get(field, set()),
) )
return { return {f"{attr}_per_feat": per_feat}
attr + "_per_feat": per_feat,
}
@staticmethod @staticmethod
def score_spans(examples, attr, getter=getattr, **cfg): def score_spans(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Doc, str], Any] = getattr,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans. """Returns PRF scores for labeled spans.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided, getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the spans for the individual doc. getter(doc, attr) should return the spans for the individual doc.
RETURNS (dict): A dictionary containing the PRF scores under the RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
keys attr_p/r/f and the per-type PRF scores under attr_per_type. the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
DOCS: https://spacy.io/api/scorer#score_spans
""" """
score = PRFScore() score = PRFScore()
score_per_type = dict() score_per_type = dict()
@ -256,14 +281,12 @@ class Scorer:
# Find all predidate labels, for all and per type # Find all predidate labels, for all and per type
gold_spans = set() gold_spans = set()
pred_spans = set() pred_spans = set()
# Special case for ents: # Special case for ents:
# If we have missing values in the gold, we can't easily tell # If we have missing values in the gold, we can't easily tell
# whether our NER predictions are true. # whether our NER predictions are true.
# It seems bad but it's what we've always done. # It seems bad but it's what we've always done.
if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc): if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc):
continue continue
for span in getter(gold_doc, attr): for span in getter(gold_doc, attr):
gold_span = (span.label_, span.start, span.end - 1) gold_span = (span.label_, span.start, span.end - 1)
gold_spans.add(gold_span) gold_spans.add(gold_span)
@ -279,38 +302,39 @@ class Scorer:
# Score for all labels # Score for all labels
score.score_set(pred_spans, gold_spans) score.score_set(pred_spans, gold_spans)
results = { results = {
attr + "_p": score.precision, f"{attr}_p": score.precision,
attr + "_r": score.recall, f"{attr}_r": score.recall,
attr + "_f": score.fscore, f"{attr}_f": score.fscore,
attr + "_per_type": {k: v.to_dict() for k, v in score_per_type.items()}, f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
} }
return results return results
@staticmethod @staticmethod
def score_cats( def score_cats(
examples, examples: Iterable[Example],
attr, attr: str,
getter=getattr, *,
labels=[], getter: Callable[[Doc, str], Any] = getattr,
multi_label=True, labels: Iterable[str] = tuple(),
positive_label=None, multi_label: bool = True,
**cfg positive_label: Optional[str] = None,
): **cfg,
) -> Dict[str, Any]:
"""Returns PRF and ROC AUC scores for a doc-level attribute with a """Returns PRF and ROC AUC scores for a doc-level attribute with a
dict with scores for each label like Doc.cats. The reported overall dict with scores for each label like Doc.cats. The reported overall
score depends on the scorer settings. score depends on the scorer settings.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided, getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the values for the individual doc. getter(doc, attr) should return the values for the individual doc.
labels (Iterable[str]): The set of possible labels. Defaults to []. labels (Iterable[str]): The set of possible labels. Defaults to [].
multi_label (bool): Whether the attribute allows multiple labels. multi_label (bool): Whether the attribute allows multiple labels.
Defaults to True. Defaults to True.
positive_label (str): The positive label for a binary task with positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None. exclusive classes. Defaults to None.
RETURNS (dict): A dictionary containing the scores, with inapplicable RETURNS (Dict[str, Any]): A dictionary containing the scores, with
scores as None: inapplicable scores as None:
for all: for all:
attr_score (one of attr_f / attr_macro_f / attr_macro_auc), attr_score (one of attr_f / attr_macro_f / attr_macro_auc),
attr_score_desc (text description of the overall score), attr_score_desc (text description of the overall score),
@ -319,6 +343,8 @@ class Scorer:
for binary exclusive with positive label: attr_p/r/f for binary exclusive with positive label: attr_p/r/f
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f for 3+ exclusive classes, macro-averaged fscore: attr_macro_f
for multilabel, macro-averaged AUC: attr_macro_auc for multilabel, macro-averaged AUC: attr_macro_auc
DOCS: https://spacy.io/api/scorer#score_cats
""" """
score = PRFScore() score = PRFScore()
f_per_type = dict() f_per_type = dict()
@ -367,64 +393,67 @@ class Scorer:
) )
) )
results = { results = {
attr + "_score": None, f"{attr}_score": None,
attr + "_score_desc": None, f"{attr}_score_desc": None,
attr + "_p": None, f"{attr}_p": None,
attr + "_r": None, f"{attr}_r": None,
attr + "_f": None, f"{attr}_f": None,
attr + "_macro_f": None, f"{attr}_macro_f": None,
attr + "_macro_auc": None, f"{attr}_macro_auc": None,
attr + "_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()}, f"{attr}_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
attr + "_auc_per_type": {k: v.score for k, v in auc_per_type.items()}, f"{attr}_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
} }
if len(labels) == 2 and not multi_label and positive_label: if len(labels) == 2 and not multi_label and positive_label:
results[attr + "_p"] = score.precision results[f"{attr}_p"] = score.precision
results[attr + "_r"] = score.recall results[f"{attr}_r"] = score.recall
results[attr + "_f"] = score.fscore results[f"{attr}_f"] = score.fscore
results[attr + "_score"] = results[attr + "_f"] results[f"{attr}_score"] = results[f"{attr}_f"]
results[attr + "_score_desc"] = "F (" + positive_label + ")" results[f"{attr}_score_desc"] = f"F ({positive_label})"
elif not multi_label: elif not multi_label:
results[attr + "_macro_f"] = sum( results[f"{attr}_macro_f"] = sum(
[score.fscore for label, score in f_per_type.items()] [score.fscore for label, score in f_per_type.items()]
) / (len(f_per_type) + 1e-100) ) / (len(f_per_type) + 1e-100)
results[attr + "_score"] = results[attr + "_macro_f"] results[f"{attr}_score"] = results[f"{attr}_macro_f"]
results[attr + "_score_desc"] = "macro F" results[f"{attr}_score_desc"] = "macro F"
else: else:
results[attr + "_macro_auc"] = max( results[f"{attr}_macro_auc"] = max(
sum([score.score for label, score in auc_per_type.items()]) sum([score.score for label, score in auc_per_type.items()])
/ (len(auc_per_type) + 1e-100), / (len(auc_per_type) + 1e-100),
-1, -1,
) )
results[attr + "_score"] = results[attr + "_macro_auc"] results[f"{attr}_score"] = results[f"{attr}_macro_auc"]
results[attr + "_score_desc"] = "macro AUC" results[f"{attr}_score_desc"] = "macro AUC"
return results return results
@staticmethod @staticmethod
def score_deps( def score_deps(
examples, examples: Iterable[Example],
attr, attr: str,
getter=getattr, *,
head_attr="head", getter: Callable[[Token, str], Any] = getattr,
head_getter=getattr, head_attr: str = "head",
ignore_labels=tuple(), head_getter: Callable[[Token, str], Any] = getattr,
**cfg ignore_labels: Tuple[str] = tuple(),
): **cfg,
) -> Dict[str, Any]:
"""Returns the UAS, LAS, and LAS per type scores for dependency """Returns the UAS, LAS, and LAS per type scores for dependency
parses. parses.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute containing the dependency label. attr (str): The attribute containing the dependency label.
getter (callable): Defaults to getattr. If provided, getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an getter(token, attr) should return the value of the attribute for an
individual token. individual token.
head_attr (str): The attribute containing the head token. Defaults to head_attr (str): The attribute containing the head token. Defaults to
'head'. 'head'.
head_getter (callable): Defaults to getattr. If provided, head_getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
head_getter(token, attr) should return the value of the head for an head_getter(token, attr) should return the value of the head for an
individual token. individual token.
ignore_labels (Tuple): Labels to ignore while scoring (e.g., punct). ignore_labels (Tuple): Labels to ignore while scoring (e.g., punct).
RETURNS (dict): A dictionary containing the scores: RETURNS (Dict[str, Any]): A dictionary containing the scores:
attr_uas, attr_las, and attr_las_per_type. attr_uas, attr_las, and attr_las_per_type.
DOCS: https://spacy.io/api/scorer#score_deps
""" """
unlabelled = PRFScore() unlabelled = PRFScore()
labelled = PRFScore() labelled = PRFScore()
@ -482,10 +511,11 @@ class Scorer:
set(item[:2] for item in pred_deps), set(item[:2] for item in gold_deps) set(item[:2] for item in pred_deps), set(item[:2] for item in gold_deps)
) )
return { return {
attr + "_uas": unlabelled.fscore, f"{attr}_uas": unlabelled.fscore,
attr + "_las": labelled.fscore, f"{attr}_las": labelled.fscore,
attr f"{attr}_las_per_type": {
+ "_las_per_type": {k: v.to_dict() for k, v in labelled_per_dep.items()}, k: v.to_dict() for k, v in labelled_per_dep.items()
},
} }

View File

@ -282,3 +282,15 @@ def test_span_eq_hash(doc, doc_not_parsed):
assert hash(doc[0:2]) == hash(doc[0:2]) assert hash(doc[0:2]) == hash(doc[0:2])
assert hash(doc[0:2]) != hash(doc[1:3]) assert hash(doc[0:2]) != hash(doc[1:3])
assert hash(doc[0:2]) != hash(doc_not_parsed[0:2]) assert hash(doc[0:2]) != hash(doc_not_parsed[0:2])
def test_span_boundaries(doc):
start = 1
end = 5
span = doc[start:end]
for i in range(start, end):
assert span[i - start] == doc[i]
with pytest.raises(IndexError):
span[-5]
with pytest.raises(IndexError):
span[5]

View File

@ -29,9 +29,7 @@ def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba):
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg): def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
nlp = Chinese( nlp = Chinese(
meta={ meta={
"tokenizer": { "tokenizer": {"config": {"segmenter": "pkuseg", "pkuseg_model": "medicine"}}
"config": {"segmenter": "pkuseg", "pkuseg_model": "medicine",}
}
} }
) )
zh_tokenizer_serialize(nlp.tokenizer) zh_tokenizer_serialize(nlp.tokenizer)

View File

@ -21,7 +21,7 @@ re_pattern5 = "B*A*B"
longest1 = "A A A A A" longest1 = "A A A A A"
longest2 = "A A A A A" longest2 = "A A A A A"
longest3 = "A A" longest3 = "A A"
longest4 = "B A A A A A B" # "FIRST" would be "B B" longest4 = "B A A A A A B" # "FIRST" would be "B B"
longest5 = "B B A A A A A B" longest5 = "B B A A A A A B"

View File

@ -4,8 +4,8 @@ from spacy import registry
from spacy.gold import Example from spacy.gold import Example
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.syntax.nonproj import projectivize from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.syntax.arc_eager import ArcEager from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL

View File

@ -5,7 +5,7 @@ from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.lookups import Lookups from spacy.lookups import Lookups
from spacy.syntax.ner import BiluoPushDown from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.gold import Example from spacy.gold import Example
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -210,7 +210,7 @@ def test_train_empty():
nlp.begin_training() nlp.begin_training()
for itn in range(2): for itn in range(2):
losses = {} losses = {}
batches = util.minibatch(train_examples) batches = util.minibatch(train_examples, size=8)
for batch in batches: for batch in batches:
nlp.update(batch, losses=losses) nlp.update(batch, losses=losses)

View File

@ -3,8 +3,8 @@ import pytest
from spacy import registry from spacy import registry
from spacy.gold import Example from spacy.gold import Example
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.syntax.nn_parser import Parser from spacy.pipeline.transition_parser import Parser
from spacy.tokens.doc import Doc from spacy.tokens.doc import Doc
from thinc.api import Model from thinc.api import Model
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL

View File

@ -1,7 +1,7 @@
import pytest import pytest
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc from spacy.pipeline._parser_internals.nonproj import ancestors, contains_cycle
from spacy.syntax.nonproj import is_nonproj_tree from spacy.pipeline._parser_internals.nonproj import is_nonproj_tree, is_nonproj_arc
from spacy.syntax import nonproj from spacy.pipeline._parser_internals import nonproj
from ..util import get_doc from ..util import get_doc

View File

@ -1,15 +1,10 @@
import spacy.language
from spacy.language import Language from spacy.language import Language
from spacy.pipe_analysis import print_summary, validate_attrs from spacy.pipe_analysis import get_attr_info, validate_attrs
from spacy.pipe_analysis import get_assigns_for_attr, get_requires_for_attr
from spacy.pipe_analysis import count_pipeline_interdependencies
from mock import Mock from mock import Mock
import pytest import pytest
def test_component_decorator_assigns(): def test_component_decorator_assigns():
spacy.language.ENABLE_PIPELINE_ANALYSIS = True
@Language.component("c1", assigns=["token.tag", "doc.tensor"]) @Language.component("c1", assigns=["token.tag", "doc.tensor"])
def test_component1(doc): def test_component1(doc):
return doc return doc
@ -32,10 +27,11 @@ def test_component_decorator_assigns():
nlp = Language() nlp = Language()
nlp.add_pipe("c1") nlp.add_pipe("c1")
with pytest.warns(UserWarning): nlp.add_pipe("c2")
nlp.add_pipe("c2") problems = nlp.analyze_pipes()["problems"]
assert problems["c2"] == ["token.pos"]
nlp.add_pipe("c3") nlp.add_pipe("c3")
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2"] assert get_attr_info(nlp, "doc.tensor")["assigns"] == ["c1", "c2"]
nlp.add_pipe("c1", name="c4") nlp.add_pipe("c1", name="c4")
test_component4_meta = nlp.get_pipe_meta("c1") test_component4_meta = nlp.get_pipe_meta("c1")
assert test_component4_meta.factory == "c1" assert test_component4_meta.factory == "c1"
@ -43,9 +39,8 @@ def test_component_decorator_assigns():
assert not Language.has_factory("c4") assert not Language.has_factory("c4")
assert nlp.pipe_factories["c1"] == "c1" assert nlp.pipe_factories["c1"] == "c1"
assert nlp.pipe_factories["c4"] == "c1" assert nlp.pipe_factories["c4"] == "c1"
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2", "c4"] assert get_attr_info(nlp, "doc.tensor")["assigns"] == ["c1", "c2", "c4"]
assert get_requires_for_attr(nlp, "token.pos") == ["c2"] assert get_attr_info(nlp, "token.pos")["requires"] == ["c2"]
assert print_summary(nlp, no_print=True)
assert nlp("hello world") assert nlp("hello world")
@ -100,7 +95,6 @@ def test_analysis_validate_attrs_invalid(attr):
def test_analysis_validate_attrs_remove_pipe(): def test_analysis_validate_attrs_remove_pipe():
"""Test that attributes are validated correctly on remove.""" """Test that attributes are validated correctly on remove."""
spacy.language.ENABLE_PIPELINE_ANALYSIS = True
@Language.component("pipe_analysis_c6", assigns=["token.tag"]) @Language.component("pipe_analysis_c6", assigns=["token.tag"])
def c1(doc): def c1(doc):
@ -112,26 +106,9 @@ def test_analysis_validate_attrs_remove_pipe():
nlp = Language() nlp = Language()
nlp.add_pipe("pipe_analysis_c6") nlp.add_pipe("pipe_analysis_c6")
with pytest.warns(UserWarning): nlp.add_pipe("pipe_analysis_c7")
nlp.add_pipe("pipe_analysis_c7") problems = nlp.analyze_pipes()["problems"]
with pytest.warns(None) as record: assert problems["pipe_analysis_c7"] == ["token.pos"]
nlp.remove_pipe("pipe_analysis_c7") nlp.remove_pipe("pipe_analysis_c7")
assert not record.list problems = nlp.analyze_pipes()["problems"]
assert all(p == [] for p in problems.values())
def test_pipe_interdependencies():
prefix = "test_pipe_interdependencies"
@Language.component(f"{prefix}.fancifier", assigns=("doc._.fancy",))
def fancifier(doc):
return doc
@Language.component(f"{prefix}.needer", requires=("doc._.fancy",))
def needer(doc):
return doc
nlp = Language()
nlp.add_pipe(f"{prefix}.fancifier")
nlp.add_pipe(f"{prefix}.needer")
counts = count_pipeline_interdependencies(nlp)
assert counts == [1, 0]

View File

@ -0,0 +1,207 @@
import pytest
import numpy
from spacy.lang.en import English
from spacy.pipeline import AttributeRuler
from spacy import util, registry
from ..util import get_doc, make_tempdir
@pytest.fixture
def nlp():
return English()
@pytest.fixture
def pattern_dicts():
return [
{
"patterns": [[{"ORTH": "a"}], [{"ORTH": "irrelevant"}]],
"attrs": {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"},
},
# one pattern sets the lemma
{"patterns": [[{"ORTH": "test"}]], "attrs": {"LEMMA": "cat"}},
# another pattern sets the morphology
{
"patterns": [[{"ORTH": "test"}]],
"attrs": {"MORPH": "Case=Nom|Number=Sing"},
"index": 0,
},
]
@registry.assets("attribute_ruler_patterns")
def attribute_ruler_patterns():
return [
{
"patterns": [[{"ORTH": "a"}], [{"ORTH": "irrelevant"}]],
"attrs": {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"},
},
# one pattern sets the lemma
{"patterns": [[{"ORTH": "test"}]], "attrs": {"LEMMA": "cat"}},
# another pattern sets the morphology
{
"patterns": [[{"ORTH": "test"}]],
"attrs": {"MORPH": "Case=Nom|Number=Sing"},
"index": 0,
},
]
@pytest.fixture
def tag_map():
return {
".": {"POS": "PUNCT", "PunctType": "peri"},
",": {"POS": "PUNCT", "PunctType": "comm"},
}
@pytest.fixture
def morph_rules():
return {"DT": {"the": {"POS": "DET", "LEMMA": "a", "Case": "Nom"}}}
def test_attributeruler_init(nlp, pattern_dicts):
a = nlp.add_pipe("attribute_ruler")
for p in pattern_dicts:
a.add(**p)
doc = nlp("This is a test.")
assert doc[2].lemma_ == "the"
assert doc[2].morph_ == "Case=Nom|Number=Plur"
assert doc[3].lemma_ == "cat"
assert doc[3].morph_ == "Case=Nom|Number=Sing"
def test_attributeruler_init_patterns(nlp, pattern_dicts):
# initialize with patterns
nlp.add_pipe("attribute_ruler", config={"pattern_dicts": pattern_dicts})
doc = nlp("This is a test.")
assert doc[2].lemma_ == "the"
assert doc[2].morph_ == "Case=Nom|Number=Plur"
assert doc[3].lemma_ == "cat"
assert doc[3].morph_ == "Case=Nom|Number=Sing"
nlp.remove_pipe("attribute_ruler")
# initialize with patterns from asset
nlp.add_pipe(
"attribute_ruler",
config={"pattern_dicts": {"@assets": "attribute_ruler_patterns"}},
)
doc = nlp("This is a test.")
assert doc[2].lemma_ == "the"
assert doc[2].morph_ == "Case=Nom|Number=Plur"
assert doc[3].lemma_ == "cat"
assert doc[3].morph_ == "Case=Nom|Number=Sing"
def test_attributeruler_tag_map(nlp, tag_map):
a = AttributeRuler(nlp.vocab)
a.load_from_tag_map(tag_map)
doc = get_doc(
nlp.vocab,
words=["This", "is", "a", "test", "."],
tags=["DT", "VBZ", "DT", "NN", "."],
)
doc = a(doc)
for i in range(len(doc)):
if i == 4:
assert doc[i].pos_ == "PUNCT"
assert doc[i].morph_ == "PunctType=peri"
else:
assert doc[i].pos_ == ""
assert doc[i].morph_ == ""
def test_attributeruler_morph_rules(nlp, morph_rules):
a = AttributeRuler(nlp.vocab)
a.load_from_morph_rules(morph_rules)
doc = get_doc(
nlp.vocab,
words=["This", "is", "the", "test", "."],
tags=["DT", "VBZ", "DT", "NN", "."],
)
doc = a(doc)
for i in range(len(doc)):
if i != 2:
assert doc[i].pos_ == ""
assert doc[i].morph_ == ""
else:
assert doc[2].pos_ == "DET"
assert doc[2].lemma_ == "a"
assert doc[2].morph_ == "Case=Nom"
def test_attributeruler_indices(nlp):
a = nlp.add_pipe("attribute_ruler")
a.add(
[[{"ORTH": "a"}, {"ORTH": "test"}]],
{"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"},
index=0,
)
a.add(
[[{"ORTH": "This"}, {"ORTH": "is"}]],
{"LEMMA": "was", "MORPH": "Case=Nom|Number=Sing"},
index=1,
)
a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=-1)
text = "This is a test."
doc = nlp(text)
for i in range(len(doc)):
if i == 1:
assert doc[i].lemma_ == "was"
assert doc[i].morph_ == "Case=Nom|Number=Sing"
elif i == 2:
assert doc[i].lemma_ == "the"
assert doc[i].morph_ == "Case=Nom|Number=Plur"
elif i == 3:
assert doc[i].lemma_ == "cat"
else:
assert doc[i].morph_ == ""
# raises an error when trying to modify a token outside of the match
a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=2)
with pytest.raises(ValueError):
doc = nlp(text)
# raises an error when trying to modify a token outside of the match
a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=10)
with pytest.raises(ValueError):
doc = nlp(text)
def test_attributeruler_patterns_prop(nlp, pattern_dicts):
a = nlp.add_pipe("attribute_ruler")
a.add_patterns(pattern_dicts)
for p1, p2 in zip(pattern_dicts, a.patterns):
assert p1["patterns"] == p2["patterns"]
assert p1["attrs"] == p2["attrs"]
if p1.get("index"):
assert p1["index"] == p2["index"]
def test_attributeruler_serialize(nlp, pattern_dicts):
a = nlp.add_pipe("attribute_ruler")
a.add_patterns(pattern_dicts)
text = "This is a test."
attrs = ["ORTH", "LEMMA", "MORPH"]
doc = nlp(text)
# bytes roundtrip
a_reloaded = AttributeRuler(nlp.vocab).from_bytes(a.to_bytes())
assert a.to_bytes() == a_reloaded.to_bytes()
doc1 = a_reloaded(nlp.make_doc(text))
numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs))
# disk roundtrip
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(text)
assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes()
assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs))

View File

@ -21,7 +21,8 @@ def assert_almost_equal(a, b):
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
@ -50,7 +51,8 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity""" """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -66,7 +68,8 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp): def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities""" """Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -80,7 +83,8 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp): def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists""" """Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -96,7 +100,8 @@ def test_kb_invalid_combination(nlp):
def test_kb_invalid_entity_vector(nlp): def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths""" """Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
@ -106,9 +111,47 @@ def test_kb_invalid_entity_vector(nlp):
mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=5, entity_vector=[2])
def test_kb_default(nlp):
"""Test that the default (empty) KB is loaded when not providing a config"""
entity_linker = nlp.add_pipe("entity_linker", config={})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
# default value from pipeline.entity_linker
assert entity_linker.kb.entity_vector_length == 64
def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe(
"entity_linker", config={"kb": {"entity_vector_length": 35}}
)
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
assert entity_linker.kb.entity_vector_length == 35
def test_kb_undefined(nlp):
"""Test that the EL can't train without defining a KB"""
entity_linker = nlp.add_pipe("entity_linker", config={})
with pytest.raises(ValueError):
entity_linker.begin_training()
def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB"""
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0
with pytest.raises(ValueError):
entity_linker.begin_training()
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
"""Test correct candidate generation""" """Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -133,7 +176,8 @@ def test_candidate_generation(nlp):
def test_append_alias(nlp): def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs""" """Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -163,7 +207,8 @@ def test_append_alias(nlp):
def test_append_invalid_alias(nlp): def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1""" """Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -184,7 +229,8 @@ def test_preserving_links_asdoc(nlp):
@registry.assets.register("myLocationsKB.v1") @registry.assets.register("myLocationsKB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> KnowledgeBase:
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
@ -289,7 +335,8 @@ def test_overfitting_IO():
# create artificial KB - assign same prior weight to the two russ cochran's # create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer # Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher # Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias( mykb.add_alias(

View File

@ -8,6 +8,8 @@ from thinc.api import Model, Linear
from thinc.config import ConfigValidationError from thinc.config import ConfigValidationError
from pydantic import StrictInt, StrictStr from pydantic import StrictInt, StrictStr
from ..util import make_tempdir
def test_pipe_function_component(): def test_pipe_function_component():
name = "test_component" name = "test_component"
@ -374,3 +376,65 @@ def test_language_factories_scores():
cfg = nlp.config["training"] cfg = nlp.config["training"]
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05} expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights assert cfg["score_weights"] == expected_weights
def test_pipe_factories_from_source():
"""Test adding components from a source model."""
source_nlp = English()
source_nlp.add_pipe("tagger", name="my_tagger")
nlp = English()
with pytest.raises(ValueError):
nlp.add_pipe("my_tagger", source="en_core_web_sm")
nlp.add_pipe("my_tagger", source=source_nlp)
assert "my_tagger" in nlp.pipe_names
with pytest.raises(KeyError):
nlp.add_pipe("custom", source=source_nlp)
def test_pipe_factories_from_source_custom():
"""Test adding components from a source model with custom components."""
name = "test_pipe_factories_from_source_custom"
@Language.factory(name, default_config={"arg": "hello"})
def test_factory(nlp, name, arg: str):
return lambda doc: doc
source_nlp = English()
source_nlp.add_pipe("tagger")
source_nlp.add_pipe(name, config={"arg": "world"})
nlp = English()
nlp.add_pipe(name, source=source_nlp)
assert name in nlp.pipe_names
assert nlp.get_pipe_meta(name).default_config["arg"] == "hello"
config = nlp.config["components"][name]
assert config["factory"] == name
assert config["arg"] == "world"
def test_pipe_factories_from_source_config():
name = "test_pipe_factories_from_source_config"
@Language.factory(name, default_config={"arg": "hello"})
def test_factory(nlp, name, arg: str):
return lambda doc: doc
source_nlp = English()
source_nlp.add_pipe("tagger")
source_nlp.add_pipe(name, name="yolo", config={"arg": "world"})
dest_nlp_cfg = {"lang": "en", "pipeline": ["parser", "custom"]}
with make_tempdir() as tempdir:
source_nlp.to_disk(tempdir)
dest_components_cfg = {
"parser": {"factory": "parser"},
"custom": {"source": str(tempdir), "component": "yolo"},
}
dest_config = {"nlp": dest_nlp_cfg, "components": dest_components_cfg}
nlp = English.from_config(dest_config)
assert nlp.pipe_names == ["parser", "custom"]
assert nlp.pipe_factories == {"parser": "parser", "custom": name}
meta = nlp.get_pipe_meta("custom")
assert meta.factory == name
assert meta.default_config["arg"] == "hello"
config = nlp.config["components"]["custom"]
assert config["factory"] == name
assert config["arg"] == "world"

View File

@ -70,6 +70,14 @@ def test_replace_pipe(nlp, name, replacement, invalid_replacement):
assert nlp.get_pipe(name) == nlp.create_pipe(replacement) assert nlp.get_pipe(name) == nlp.create_pipe(replacement)
def test_replace_last_pipe(nlp):
nlp.add_pipe("sentencizer")
nlp.add_pipe("ner")
assert nlp.pipe_names == ["sentencizer", "ner"]
nlp.replace_pipe("ner", "ner")
assert nlp.pipe_names == ["sentencizer", "ner"]
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")]) @pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
def test_rename_pipe(nlp, old_name, new_name): def test_rename_pipe(nlp, old_name, new_name):
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -1,418 +1,45 @@
import pytest from spacy.lang.en import English
from collections import namedtuple from spacy.gold import Example
from thinc.api import NumpyOps from spacy import util
from spacy.ml._biluo import BILUO, _get_transition_table from ..util import make_tempdir
@pytest.fixture( TRAIN_DATA = [
params=[ ("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
["PER", "ORG", "LOC", "MISC"], ("I like London and Berlin.", {"entities": [(7, 13, "LOC"), (18, 24, "LOC")]}),
["GPE", "PERSON", "NUMBER", "CURRENCY", "EVENT"], ]
]
)
def labels(request):
return request.param
@pytest.fixture def test_overfitting_IO():
def ops(): # Simple test to try and quickly overfit the SimpleNER component - ensuring the ML models work correctly
return NumpyOps() nlp = English()
ner = nlp.add_pipe("simple_ner")
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
for ent in annotations.get("entities"):
ner.add_label(ent[2])
optimizer = nlp.begin_training()
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["ner"] < 0.0001
def _get_actions(labels): # test the trained model
action_names = ( test_text = "I like London."
[f"B{label}" for label in labels] doc = nlp(test_text)
+ [f"I{label}" for label in labels] ents = doc.ents
+ [f"L{label}" for label in labels] assert len(ents) == 1
+ [f"U{label}" for label in labels] assert ents[0].text == "London"
+ ["O"] assert ents[0].label_ == "LOC"
)
A = namedtuple("actions", action_names)
return A(**{name: i for i, name in enumerate(action_names)})
# Also test the results are still the same after IO
def test_init_biluo_layer(labels): with make_tempdir() as tmp_dir:
model = BILUO() nlp.to_disk(tmp_dir)
model.set_dim("nO", model.attrs["get_num_actions"](len(labels))) nlp2 = util.load_model_from_path(tmp_dir)
model.initialize() doc2 = nlp2(test_text)
assert model.get_dim("nO") == len(labels) * 4 + 1 ents2 = doc2.ents
assert len(ents2) == 1
assert ents2[0].text == "London"
def test_transition_table(ops): assert ents2[0].label_ == "LOC"
labels = ["per", "loc", "org"]
table = _get_transition_table(len(labels))
a = _get_actions(labels)
assert table.shape == (2, len(a), len(a))
# Not last token, prev action was B
assert table[0, a.Bper, a.Bper] == 0
assert table[0, a.Bper, a.Bloc] == 0
assert table[0, a.Bper, a.Borg] == 0
assert table[0, a.Bper, a.Iper] == 1
assert table[0, a.Bper, a.Iloc] == 0
assert table[0, a.Bper, a.Iorg] == 0
assert table[0, a.Bper, a.Lper] == 1
assert table[0, a.Bper, a.Lloc] == 0
assert table[0, a.Bper, a.Lorg] == 0
assert table[0, a.Bper, a.Uper] == 0
assert table[0, a.Bper, a.Uloc] == 0
assert table[0, a.Bper, a.Uorg] == 0
assert table[0, a.Bper, a.O] == 0
assert table[0, a.Bloc, a.Bper] == 0
assert table[0, a.Bloc, a.Bloc] == 0
assert table[0, a.Bloc, a.Borg] == 0
assert table[0, a.Bloc, a.Iper] == 0
assert table[0, a.Bloc, a.Iloc] == 1
assert table[0, a.Bloc, a.Iorg] == 0
assert table[0, a.Bloc, a.Lper] == 0
assert table[0, a.Bloc, a.Lloc] == 1
assert table[0, a.Bloc, a.Lorg] == 0
assert table[0, a.Bloc, a.Uper] == 0
assert table[0, a.Bloc, a.Uloc] == 0
assert table[0, a.Bloc, a.Uorg] == 0
assert table[0, a.Bloc, a.O] == 0
assert table[0, a.Borg, a.Bper] == 0
assert table[0, a.Borg, a.Bloc] == 0
assert table[0, a.Borg, a.Borg] == 0
assert table[0, a.Borg, a.Iper] == 0
assert table[0, a.Borg, a.Iloc] == 0
assert table[0, a.Borg, a.Iorg] == 1
assert table[0, a.Borg, a.Lper] == 0
assert table[0, a.Borg, a.Lloc] == 0
assert table[0, a.Borg, a.Lorg] == 1
assert table[0, a.Borg, a.Uper] == 0
assert table[0, a.Borg, a.Uloc] == 0
assert table[0, a.Borg, a.Uorg] == 0
assert table[0, a.Borg, a.O] == 0
# Not last token, prev action was I
assert table[0, a.Iper, a.Bper] == 0
assert table[0, a.Iper, a.Bloc] == 0
assert table[0, a.Iper, a.Borg] == 0
assert table[0, a.Iper, a.Iper] == 1
assert table[0, a.Iper, a.Iloc] == 0
assert table[0, a.Iper, a.Iorg] == 0
assert table[0, a.Iper, a.Lper] == 1
assert table[0, a.Iper, a.Lloc] == 0
assert table[0, a.Iper, a.Lorg] == 0
assert table[0, a.Iper, a.Uper] == 0
assert table[0, a.Iper, a.Uloc] == 0
assert table[0, a.Iper, a.Uorg] == 0
assert table[0, a.Iper, a.O] == 0
assert table[0, a.Iloc, a.Bper] == 0
assert table[0, a.Iloc, a.Bloc] == 0
assert table[0, a.Iloc, a.Borg] == 0
assert table[0, a.Iloc, a.Iper] == 0
assert table[0, a.Iloc, a.Iloc] == 1
assert table[0, a.Iloc, a.Iorg] == 0
assert table[0, a.Iloc, a.Lper] == 0
assert table[0, a.Iloc, a.Lloc] == 1
assert table[0, a.Iloc, a.Lorg] == 0
assert table[0, a.Iloc, a.Uper] == 0
assert table[0, a.Iloc, a.Uloc] == 0
assert table[0, a.Iloc, a.Uorg] == 0
assert table[0, a.Iloc, a.O] == 0
assert table[0, a.Iorg, a.Bper] == 0
assert table[0, a.Iorg, a.Bloc] == 0
assert table[0, a.Iorg, a.Borg] == 0
assert table[0, a.Iorg, a.Iper] == 0
assert table[0, a.Iorg, a.Iloc] == 0
assert table[0, a.Iorg, a.Iorg] == 1
assert table[0, a.Iorg, a.Lper] == 0
assert table[0, a.Iorg, a.Lloc] == 0
assert table[0, a.Iorg, a.Lorg] == 1
assert table[0, a.Iorg, a.Uper] == 0
assert table[0, a.Iorg, a.Uloc] == 0
assert table[0, a.Iorg, a.Uorg] == 0
assert table[0, a.Iorg, a.O] == 0
# Not last token, prev action was L
assert table[0, a.Lper, a.Bper] == 1
assert table[0, a.Lper, a.Bloc] == 1
assert table[0, a.Lper, a.Borg] == 1
assert table[0, a.Lper, a.Iper] == 0
assert table[0, a.Lper, a.Iloc] == 0
assert table[0, a.Lper, a.Iorg] == 0
assert table[0, a.Lper, a.Lper] == 0
assert table[0, a.Lper, a.Lloc] == 0
assert table[0, a.Lper, a.Lorg] == 0
assert table[0, a.Lper, a.Uper] == 1
assert table[0, a.Lper, a.Uloc] == 1
assert table[0, a.Lper, a.Uorg] == 1
assert table[0, a.Lper, a.O] == 1
assert table[0, a.Lloc, a.Bper] == 1
assert table[0, a.Lloc, a.Bloc] == 1
assert table[0, a.Lloc, a.Borg] == 1
assert table[0, a.Lloc, a.Iper] == 0
assert table[0, a.Lloc, a.Iloc] == 0
assert table[0, a.Lloc, a.Iorg] == 0
assert table[0, a.Lloc, a.Lper] == 0
assert table[0, a.Lloc, a.Lloc] == 0
assert table[0, a.Lloc, a.Lorg] == 0
assert table[0, a.Lloc, a.Uper] == 1
assert table[0, a.Lloc, a.Uloc] == 1
assert table[0, a.Lloc, a.Uorg] == 1
assert table[0, a.Lloc, a.O] == 1
assert table[0, a.Lorg, a.Bper] == 1
assert table[0, a.Lorg, a.Bloc] == 1
assert table[0, a.Lorg, a.Borg] == 1
assert table[0, a.Lorg, a.Iper] == 0
assert table[0, a.Lorg, a.Iloc] == 0
assert table[0, a.Lorg, a.Iorg] == 0
assert table[0, a.Lorg, a.Lper] == 0
assert table[0, a.Lorg, a.Lloc] == 0
assert table[0, a.Lorg, a.Lorg] == 0
assert table[0, a.Lorg, a.Uper] == 1
assert table[0, a.Lorg, a.Uloc] == 1
assert table[0, a.Lorg, a.Uorg] == 1
assert table[0, a.Lorg, a.O] == 1
# Not last token, prev action was U
assert table[0, a.Uper, a.Bper] == 1
assert table[0, a.Uper, a.Bloc] == 1
assert table[0, a.Uper, a.Borg] == 1
assert table[0, a.Uper, a.Iper] == 0
assert table[0, a.Uper, a.Iloc] == 0
assert table[0, a.Uper, a.Iorg] == 0
assert table[0, a.Uper, a.Lper] == 0
assert table[0, a.Uper, a.Lloc] == 0
assert table[0, a.Uper, a.Lorg] == 0
assert table[0, a.Uper, a.Uper] == 1
assert table[0, a.Uper, a.Uloc] == 1
assert table[0, a.Uper, a.Uorg] == 1
assert table[0, a.Uper, a.O] == 1
assert table[0, a.Uloc, a.Bper] == 1
assert table[0, a.Uloc, a.Bloc] == 1
assert table[0, a.Uloc, a.Borg] == 1
assert table[0, a.Uloc, a.Iper] == 0
assert table[0, a.Uloc, a.Iloc] == 0
assert table[0, a.Uloc, a.Iorg] == 0
assert table[0, a.Uloc, a.Lper] == 0
assert table[0, a.Uloc, a.Lloc] == 0
assert table[0, a.Uloc, a.Lorg] == 0
assert table[0, a.Uloc, a.Uper] == 1
assert table[0, a.Uloc, a.Uloc] == 1
assert table[0, a.Uloc, a.Uorg] == 1
assert table[0, a.Uloc, a.O] == 1
assert table[0, a.Uorg, a.Bper] == 1
assert table[0, a.Uorg, a.Bloc] == 1
assert table[0, a.Uorg, a.Borg] == 1
assert table[0, a.Uorg, a.Iper] == 0
assert table[0, a.Uorg, a.Iloc] == 0
assert table[0, a.Uorg, a.Iorg] == 0
assert table[0, a.Uorg, a.Lper] == 0
assert table[0, a.Uorg, a.Lloc] == 0
assert table[0, a.Uorg, a.Lorg] == 0
assert table[0, a.Uorg, a.Uper] == 1
assert table[0, a.Uorg, a.Uloc] == 1
assert table[0, a.Uorg, a.Uorg] == 1
assert table[0, a.Uorg, a.O] == 1
# Not last token, prev action was O
assert table[0, a.O, a.Bper] == 1
assert table[0, a.O, a.Bloc] == 1
assert table[0, a.O, a.Borg] == 1
assert table[0, a.O, a.Iper] == 0
assert table[0, a.O, a.Iloc] == 0
assert table[0, a.O, a.Iorg] == 0
assert table[0, a.O, a.Lper] == 0
assert table[0, a.O, a.Lloc] == 0
assert table[0, a.O, a.Lorg] == 0
assert table[0, a.O, a.Uper] == 1
assert table[0, a.O, a.Uloc] == 1
assert table[0, a.O, a.Uorg] == 1
assert table[0, a.O, a.O] == 1
# Last token, prev action was B
assert table[1, a.Bper, a.Bper] == 0
assert table[1, a.Bper, a.Bloc] == 0
assert table[1, a.Bper, a.Borg] == 0
assert table[1, a.Bper, a.Iper] == 0
assert table[1, a.Bper, a.Iloc] == 0
assert table[1, a.Bper, a.Iorg] == 0
assert table[1, a.Bper, a.Lper] == 1
assert table[1, a.Bper, a.Lloc] == 0
assert table[1, a.Bper, a.Lorg] == 0
assert table[1, a.Bper, a.Uper] == 0
assert table[1, a.Bper, a.Uloc] == 0
assert table[1, a.Bper, a.Uorg] == 0
assert table[1, a.Bper, a.O] == 0
assert table[1, a.Bloc, a.Bper] == 0
assert table[1, a.Bloc, a.Bloc] == 0
assert table[0, a.Bloc, a.Borg] == 0
assert table[1, a.Bloc, a.Iper] == 0
assert table[1, a.Bloc, a.Iloc] == 0
assert table[1, a.Bloc, a.Iorg] == 0
assert table[1, a.Bloc, a.Lper] == 0
assert table[1, a.Bloc, a.Lloc] == 1
assert table[1, a.Bloc, a.Lorg] == 0
assert table[1, a.Bloc, a.Uper] == 0
assert table[1, a.Bloc, a.Uloc] == 0
assert table[1, a.Bloc, a.Uorg] == 0
assert table[1, a.Bloc, a.O] == 0
assert table[1, a.Borg, a.Bper] == 0
assert table[1, a.Borg, a.Bloc] == 0
assert table[1, a.Borg, a.Borg] == 0
assert table[1, a.Borg, a.Iper] == 0
assert table[1, a.Borg, a.Iloc] == 0
assert table[1, a.Borg, a.Iorg] == 0
assert table[1, a.Borg, a.Lper] == 0
assert table[1, a.Borg, a.Lloc] == 0
assert table[1, a.Borg, a.Lorg] == 1
assert table[1, a.Borg, a.Uper] == 0
assert table[1, a.Borg, a.Uloc] == 0
assert table[1, a.Borg, a.Uorg] == 0
assert table[1, a.Borg, a.O] == 0
# Last token, prev action was I
assert table[1, a.Iper, a.Bper] == 0
assert table[1, a.Iper, a.Bloc] == 0
assert table[1, a.Iper, a.Borg] == 0
assert table[1, a.Iper, a.Iper] == 0
assert table[1, a.Iper, a.Iloc] == 0
assert table[1, a.Iper, a.Iorg] == 0
assert table[1, a.Iper, a.Lper] == 1
assert table[1, a.Iper, a.Lloc] == 0
assert table[1, a.Iper, a.Lorg] == 0
assert table[1, a.Iper, a.Uper] == 0
assert table[1, a.Iper, a.Uloc] == 0
assert table[1, a.Iper, a.Uorg] == 0
assert table[1, a.Iper, a.O] == 0
assert table[1, a.Iloc, a.Bper] == 0
assert table[1, a.Iloc, a.Bloc] == 0
assert table[1, a.Iloc, a.Borg] == 0
assert table[1, a.Iloc, a.Iper] == 0
assert table[1, a.Iloc, a.Iloc] == 0
assert table[1, a.Iloc, a.Iorg] == 0
assert table[1, a.Iloc, a.Lper] == 0
assert table[1, a.Iloc, a.Lloc] == 1
assert table[1, a.Iloc, a.Lorg] == 0
assert table[1, a.Iloc, a.Uper] == 0
assert table[1, a.Iloc, a.Uloc] == 0
assert table[1, a.Iloc, a.Uorg] == 0
assert table[1, a.Iloc, a.O] == 0
assert table[1, a.Iorg, a.Bper] == 0
assert table[1, a.Iorg, a.Bloc] == 0
assert table[1, a.Iorg, a.Borg] == 0
assert table[1, a.Iorg, a.Iper] == 0
assert table[1, a.Iorg, a.Iloc] == 0
assert table[1, a.Iorg, a.Iorg] == 0
assert table[1, a.Iorg, a.Lper] == 0
assert table[1, a.Iorg, a.Lloc] == 0
assert table[1, a.Iorg, a.Lorg] == 1
assert table[1, a.Iorg, a.Uper] == 0
assert table[1, a.Iorg, a.Uloc] == 0
assert table[1, a.Iorg, a.Uorg] == 0
assert table[1, a.Iorg, a.O] == 0
# Last token, prev action was L
assert table[1, a.Lper, a.Bper] == 0
assert table[1, a.Lper, a.Bloc] == 0
assert table[1, a.Lper, a.Borg] == 0
assert table[1, a.Lper, a.Iper] == 0
assert table[1, a.Lper, a.Iloc] == 0
assert table[1, a.Lper, a.Iorg] == 0
assert table[1, a.Lper, a.Lper] == 0
assert table[1, a.Lper, a.Lloc] == 0
assert table[1, a.Lper, a.Lorg] == 0
assert table[1, a.Lper, a.Uper] == 1
assert table[1, a.Lper, a.Uloc] == 1
assert table[1, a.Lper, a.Uorg] == 1
assert table[1, a.Lper, a.O] == 1
assert table[1, a.Lloc, a.Bper] == 0
assert table[1, a.Lloc, a.Bloc] == 0
assert table[1, a.Lloc, a.Borg] == 0
assert table[1, a.Lloc, a.Iper] == 0
assert table[1, a.Lloc, a.Iloc] == 0
assert table[1, a.Lloc, a.Iorg] == 0
assert table[1, a.Lloc, a.Lper] == 0
assert table[1, a.Lloc, a.Lloc] == 0
assert table[1, a.Lloc, a.Lorg] == 0
assert table[1, a.Lloc, a.Uper] == 1
assert table[1, a.Lloc, a.Uloc] == 1
assert table[1, a.Lloc, a.Uorg] == 1
assert table[1, a.Lloc, a.O] == 1
assert table[1, a.Lorg, a.Bper] == 0
assert table[1, a.Lorg, a.Bloc] == 0
assert table[1, a.Lorg, a.Borg] == 0
assert table[1, a.Lorg, a.Iper] == 0
assert table[1, a.Lorg, a.Iloc] == 0
assert table[1, a.Lorg, a.Iorg] == 0
assert table[1, a.Lorg, a.Lper] == 0
assert table[1, a.Lorg, a.Lloc] == 0
assert table[1, a.Lorg, a.Lorg] == 0
assert table[1, a.Lorg, a.Uper] == 1
assert table[1, a.Lorg, a.Uloc] == 1
assert table[1, a.Lorg, a.Uorg] == 1
assert table[1, a.Lorg, a.O] == 1
# Last token, prev action was U
assert table[1, a.Uper, a.Bper] == 0
assert table[1, a.Uper, a.Bloc] == 0
assert table[1, a.Uper, a.Borg] == 0
assert table[1, a.Uper, a.Iper] == 0
assert table[1, a.Uper, a.Iloc] == 0
assert table[1, a.Uper, a.Iorg] == 0
assert table[1, a.Uper, a.Lper] == 0
assert table[1, a.Uper, a.Lloc] == 0
assert table[1, a.Uper, a.Lorg] == 0
assert table[1, a.Uper, a.Uper] == 1
assert table[1, a.Uper, a.Uloc] == 1
assert table[1, a.Uper, a.Uorg] == 1
assert table[1, a.Uper, a.O] == 1
assert table[1, a.Uloc, a.Bper] == 0
assert table[1, a.Uloc, a.Bloc] == 0
assert table[1, a.Uloc, a.Borg] == 0
assert table[1, a.Uloc, a.Iper] == 0
assert table[1, a.Uloc, a.Iloc] == 0
assert table[1, a.Uloc, a.Iorg] == 0
assert table[1, a.Uloc, a.Lper] == 0
assert table[1, a.Uloc, a.Lloc] == 0
assert table[1, a.Uloc, a.Lorg] == 0
assert table[1, a.Uloc, a.Uper] == 1
assert table[1, a.Uloc, a.Uloc] == 1
assert table[1, a.Uloc, a.Uorg] == 1
assert table[1, a.Uloc, a.O] == 1
assert table[1, a.Uorg, a.Bper] == 0
assert table[1, a.Uorg, a.Bloc] == 0
assert table[1, a.Uorg, a.Borg] == 0
assert table[1, a.Uorg, a.Iper] == 0
assert table[1, a.Uorg, a.Iloc] == 0
assert table[1, a.Uorg, a.Iorg] == 0
assert table[1, a.Uorg, a.Lper] == 0
assert table[1, a.Uorg, a.Lloc] == 0
assert table[1, a.Uorg, a.Lorg] == 0
assert table[1, a.Uorg, a.Uper] == 1
assert table[1, a.Uorg, a.Uloc] == 1
assert table[1, a.Uorg, a.Uorg] == 1
assert table[1, a.Uorg, a.O] == 1
# Last token, prev action was O
assert table[1, a.O, a.Bper] == 0
assert table[1, a.O, a.Bloc] == 0
assert table[1, a.O, a.Borg] == 0
assert table[1, a.O, a.Iper] == 0
assert table[1, a.O, a.Iloc] == 0
assert table[1, a.O, a.Iorg] == 0
assert table[1, a.O, a.Lper] == 0
assert table[1, a.O, a.Lloc] == 0
assert table[1, a.O, a.Lorg] == 0
assert table[1, a.O, a.Uper] == 1
assert table[1, a.O, a.Uloc] == 1
assert table[1, a.O, a.Uorg] == 1
assert table[1, a.O, a.O] == 1

View File

@ -117,9 +117,7 @@ def test_overfitting_IO():
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1) assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
# Test scoring # Test scoring
scores = nlp.evaluate( scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
)
assert scores["cats_f"] == 1.0 assert scores["cats_f"] == 1.0
assert scores["cats_score"] == 1.0 assert scores["cats_score"] == 1.0
assert "cats_score_desc" in scores assert "cats_score_desc" in scores

View File

@ -438,9 +438,8 @@ def test_issue4402():
data = DocBin(docs=docs, attrs=attrs).to_bytes() data = DocBin(docs=docs, attrs=attrs).to_bytes()
with output_file.open("wb") as file_: with output_file.open("wb") as file_:
file_.write(data) file_.write(data)
corpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file)) reader = Corpus(output_file)
train_data = list(reader(nlp))
train_data = list(corpus.train_dataset(nlp))
assert len(train_data) == 2 assert len(train_data) == 2
split_train_data = [] split_train_data = []

View File

@ -139,7 +139,8 @@ def test_issue4665():
def test_issue4674(): def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO""" """Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English() nlp = English()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3) kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(nlp.vocab)
vector1 = [0.9, 1.1, 1.01] vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01] vector2 = [1.8, 2.25, 2.01]
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
@ -156,7 +157,8 @@ def test_issue4674():
dir_path.mkdir() dir_path.mkdir()
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb.dump(str(file_path)) kb.dump(str(file_path))
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3) kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(nlp.vocab)
kb2.load_bulk(str(file_path)) kb2.load_bulk(str(file_path))
assert kb2.get_size_entities() == 1 assert kb2.get_size_entities() == 1

View File

@ -27,6 +27,6 @@ def test_issue5137():
with make_tempdir() as tmpdir: with make_tempdir() as tmpdir:
nlp.to_disk(tmpdir) nlp.to_disk(tmpdir)
overrides = {"my_component": {"categories": "my_categories"}} overrides = {"components": {"my_component": {"categories": "my_categories"}}}
nlp2 = spacy.load(tmpdir, component_cfg=overrides) nlp2 = spacy.load(tmpdir, config=overrides)
assert nlp2.get_pipe("my_component").categories == "my_categories" assert nlp2.get_pipe("my_component").categories == "my_categories"

View File

@ -72,7 +72,8 @@ def entity_linker():
@registry.assets.register("TestIssue5230KB.v1") @registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> KnowledgeBase:
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb return kb
@ -121,7 +122,8 @@ def test_writer_with_path_py35():
def test_save_and_load_knowledge_base(): def test_save_and_load_knowledge_base():
nlp = Language() nlp = Language()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
with make_tempdir() as d: with make_tempdir() as d:
path = d / "kb" path = d / "kb"
try: try:
@ -130,7 +132,8 @@ def test_save_and_load_knowledge_base():
pytest.fail(str(e)) pytest.fail(str(e))
try: try:
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1) kb_loaded = KnowledgeBase(entity_vector_length=1)
kb_loaded.initialize(nlp.vocab)
kb_loaded.load_bulk(path) kb_loaded.load_bulk(path)
except Exception as e: except Exception as e:
pytest.fail(str(e)) pytest.fail(str(e))

View File

@ -2,6 +2,7 @@ import pytest
from thinc.config import Config, ConfigValidationError from thinc.config import Config, ConfigValidationError
import spacy import spacy
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German
from spacy.language import Language from spacy.language import Language
from spacy.util import registry, deep_merge_configs, load_model_from_config from spacy.util import registry, deep_merge_configs, load_model_from_config
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
@ -11,8 +12,23 @@ from ..util import make_tempdir
nlp_config_string = """ nlp_config_string = """
[paths]
train = ""
dev = ""
[training] [training]
batch_size = 666
[training.train_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
[training.dev_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
[training.batcher]
@batchers = "batch_by_words.v1"
size = 666
[nlp] [nlp]
lang = "en" lang = "en"
@ -73,14 +89,9 @@ def my_parser():
width=321, width=321,
rows=5432, rows=5432,
also_embed_subwords=True, also_embed_subwords=True,
also_use_static_vectors=False also_use_static_vectors=False,
), ),
MaxoutWindowEncoder( MaxoutWindowEncoder(width=321, window_size=3, maxout_pieces=4, depth=2),
width=321,
window_size=3,
maxout_pieces=4,
depth=2
)
) )
parser = build_tb_parser_model( parser = build_tb_parser_model(
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5 tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
@ -93,7 +104,7 @@ def test_create_nlp_from_config():
with pytest.raises(ConfigValidationError): with pytest.raises(ConfigValidationError):
nlp, _ = load_model_from_config(config, auto_fill=False) nlp, _ = load_model_from_config(config, auto_fill=False)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp, resolved = load_model_from_config(config, auto_fill=True)
assert nlp.config["training"]["batch_size"] == 666 assert nlp.config["training"]["batcher"]["size"] == 666
assert len(nlp.config["training"]) > 1 assert len(nlp.config["training"]) > 1
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
assert len(nlp.config["components"]) == 2 assert len(nlp.config["components"]) == 2
@ -272,3 +283,33 @@ def test_serialize_config_missing_pipes():
assert "tok2vec" not in config["components"] assert "tok2vec" not in config["components"]
with pytest.raises(ValueError): with pytest.raises(ValueError):
load_model_from_config(config, auto_fill=True) load_model_from_config(config, auto_fill=True)
def test_config_overrides():
overrides_nested = {"nlp": {"lang": "de", "pipeline": ["tagger"]}}
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
# load_model from config with overrides passed directly to Config
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
nlp, _ = load_model_from_config(config, auto_fill=True)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
# Serialized roundtrip with config passed in
base_config = Config().from_str(nlp_config_string)
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
assert isinstance(base_nlp, English)
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d, config=overrides_nested)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d, config=overrides_dot)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
with make_tempdir() as d:
base_nlp.to_disk(d)
nlp = spacy.load(d)
assert isinstance(nlp, English)
assert nlp.pipe_names == ["tok2vec", "tagger"]

Some files were not shown because too many files have changed in this diff Show More