Merge remote-tracking branch 'upstream/develop' into feature/more-v3-docs

This commit is contained in:
svlandeg 2020-08-18 18:55:12 +02:00
commit abba639565
42 changed files with 779 additions and 687 deletions

View File

@ -15,7 +15,8 @@ import spacy.util
from bin.ud import conll17_ud_eval from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import Example from spacy.gold import Example
from spacy.util import compounding, minibatch, minibatch_by_words from spacy.util import compounding, minibatch
from spacy.gold.batchers import minibatch_by_words
from spacy.pipeline._parser_internals.nonproj import projectivize from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy import displacy from spacy import displacy

View File

@ -48,8 +48,7 @@ def main(model, output_dir=None):
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality. # You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
# For simplicity, we'll just use the original vector dimension here instead. # For simplicity, we'll just use the original vector dimension here instead.
vectors_dim = nlp.vocab.vectors.shape[1] vectors_dim = nlp.vocab.vectors.shape[1]
kb = KnowledgeBase(entity_vector_length=vectors_dim) kb = KnowledgeBase(nlp.vocab, entity_vector_length=vectors_dim)
kb.initialize(nlp.vocab)
# set up the data # set up the data
entity_ids = [] entity_ids = []
@ -81,7 +80,7 @@ def main(model, output_dir=None):
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()
kb_path = str(output_dir / "kb") kb_path = str(output_dir / "kb")
kb.dump(kb_path) kb.to_disk(kb_path)
print() print()
print("Saved KB to", kb_path) print("Saved KB to", kb_path)
@ -96,9 +95,8 @@ def main(model, output_dir=None):
print("Loading vocab from", vocab_path) print("Loading vocab from", vocab_path)
print("Loading KB from", kb_path) print("Loading KB from", kb_path)
vocab2 = Vocab().from_disk(vocab_path) vocab2 = Vocab().from_disk(vocab_path)
kb2 = KnowledgeBase(entity_vector_length=1) kb2 = KnowledgeBase(vocab2, entity_vector_length=1)
kb.initialize(vocab2) kb2.from_disk(kb_path)
kb2.load_bulk(kb_path)
print() print()
_print_kb(kb2) _print_kb(kb2)

View File

@ -83,7 +83,7 @@ def main(kb_path, vocab_path, output_dir=None, n_iter=50):
if "entity_linker" not in nlp.pipe_names: if "entity_linker" not in nlp.pipe_names:
print("Loading Knowledge Base from '%s'" % kb_path) print("Loading Knowledge Base from '%s'" % kb_path)
cfg = { cfg = {
"kb": { "kb_loader": {
"@assets": "spacy.KBFromFile.v1", "@assets": "spacy.KBFromFile.v1",
"vocab_path": vocab_path, "vocab_path": vocab_path,
"kb_path": kb_path, "kb_path": kb_path,

View File

@ -68,11 +68,12 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
opt = args.pop(0) opt = args.pop(0)
err = f"Invalid CLI argument '{opt}'" err = f"Invalid CLI argument '{opt}'"
if opt.startswith("--"): # new argument if opt.startswith("--"): # new argument
opt = opt.replace("--", "").replace("-", "_") opt = opt.replace("--", "")
if "." not in opt: if "." not in opt:
msg.fail(f"{err}: can't override top-level section", exits=1) msg.fail(f"{err}: can't override top-level section", exits=1)
if "=" in opt: # we have --opt=value if "=" in opt: # we have --opt=value
opt, value = opt.split("=", 1) opt, value = opt.split("=", 1)
opt = opt.replace("-", "_")
else: else:
if not args or args[0].startswith("--"): # flag with no value if not args or args[0].startswith("--"): # flag with no value
value = "true" value = "true"

View File

@ -229,6 +229,7 @@ if __name__ == '__main__':
TEMPLATE_MANIFEST = """ TEMPLATE_MANIFEST = """
include meta.json include meta.json
include config.cfg
""".strip() """.strip()

View File

@ -75,7 +75,7 @@ def train(
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}") msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=config_overrides) config = util.load_config(config_path, overrides=config_overrides, interpolate=True)
if config.get("training", {}).get("seed") is not None: if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions # Use original config here before it's resolved to functions

View File

@ -78,10 +78,11 @@ class Warnings:
"are currently: {langs}") "are currently: {langs}")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
W090 = ("Could not locate any binary .spacy files in path '{path}'.")
W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.") W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.")
W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.") W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.")
W093 = ("Could not find any data to train the {name} on. Is your " W093 = ("Could not find any data to train the {name} on. Is your "
"input data correctly formatted ?") "input data correctly formatted?")
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained " W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
"spaCy version requirement: {version}. This can lead to compatibility " "spaCy version requirement: {version}. This can lead to compatibility "
"problems with older versions, or as new spaCy versions are " "problems with older versions, or as new spaCy versions are "
@ -476,6 +477,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the "
"provided argument {loc} is an existing directory.")
E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does "
"not seem to exist.")
E930 = ("Received invalid get_examples callback in {name}.begin_training. " E930 = ("Received invalid get_examples callback in {name}.begin_training. "
"Expected function that returns an iterable of Example objects but " "Expected function that returns an iterable of Example objects but "
"got: {obj}") "got: {obj}")
@ -503,8 +508,6 @@ class Errors:
"not found in pipeline. Available components: {opts}") "not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}") "nlp object, but got: {source}")
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected " E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'") "a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected " E948 = ("Matcher.add received invalid 'patterns' argument: expected "
@ -600,7 +603,8 @@ class Errors:
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}") "\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}") E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
E986 = ("Could not create any training batches: check your input. " E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?") "Are the train and dev paths defined? "
"Is 'discard_oversize' set appropriately? ")
E987 = ("The text of an example training instance is either a Doc or " E987 = ("The text of an example training instance is either a Doc or "
"a string, but found {type} instead.") "a string, but found {type} instead.")
E988 = ("Could not parse any training examples. Ensure the data is " E988 = ("Could not parse any training examples. Ensure the data is "
@ -610,8 +614,6 @@ class Errors:
"of the training data in spaCy 3.0 onwards. The 'update' " "of the training data in spaCy 3.0 onwards. The 'update' "
"function should now be called with a batch of 'Example' " "function should now be called with a batch of 'Example' "
"objects, instead of (text, annotation) tuples. ") "objects, instead of (text, annotation) tuples. ")
E990 = ("An entity linking component needs to be initialized with a "
"KnowledgeBase object, but found {type} instead.")
E991 = ("The function 'select_pipes' should be called with either a " E991 = ("The function 'select_pipes' should be called with either a "
"'disable' argument to list the names of the pipe components " "'disable' argument to list the names of the pipe components "
"that should be disabled, or with an 'enable' argument that " "that should be disabled, or with an 'enable' argument that "

View File

@ -1,8 +1,10 @@
import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from pathlib import Path from pathlib import Path
from .. import util from .. import util
from .example import Example from .example import Example
from ..errors import Warnings
from ..tokens import DocBin, Doc from ..tokens import DocBin, Doc
from ..vocab import Vocab from ..vocab import Vocab
@ -10,6 +12,8 @@ if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports # This lets us add type hints for mypy etc. without causing circular imports
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
FILE_TYPE = ".spacy"
@util.registry.readers("spacy.Corpus.v1") @util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader( def create_docbin_reader(
@ -53,8 +57,9 @@ class Corpus:
@staticmethod @staticmethod
def walk_corpus(path: Union[str, Path]) -> List[Path]: def walk_corpus(path: Union[str, Path]) -> List[Path]:
path = util.ensure_path(path) path = util.ensure_path(path)
if not path.is_dir(): if not path.is_dir() and path.parts[-1].endswith(FILE_TYPE):
return [path] return [path]
orig_path = path
paths = [path] paths = [path]
locs = [] locs = []
seen = set() seen = set()
@ -66,8 +71,10 @@ class Corpus:
continue continue
elif path.is_dir(): elif path.is_dir():
paths.extend(path.iterdir()) paths.extend(path.iterdir())
elif path.parts[-1].endswith(".spacy"): elif path.parts[-1].endswith(FILE_TYPE):
locs.append(path) locs.append(path)
if len(locs) == 0:
warnings.warn(Warnings.W090.format(path=orig_path))
return locs return locs
def __call__(self, nlp: "Language") -> Iterator[Example]: def __call__(self, nlp: "Language") -> Iterator[Example]:
@ -135,7 +142,7 @@ class Corpus:
i = 0 i = 0
for loc in locs: for loc in locs:
loc = util.ensure_path(loc) loc = util.ensure_path(loc)
if loc.parts[-1].endswith(".spacy"): if loc.parts[-1].endswith(FILE_TYPE):
doc_bin = DocBin().from_disk(loc) doc_bin = DocBin().from_disk(loc)
docs = doc_bin.get_docs(vocab) docs = doc_bin.get_docs(vocab)
for doc in docs: for doc in docs:

View File

@ -140,7 +140,7 @@ cdef class KnowledgeBase:
self._entries.push_back(entry) self._entries.push_back(entry)
self._aliases_table.push_back(alias) self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc) cpdef from_disk(self, loc)
cpdef set_entities(self, entity_list, freq_list, vector_list) cpdef set_entities(self, entity_list, freq_list, vector_list)

View File

@ -1,4 +1,5 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from typing import Iterator
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from cpython.exc cimport PyErr_SetFromErrno from cpython.exc cimport PyErr_SetFromErrno
@ -64,6 +65,16 @@ cdef class Candidate:
return self.prior_prob return self.prior_prob
def get_candidates(KnowledgeBase kb, span) -> Iterator[Candidate]:
"""
Return candidate entities for a given span by using the text of the span as the alias
and fetching appropriate entries from the index.
This particular function is optimized to work with the built-in KB functionality,
but any other custom candidate generation method can be used in combination with the KB as well.
"""
return kb.get_alias_candidates(span.text)
cdef class KnowledgeBase: cdef class KnowledgeBase:
"""A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases, """A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases,
to support entity linking of named entities to real-world concepts. to support entity linking of named entities to real-world concepts.
@ -71,25 +82,16 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
def __init__(self, entity_vector_length): def __init__(self, Vocab vocab, entity_vector_length):
"""Create a KnowledgeBase. Make sure to call kb.initialize() before using it.""" """Create a KnowledgeBase."""
self.mem = Pool() self.mem = Pool()
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap() self._entry_index = PreshMap()
self._alias_index = PreshMap() self._alias_index = PreshMap()
self.vocab = None
def initialize(self, Vocab vocab):
self.vocab = vocab self.vocab = vocab
self.vocab.strings.add("") self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def require_vocab(self):
if self.vocab is None:
raise ValueError(Errors.E946)
@property @property
def entity_vector_length(self): def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors""" """RETURNS (uint64): length of the entity vectors"""
@ -102,14 +104,12 @@ cdef class KnowledgeBase:
return len(self._entry_index) return len(self._entry_index)
def get_entity_strings(self): def get_entity_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._entry_index] return [self.vocab.strings[x] for x in self._entry_index]
def get_size_aliases(self): def get_size_aliases(self):
return len(self._alias_index) return len(self._alias_index)
def get_alias_strings(self): def get_alias_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._alias_index] return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float freq, vector[float] entity_vector): def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
@ -117,7 +117,6 @@ cdef class KnowledgeBase:
Add an entity to the KB, optionally specifying its log probability based on corpus frequency Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end. Return the hash of the entity ID/name at the end.
""" """
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before # Return if this entity was added before
@ -140,7 +139,6 @@ cdef class KnowledgeBase:
return entity_hash return entity_hash
cpdef set_entities(self, entity_list, freq_list, vector_list): cpdef set_entities(self, entity_list, freq_list, vector_list):
self.require_vocab()
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list): if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140) raise ValueError(Errors.E140)
@ -176,12 +174,10 @@ cdef class KnowledgeBase:
i += 1 i += 1
def contains_entity(self, unicode entity): def contains_entity(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index return entity_hash in self._entry_index
def contains_alias(self, unicode alias): def contains_alias(self, unicode alias):
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings.add(alias) cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index return alias_hash in self._alias_index
@ -190,7 +186,6 @@ cdef class KnowledgeBase:
For a given alias, add its potential entities and prior probabilies to the KB. For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end Return the alias_hash at the end
""" """
self.require_vocab()
# Throw an error if the length of entities and probabilities are not the same # Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities): if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias, raise ValueError(Errors.E132.format(alias=alias,
@ -234,7 +229,6 @@ cdef class KnowledgeBase:
Throw an error if this entity+prior prob would exceed the sum of 1. Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
""" """
self.require_vocab()
# Check if the alias exists in the KB # Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
@ -274,14 +268,12 @@ cdef class KnowledgeBase:
alias_entry.probs = probs alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry self._aliases_table[alias_index] = alias_entry
def get_alias_candidates(self, unicode alias) -> Iterator[Candidate]:
def get_candidates(self, unicode alias):
""" """
Return candidate entities for an alias. Each candidate defines the entity, the original alias, Return candidate entities for an alias. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity. and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned. If the alias is not known in the KB, and empty list is returned.
""" """
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
return [] return []
@ -298,7 +290,6 @@ cdef class KnowledgeBase:
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity): def get_vector(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB # Return an empty list if this entity is unknown in this KB
@ -311,7 +302,6 @@ cdef class KnowledgeBase:
def get_prior_prob(self, unicode entity, unicode alias): def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity, """ Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base""" or return 0.0 when this combination is not known in the knowledge base"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
@ -329,8 +319,7 @@ cdef class KnowledgeBase:
return 0.0 return 0.0
def dump(self, loc): def to_disk(self, loc):
self.require_vocab()
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length) writer.write_header(self.get_size_entities(), self.entity_vector_length)
@ -370,7 +359,7 @@ cdef class KnowledgeBase:
writer.close() writer.close()
cpdef load_bulk(self, loc): cpdef from_disk(self, loc):
cdef hash_t entity_hash cdef hash_t entity_hash
cdef hash_t alias_hash cdef hash_t alias_hash
cdef int64_t entry_index cdef int64_t entry_index
@ -462,12 +451,11 @@ cdef class KnowledgeBase:
cdef class Writer: cdef class Writer:
def __init__(self, object loc): def __init__(self, object loc):
if path.exists(loc):
assert not path.isdir(loc), f"{loc} is directory"
if isinstance(loc, Path): if isinstance(loc, Path):
loc = bytes(loc) loc = bytes(loc)
if path.exists(loc): if path.exists(loc):
assert not path.isdir(loc), "%s is directory." % loc if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb') self._fp = fopen(<char*>bytes_loc, 'wb')
if not self._fp: if not self._fp:
@ -511,8 +499,10 @@ cdef class Reader:
def __init__(self, object loc): def __init__(self, object loc):
if isinstance(loc, Path): if isinstance(loc, Path):
loc = bytes(loc) loc = bytes(loc)
assert path.exists(loc) if not path.exists(loc):
assert not path.isdir(loc) raise ValueError(Errors.E929.format(loc=loc))
if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'rb') self._fp = fopen(<char*>bytes_loc, 'rb')
if not self._fp: if not self._fp:

View File

@ -772,9 +772,9 @@ class Language:
self.remove_pipe(name) self.remove_pipe(name)
if not len(self.pipeline) or pipe_index == len(self.pipeline): if not len(self.pipeline) or pipe_index == len(self.pipeline):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name) self.add_pipe(factory_name, name=name, config=config, validate=validate)
else: else:
self.add_pipe(factory_name, name=name, before=pipe_index) self.add_pipe(factory_name, name=name, before=pipe_index, config=config, validate=validate)
def rename_pipe(self, old_name: str, new_name: str) -> None: def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component. """Rename a pipeline component.

View File

@ -1,9 +1,9 @@
from typing import Optional from typing import Optional, Callable, Iterable
from thinc.api import chain, clone, list2ragged, reduce_mean, residual from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear from thinc.api import Model, Maxout, Linear
from ...util import registry from ...util import registry
from ...kb import KnowledgeBase from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab from ...vocab import Vocab
@ -25,15 +25,21 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
@registry.assets.register("spacy.KBFromFile.v1") @registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase: def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
vocab = Vocab().from_disk(vocab_path) def kb_from_file(vocab):
kb = KnowledgeBase(entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.initialize(vocab) kb.from_disk(kb_path)
kb.load_bulk(kb_path) return kb
return kb return kb_from_file
@registry.assets.register("spacy.EmptyKB.v1") @registry.assets.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> KnowledgeBase: def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
kb = KnowledgeBase(entity_vector_length=entity_vector_length) def empty_kb_factory(vocab):
return kb return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.assets.register("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_candidates

View File

@ -6,7 +6,7 @@ from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
import warnings import warnings
from ..kb import KnowledgeBase from ..kb import KnowledgeBase, Candidate
from ..tokens import Doc from ..tokens import Doc
from .pipe import Pipe, deserialize_config from .pipe import Pipe, deserialize_config
from ..language import Language from ..language import Language
@ -32,35 +32,30 @@ subword_features = true
""" """
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_kb_config = """
[kb]
@assets = "spacy.EmptyKB.v1"
entity_vector_length = 64
"""
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
@Language.factory( @Language.factory(
"entity_linker", "entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"], assigns=["token.ent_kb_id"],
default_config={ default_config={
"kb": DEFAULT_NEL_KB, "kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL, "model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"get_candidates": {"@assets": "spacy.CandidateGenerator.v1"},
}, },
) )
def make_entity_linker( def make_entity_linker(
nlp: Language, nlp: Language,
name: str, name: str,
model: Model, model: Model,
kb: KnowledgeBase, kb_loader: Callable[[Vocab], KnowledgeBase],
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -76,10 +71,11 @@ def make_entity_linker(
nlp.vocab, nlp.vocab,
model, model,
name, name,
kb=kb, kb_loader=kb_loader,
labels_discard=labels_discard, labels_discard=labels_discard,
incl_prior=incl_prior, incl_prior=incl_prior,
incl_context=incl_context, incl_context=incl_context,
get_candidates=get_candidates,
) )
@ -97,10 +93,11 @@ class EntityLinker(Pipe):
model: Model, model: Model,
name: str = "entity_linker", name: str = "entity_linker",
*, *,
kb: KnowledgeBase, kb_loader: Callable[[Vocab], KnowledgeBase],
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -108,7 +105,7 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases. kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
@ -119,17 +116,12 @@ class EntityLinker(Pipe):
self.model = model self.model = model
self.name = name self.name = name
cfg = { cfg = {
"kb": kb,
"labels_discard": list(labels_discard), "labels_discard": list(labels_discard),
"incl_prior": incl_prior, "incl_prior": incl_prior,
"incl_context": incl_context, "incl_context": incl_context,
} }
if not isinstance(kb, KnowledgeBase): self.kb = kb_loader(self.vocab)
raise ValueError(Errors.E990.format(type=type(self.kb))) self.get_candidates = get_candidates
kb.initialize(vocab)
self.kb = kb
if "kb" in cfg:
del cfg["kb"] # we don't want to duplicate its serialization
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account # how many neightbour sentences to take into account
@ -326,10 +318,11 @@ class EntityLinker(Pipe):
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] xp = self.model.ops.xp
xp = get_array_module(sentence_encoding) if self.cfg.get("incl_context"):
sentence_encoding_t = sentence_encoding.T sentence_encoding = self.model.predict([sent_doc])[0]
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents: for ent in sent.ents:
entity_count += 1 entity_count += 1
to_discard = self.cfg.get("labels_discard", []) to_discard = self.cfg.get("labels_discard", [])
@ -337,7 +330,7 @@ class EntityLinker(Pipe):
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
else: else:
candidates = self.kb.get_candidates(ent.text) candidates = self.get_candidates(self.kb, ent)
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
@ -421,10 +414,9 @@ class EntityLinker(Pipe):
DOCS: https://spacy.io/api/entitylinker#to_disk DOCS: https://spacy.io/api/entitylinker#to_disk
""" """
serialize = {} serialize = {}
self.cfg["entity_width"] = self.kb.entity_vector_length
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.dump(p) serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p) serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -446,15 +438,10 @@ class EntityLinker(Pipe):
except AttributeError: except AttributeError:
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
def load_kb(p):
self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
self.kb.initialize(self.vocab)
self.kb.load_bulk(p)
deserialize = {} deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p) deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p)) deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["kb"] = load_kb deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self

View File

@ -68,7 +68,6 @@ class Tagger(Pipe):
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
labels (List): The set of labels. Defaults to None. labels (List): The set of labels. Defaults to None.
set_morphology (bool): Whether to set morphological features.
DOCS: https://spacy.io/api/tagger#init DOCS: https://spacy.io/api/tagger#init
""" """

View File

@ -167,18 +167,20 @@ class ModelMetaSchema(BaseModel):
lang: StrictStr = Field(..., title="Two-letter language code, e.g. 'en'") lang: StrictStr = Field(..., title="Two-letter language code, e.g. 'en'")
name: StrictStr = Field(..., title="Model name") name: StrictStr = Field(..., title="Model name")
version: StrictStr = Field(..., title="Model version") version: StrictStr = Field(..., title="Model version")
spacy_version: Optional[StrictStr] = Field(None, title="Compatible spaCy version identifier") spacy_version: StrictStr = Field("", title="Compatible spaCy version identifier")
parent_package: Optional[StrictStr] = Field("spacy", title="Name of parent spaCy package, e.g. spacy or spacy-nightly") parent_package: StrictStr = Field("spacy", title="Name of parent spaCy package, e.g. spacy or spacy-nightly")
pipeline: Optional[List[StrictStr]] = Field([], title="Names of pipeline components") pipeline: List[StrictStr] = Field([], title="Names of pipeline components")
description: Optional[StrictStr] = Field(None, title="Model description") description: StrictStr = Field("", title="Model description")
license: Optional[StrictStr] = Field(None, title="Model license") license: StrictStr = Field("", title="Model license")
author: Optional[StrictStr] = Field(None, title="Model author name") author: StrictStr = Field("", title="Model author name")
email: Optional[StrictStr] = Field(None, title="Model author email") email: StrictStr = Field("", title="Model author email")
url: Optional[StrictStr] = Field(None, title="Model author URL") url: StrictStr = Field("", title="Model author URL")
sources: Optional[Union[List[StrictStr], Dict[str, str]]] = Field(None, title="Training data sources") sources: Optional[Union[List[StrictStr], List[Dict[str, str]]]] = Field(None, title="Training data sources")
vectors: Optional[Dict[str, Any]] = Field(None, title="Included word vectors") vectors: Dict[str, Any] = Field({}, title="Included word vectors")
accuracy: Optional[Dict[str, Union[float, int]]] = Field(None, title="Accuracy numbers") labels: Dict[str, Dict[str, List[str]]] = Field({}, title="Component labels, keyed by component name")
speed: Optional[Dict[str, Union[float, int]]] = Field(None, title="Speed evaluation numbers") accuracy: Dict[str, Union[float, Dict[str, float]]] = Field({}, title="Accuracy numbers")
speed: Dict[str, Union[float, int]] = Field({}, title="Speed evaluation numbers")
spacy_git_version: StrictStr = Field("", title="Commit of spaCy version used")
# fmt: on # fmt: on

View File

@ -1,6 +1,7 @@
from typing import Callable, Iterable
import pytest import pytest
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase, get_candidates, Candidate
from spacy import util, registry from spacy import util, registry
from spacy.gold import Example from spacy.gold import Example
@ -21,8 +22,7 @@ def assert_almost_equal(a, b):
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
@ -51,8 +51,7 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity""" """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -68,8 +67,7 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp): def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities""" """Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -83,8 +81,7 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp): def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists""" """Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -100,8 +97,7 @@ def test_kb_invalid_combination(nlp):
def test_kb_invalid_entity_vector(nlp): def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths""" """Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
@ -117,14 +113,14 @@ def test_kb_default(nlp):
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0 assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0 assert entity_linker.kb.get_size_aliases() == 0
# default value from pipeline.entity_linker # 64 is the default value from pipeline.entity_linker
assert entity_linker.kb.entity_vector_length == 64 assert entity_linker.kb.entity_vector_length == 64
def test_kb_custom_length(nlp): def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length""" """Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", config={"kb": {"entity_vector_length": 35}} "entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
) )
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0 assert entity_linker.kb.get_size_entities() == 0
@ -141,7 +137,7 @@ def test_kb_undefined(nlp):
def test_kb_empty(nlp): def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB""" """Test that the EL can't train with an empty KB"""
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}} config = {"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -150,8 +146,13 @@ def test_kb_empty(nlp):
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
"""Test correct candidate generation""" """Test correct candidate generation"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab) doc = nlp("douglas adam Adam shrubbery")
douglas_ent = doc[0:1]
adam_ent = doc[1:2]
Adam_ent = doc[2:3]
shrubbery_ent = doc[3:4]
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -163,21 +164,76 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2 assert len(get_candidates(mykb, douglas_ent)) == 2
assert len(mykb.get_candidates("adam")) == 1 assert len(get_candidates(mykb, adam_ent)) == 1
assert len(mykb.get_candidates("shrubbery")) == 0 assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates # test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2" assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam" assert get_candidates(mykb, adam_ent)[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 12) assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp):
"""Test correct candidate generation as part of the EL pipe"""
nlp.add_pipe("sentencizer")
pattern = {"label": "PERSON", "pattern": [{"LOWER": "douglas"}]}
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns([pattern])
@registry.assets.register("myAdamKB.v1")
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
kb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
)
return kb
return create_kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe(
"entity_linker",
config={"kb_loader": {"@assets": "myAdamKB.v1"}, "incl_context": False},
)
# With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same."
doc = nlp(text)
assert doc[0].ent_kb_id_ == "NIL"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower())
@registry.assets.register("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_lowercased_candidates
# replace the pipe with a new one with with a different candidate generator
nlp.replace_pipe(
"entity_linker",
"entity_linker",
config={
"kb_loader": {"@assets": "myAdamKB.v1"},
"incl_context": False,
"get_candidates": {"@assets": "spacy.LowercaseCandidateGenerator.v1"},
},
)
doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def test_append_alias(nlp): def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs""" """Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -189,26 +245,25 @@ def test_append_alias(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2 assert len(mykb.get_alias_candidates("douglas")) == 2
# append an alias # append an alias
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
# test the size of the relevant candidates has been incremented # test the size of the relevant candidates has been incremented
assert len(mykb.get_candidates("douglas")) == 3 assert len(mykb.get_alias_candidates("douglas")) == 3
# append the same alias-entity pair again should not work (will throw a warning) # append the same alias-entity pair again should not work (will throw a warning)
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3) mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
# test the size of the relevant candidates remained unchanged # test the size of the relevant candidates remained unchanged
assert len(mykb.get_candidates("douglas")) == 3 assert len(mykb.get_alias_candidates("douglas")) == 3
def test_append_invalid_alias(nlp): def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1""" """Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -228,16 +283,18 @@ def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""
@registry.assets.register("myLocationsKB.v1") @registry.assets.register("myLocationsKB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
mykb = KnowledgeBase(entity_vector_length=1) def create_kb(vocab):
mykb.initialize(nlp.vocab) mykb = KnowledgeBase(vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
# adding aliases # adding aliases
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
return mykb return mykb
return create_kb
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
@ -247,7 +304,7 @@ def test_preserving_links_asdoc(nlp):
] ]
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False} el_config = {"kb_loader": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True) el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
el_pipe.begin_training(lambda: []) el_pipe.begin_training(lambda: [])
el_pipe.incl_context = False el_pipe.incl_context = False
@ -331,24 +388,28 @@ def test_overfitting_IO():
train_examples.append(Example.from_dict(doc, annotation)) train_examples.append(Example.from_dict(doc, annotation))
@registry.assets.register("myOverfittingKB.v1") @registry.assets.register("myOverfittingKB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
# create artificial KB - assign same prior weight to the two russ cochran's def create_kb(vocab):
# Q2146908 (Russ Cochran): American golfer # create artificial KB - assign same prior weight to the two russ cochran's
# Q7381115 (Russ Cochran): publisher # Q2146908 (Russ Cochran): American golfer
mykb = KnowledgeBase(entity_vector_length=3) # Q7381115 (Russ Cochran): publisher
mykb.initialize(nlp.vocab) mykb = KnowledgeBase(vocab, entity_vector_length=3)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias( mykb.add_alias(
alias="Russ Cochran", alias="Russ Cochran",
entities=["Q2146908", "Q7381115"], entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5], probabilities=[0.5, 0.5],
) )
return mykb return mykb
return create_kb
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
nlp.add_pipe( nlp.add_pipe(
"entity_linker", config={"kb": {"@assets": "myOverfittingKB.v1"}}, last=True "entity_linker",
config={"kb_loader": {"@assets": "myOverfittingKB.v1"}},
last=True,
) )
# train the NEL pipe # train the NEL pipe

View File

@ -78,6 +78,14 @@ def test_replace_last_pipe(nlp):
assert nlp.pipe_names == ["sentencizer", "ner"] assert nlp.pipe_names == ["sentencizer", "ner"]
def test_replace_pipe_config(nlp):
nlp.add_pipe("entity_linker")
nlp.add_pipe("sentencizer")
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == True
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == False
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")]) @pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
def test_rename_pipe(nlp, old_name, new_name): def test_rename_pipe(nlp, old_name, new_name):
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -139,8 +139,7 @@ def test_issue4665():
def test_issue4674(): def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO""" """Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English() nlp = English()
kb = KnowledgeBase(entity_vector_length=3) kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb.initialize(nlp.vocab)
vector1 = [0.9, 1.1, 1.01] vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01] vector2 = [1.8, 2.25, 2.01]
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
@ -156,10 +155,9 @@ def test_issue4674():
if not dir_path.exists(): if not dir_path.exists():
dir_path.mkdir() dir_path.mkdir()
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb.dump(str(file_path)) kb.to_disk(str(file_path))
kb2 = KnowledgeBase(entity_vector_length=3) kb2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb2.initialize(nlp.vocab) kb2.from_disk(str(file_path))
kb2.load_bulk(str(file_path))
assert kb2.get_size_entities() == 1 assert kb2.get_size_entities() == 1

View File

@ -1,3 +1,4 @@
from typing import Callable
import warnings import warnings
from unittest import TestCase from unittest import TestCase
import pytest import pytest
@ -70,13 +71,14 @@ def entity_linker():
nlp = Language() nlp = Language()
@registry.assets.register("TestIssue5230KB.v1") @registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
kb = KnowledgeBase(entity_vector_length=1) def create_kb(vocab):
kb.initialize(nlp.vocab) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb return kb
return create_kb
config = {"kb": {"@assets": "TestIssue5230KB.v1"}} config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
# need to add model for two reasons: # need to add model for two reasons:
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,
@ -121,19 +123,17 @@ def test_writer_with_path_py35():
def test_save_and_load_knowledge_base(): def test_save_and_load_knowledge_base():
nlp = Language() nlp = Language()
kb = KnowledgeBase(entity_vector_length=1) kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb.initialize(nlp.vocab)
with make_tempdir() as d: with make_tempdir() as d:
path = d / "kb" path = d / "kb"
try: try:
kb.dump(path) kb.to_disk(path)
except Exception as e: except Exception as e:
pytest.fail(str(e)) pytest.fail(str(e))
try: try:
kb_loaded = KnowledgeBase(entity_vector_length=1) kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb_loaded.initialize(nlp.vocab) kb_loaded.from_disk(path)
kb_loaded.load_bulk(path)
except Exception as e: except Exception as e:
pytest.fail(str(e)) pytest.fail(str(e))

View File

@ -1,4 +1,8 @@
from spacy.util import ensure_path from typing import Callable
from spacy import util
from spacy.lang.en import English
from spacy.util import ensure_path, registry
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from ..util import make_tempdir from ..util import make_tempdir
@ -15,20 +19,16 @@ def test_serialize_kb_disk(en_vocab):
if not dir_path.exists(): if not dir_path.exists():
dir_path.mkdir() dir_path.mkdir()
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb1.dump(str(file_path)) kb1.to_disk(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
kb2 = KnowledgeBase(entity_vector_length=3) kb2.from_disk(str(file_path))
kb2.initialize(en_vocab)
kb2.load_bulk(str(file_path))
# final assertions # final assertions
_check_kb(kb2) _check_kb(kb2)
def _get_dummy_kb(vocab): def _get_dummy_kb(vocab):
kb = KnowledgeBase(entity_vector_length=3) kb = KnowledgeBase(vocab, entity_vector_length=3)
kb.initialize(vocab)
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3]) kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0]) kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])
kb.add_entity(entity="Q007", freq=7, entity_vector=[0, 0, 7]) kb.add_entity(entity="Q007", freq=7, entity_vector=[0, 0, 7])
@ -61,7 +61,7 @@ def _check_kb(kb):
assert alias_string not in kb.get_alias_strings() assert alias_string not in kb.get_alias_strings()
# check candidates & probabilities # check candidates & probabilities
candidates = sorted(kb.get_candidates("double07"), key=lambda x: x.entity_) candidates = sorted(kb.get_alias_candidates("double07"), key=lambda x: x.entity_)
assert len(candidates) == 2 assert len(candidates) == 2
assert candidates[0].entity_ == "Q007" assert candidates[0].entity_ == "Q007"
@ -75,3 +75,47 @@ def _check_kb(kb):
assert candidates[1].entity_vector == [7, 1, 0] assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].alias_ == "double07" assert candidates[1].alias_ == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101 assert 0.099 < candidates[1].prior_prob < 0.101
def test_serialize_subclassed_kb():
"""Check that IO of a custom KB works fine as part of an EL pipe."""
class SubKnowledgeBase(KnowledgeBase):
def __init__(self, vocab, entity_vector_length, custom_field):
super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field
@registry.assets.register("spacy.CustomKB.v1")
def custom_kb(
entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]:
def custom_kb_factory(vocab):
return SubKnowledgeBase(
vocab=vocab,
entity_vector_length=entity_vector_length,
custom_field=custom_field,
)
return custom_kb_factory
nlp = English()
config = {
"kb_loader": {
"@assets": "spacy.CustomKB.v1",
"entity_vector_length": 342,
"custom_field": 666,
}
}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert type(entity_linker.kb) == SubKnowledgeBase
assert entity_linker.kb.entity_vector_length == 342
assert entity_linker.kb.custom_field == 666
# Make sure the custom KB is serialized correctly
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker")
assert type(entity_linker2.kb) == SubKnowledgeBase
assert entity_linker2.kb.entity_vector_length == 342
assert entity_linker2.kb.custom_field == 666

View File

@ -33,7 +33,7 @@ TODO: intro and how architectures work, link to
> subword_features = true > subword_features = true
> ``` > ```
Build spaCy's "standard" tok2vec layer, which uses hash embedding with subword Build spaCy's "standard" embedding layer, which uses hash embedding with subword
features and a CNN with layer-normalized maxout. features and a CNN with layer-normalized maxout.
| Name | Description | | Name | Description |
@ -45,6 +45,7 @@ features and a CNN with layer-normalized maxout.
| `maxout_pieces` | The number of pieces to use in the maxout non-linearity. If `1`, the [`Mish`](https://thinc.ai/docs/api-layers#mish) non-linearity is used instead. Recommended values are `1`-`3`. ~~int~~ | | `maxout_pieces` | The number of pieces to use in the maxout non-linearity. If `1`, the [`Mish`](https://thinc.ai/docs/api-layers#mish) non-linearity is used instead. Recommended values are `1`-`3`. ~~int~~ |
| `subword_features` | Whether to also embed subword features, specifically the prefix, suffix and word shape. This is recommended for alphabetic languages like English, but not if single-character tokens are used for a language such as Chinese. ~~bool~~ | | `subword_features` | Whether to also embed subword features, specifically the prefix, suffix and word shape. This is recommended for alphabetic languages like English, but not if single-character tokens are used for a language such as Chinese. ~~bool~~ |
| `pretrained_vectors` | Whether to also use static vectors. ~~bool~~ | | `pretrained_vectors` | Whether to also use static vectors. ~~bool~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
### spacy.Tok2Vec.v1 {#Tok2Vec} ### spacy.Tok2Vec.v1 {#Tok2Vec}
@ -67,10 +68,11 @@ Construct a tok2vec model out of embedding and encoding subnetworks. See the
["Embed, Encode, Attend, Predict"](https://explosion.ai/blog/deep-learning-formula-nlp) ["Embed, Encode, Attend, Predict"](https://explosion.ai/blog/deep-learning-formula-nlp)
blog post for background. blog post for background.
| Name | Description | | Name | Description |
| -------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `embed` | Embed tokens into context-independent word vector representations. For example, [CharacterEmbed](/api/architectures#CharacterEmbed) or [MultiHashEmbed](/api/architectures#MultiHashEmbed). ~~Model[List[Doc], List[Floats2d]]~~ | | `embed` | Embed tokens into context-independent word vector representations. For example, [CharacterEmbed](/api/architectures#CharacterEmbed) or [MultiHashEmbed](/api/architectures#MultiHashEmbed). ~~Model[List[Doc], List[Floats2d]]~~ |
| `encode` | Encode context into the embeddings, using an architecture such as a CNN, BiLSTM or transformer. For example, [MaxoutWindowEncoder](/api/architectures#MaxoutWindowEncoder). ~~Model[List[Floats2d], List[Floats2d]]~~ | | `encode` | Encode context into the embeddings, using an architecture such as a CNN, BiLSTM or transformer. For example, [MaxoutWindowEncoder](/api/architectures#MaxoutWindowEncoder). ~~Model[List[Floats2d], List[Floats2d]]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
### spacy.Tok2VecListener.v1 {#Tok2VecListener} ### spacy.Tok2VecListener.v1 {#Tok2VecListener}
@ -108,10 +110,13 @@ Instead of defining its own `Tok2Vec` instance, a model architecture like
[Tagger](/api/architectures#tagger) can define a listener as its `tok2vec` [Tagger](/api/architectures#tagger) can define a listener as its `tok2vec`
argument that connects to the shared `tok2vec` component in the pipeline. argument that connects to the shared `tok2vec` component in the pipeline.
| Name | Description | <!-- TODO: return type -->
| ---------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `width` | The width of the vectors produced by the "upstream" [`Tok2Vec`](/api/tok2vec) component. ~~int~~ | | Name | Description |
| `upstream` | A string to identify the "upstream" `Tok2Vec` component to communicate with. The upstream name should either be the wildcard string `"*"`, or the name of the `Tok2Vec` component. You'll almost never have multiple upstream `Tok2Vec` components, so the wildcard string will almost always be fine. ~~str~~ | | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `width` | The width of the vectors produced by the "upstream" [`Tok2Vec`](/api/tok2vec) component. ~~int~~ |
| `upstream` | A string to identify the "upstream" `Tok2Vec` component to communicate with. The upstream name should either be the wildcard string `"*"`, or the name of the `Tok2Vec` component. You'll almost never have multiple upstream `Tok2Vec` components, so the wildcard string will almost always be fine. ~~str~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.MultiHashEmbed.v1 {#MultiHashEmbed} ### spacy.MultiHashEmbed.v1 {#MultiHashEmbed}
@ -134,12 +139,15 @@ definitions depending on the `Vocab` of the `Doc` object passed in. Vectors from
pretrained static vectors can also be incorporated into the concatenated pretrained static vectors can also be incorporated into the concatenated
representation. representation.
<!-- TODO: model return type -->
| Name | Description | | Name | Description |
| ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `width` | The output width. Also used as the width of the embedding tables. Recommended values are between `64` and `300`. ~~int~~ | | `width` | The output width. Also used as the width of the embedding tables. Recommended values are between `64` and `300`. ~~int~~ |
| `rows` | The number of rows for the embedding tables. Can be low, due to the hashing trick. Embeddings for prefix, suffix and word shape use half as many rows. Recommended values are between `2000` and `10000`. ~~int~~ | | `rows` | The number of rows for the embedding tables. Can be low, due to the hashing trick. Embeddings for prefix, suffix and word shape use half as many rows. Recommended values are between `2000` and `10000`. ~~int~~ |
| `also_embed_subwords` | Whether to use the `PREFIX`, `SUFFIX` and `SHAPE` features in the embeddings. If not using these, you may need more rows in your hash embeddings, as there will be increased chance of collisions. ~~bool~~ | | `also_embed_subwords` | Whether to use the `PREFIX`, `SUFFIX` and `SHAPE` features in the embeddings. If not using these, you may need more rows in your hash embeddings, as there will be increased chance of collisions. ~~bool~~ |
| `also_use_static_vectors` | Whether to also use static word vectors. Requires a vectors table to be loaded in the [Doc](/api/doc) objects' vocab. ~~bool~~ | | `also_use_static_vectors` | Whether to also use static word vectors. Requires a vectors table to be loaded in the [Doc](/api/doc) objects' vocab. ~~bool~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.CharacterEmbed.v1 {#CharacterEmbed} ### spacy.CharacterEmbed.v1 {#CharacterEmbed}
@ -170,12 +178,15 @@ concatenated. A hash-embedded vector of the `NORM` of the word is also
concatenated on, and the result is then passed through a feed-forward network to concatenated on, and the result is then passed through a feed-forward network to
construct a single vector to represent the information. construct a single vector to represent the information.
| Name | Description | <!-- TODO: model return type -->
| ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `width` | The width of the output vector and the `NORM` hash embedding. ~~int~~ | | Name | Description |
| `rows` | The number of rows in the `NORM` hash embedding table. ~~int~~ | | ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nM` | The dimensionality of the character embeddings. Recommended values are between `16` and `64`. ~~int~~ | | `width` | The width of the output vector and the `NORM` hash embedding. ~~int~~ |
| `nC` | The number of UTF-8 bytes to embed per word. Recommended values are between `3` and `8`, although it may depend on the length of words in the language. ~~int~~ | | `rows` | The number of rows in the `NORM` hash embedding table. ~~int~~ |
| `nM` | The dimensionality of the character embeddings. Recommended values are between `16` and `64`. ~~int~~ |
| `nC` | The number of UTF-8 bytes to embed per word. Recommended values are between `3` and `8`, although it may depend on the length of words in the language. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.MaxoutWindowEncoder.v1 {#MaxoutWindowEncoder} ### spacy.MaxoutWindowEncoder.v1 {#MaxoutWindowEncoder}
@ -199,6 +210,7 @@ and residual connections.
| `window_size` | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. ~~int~~ | | `window_size` | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. ~~int~~ |
| `maxout_pieces` | The number of maxout pieces to use. Recommended values are `2` or `3`. ~~int~~ | | `maxout_pieces` | The number of maxout pieces to use. Recommended values are `2` or `3`. ~~int~~ |
| `depth` | The number of convolutional layers. Recommended value is `4`. ~~int~~ | | `depth` | The number of convolutional layers. Recommended value is `4`. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Floats2d], List[Floats2d]]~~ |
### spacy.MishWindowEncoder.v1 {#MishWindowEncoder} ### spacy.MishWindowEncoder.v1 {#MishWindowEncoder}
@ -221,6 +233,7 @@ and residual connections.
| `width` | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. ~~int~~ | | `width` | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. ~~int~~ |
| `window_size` | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. ~~int~~ | | `window_size` | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. ~~int~~ |
| `depth` | The number of convolutional layers. Recommended value is `4`. ~~int~~ | | `depth` | The number of convolutional layers. Recommended value is `4`. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Floats2d], List[Floats2d]]~~ |
### spacy.TorchBiLSTMEncoder.v1 {#TorchBiLSTMEncoder} ### spacy.TorchBiLSTMEncoder.v1 {#TorchBiLSTMEncoder}
@ -242,10 +255,38 @@ Encode context using bidirectional LSTM layers. Requires
| `width` | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. ~~int~~ | | `width` | The input and output width. These are required to be the same, to allow residual connections. This value will be determined by the width of the inputs. Recommended values are between `64` and `300`. ~~int~~ |
| `window_size` | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. ~~int~~ | | `window_size` | The number of words to concatenate around each token to construct the convolution. Recommended value is `1`. ~~int~~ |
| `depth` | The number of convolutional layers. Recommended value is `4`. ~~int~~ | | `depth` | The number of convolutional layers. Recommended value is `4`. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Floats2d], List[Floats2d]]~~ |
### spacy.StaticVectors.v1 {#StaticVectors} ### spacy.StaticVectors.v1 {#StaticVectors}
<!-- TODO: --> > #### Example config
>
> ```ini
> [model]
> @architectures = "spacy.StaticVectors.v1"
> nO = null
> nM = null
> dropout = 0.2
> key_attr = "ORTH"
>
> [model.init_W]
> @initializers = "glorot_uniform_init.v1"
> ```
Embed [`Doc`](/api/doc) objects with their vocab's vectors table, applying a
learned linear projection to control the dimensionality. See the documentation
on [static vectors](/usage/embeddings-transformers#static-vectors) for details.
<!-- TODO: document argument descriptions -->
| Name |  Description |
| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nO` | Defaults to `None`. ~~Optional[int]~~ |
| `nM` | Defaults to `None`. ~~Optional[int]~~ |
| `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ |
| `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ |
| `key_attr` | Defaults to `"ORTH"`. ~~str~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ |
## Transformer architectures {#transformers source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/architectures.py"} ## Transformer architectures {#transformers source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/architectures.py"}
@ -277,6 +318,7 @@ architectures into your training config.
| `name` | Any model name that can be loaded by [`transformers.AutoModel`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModel). ~~str~~ | | `name` | Any model name that can be loaded by [`transformers.AutoModel`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModel). ~~str~~ |
| `get_spans` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. ~~Callable[[List[Doc]], List[Span]]~~ | | `get_spans` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. ~~Callable[[List[Doc]], List[Span]]~~ |
| `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ | | `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], FullTransformerBatch]~~ |
### spacy-transformers.Tok2VecListener.v1 {#transformers-Tok2VecListener} ### spacy-transformers.Tok2VecListener.v1 {#transformers-Tok2VecListener}
@ -305,6 +347,7 @@ a single token vector given zero or more wordpiece vectors.
| ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `pooling` | A reduction layer used to calculate the token vectors based on zero or more wordpiece vectors. If in doubt, mean pooling (see [`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean)) is usually a good choice. ~~Model[Ragged, Floats2d]~~ | | `pooling` | A reduction layer used to calculate the token vectors based on zero or more wordpiece vectors. If in doubt, mean pooling (see [`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean)) is usually a good choice. ~~Model[Ragged, Floats2d]~~ |
| `grad_factor` | Reweight gradients from the component before passing them upstream. You can set this to `0` to "freeze" the transformer weights with respect to the component, or use it to make some components more significant than others. Leaving it at `1.0` is usually fine. ~~float~~ | | `grad_factor` | Reweight gradients from the component before passing them upstream. You can set this to `0` to "freeze" the transformer weights with respect to the component, or use it to make some components more significant than others. Leaving it at `1.0` is usually fine. ~~float~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
### spacy-transformers.Tok2VecTransformer.v1 {#Tok2VecTransformer} ### spacy-transformers.Tok2VecTransformer.v1 {#Tok2VecTransformer}
@ -330,6 +373,7 @@ one component.
| `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ | | `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ |
| `pooling` | A reduction layer used to calculate the token vectors based on zero or more wordpiece vectors. If in doubt, mean pooling (see [`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean)) is usually a good choice. ~~Model[Ragged, Floats2d]~~ | | `pooling` | A reduction layer used to calculate the token vectors based on zero or more wordpiece vectors. If in doubt, mean pooling (see [`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean)) is usually a good choice. ~~Model[Ragged, Floats2d]~~ |
| `grad_factor` | Reweight gradients from the component before passing them upstream. You can set this to `0` to "freeze" the transformer weights with respect to the component, or use it to make some components more significant than others. Leaving it at `1.0` is usually fine. ~~float~~ | | `grad_factor` | Reweight gradients from the component before passing them upstream. You can set this to `0` to "freeze" the transformer weights with respect to the component, or use it to make some components more significant than others. Leaving it at `1.0` is usually fine. ~~float~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
## Parser & NER architectures {#parser} ## Parser & NER architectures {#parser}
@ -372,6 +416,8 @@ consists of either two or three subnetworks:
state representation. If not present, the output from the lower model is used state representation. If not present, the output from the lower model is used
as action scores directly. as action scores directly.
<!-- TODO: model return type -->
| Name | Description | | Name | Description |
| ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | | `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
@ -380,6 +426,7 @@ consists of either two or three subnetworks:
| `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. If `1`, the maxout non-linearity is replaced with a [`Relu`](https://thinc.ai/docs/api-layers#relu) non-linearity if `use_upper` is `True`, and no non-linearity if `False`. ~~int~~ | | `maxout_pieces` | How many pieces to use in the state prediction layer. Recommended values are `1`, `2` or `3`. If `1`, the maxout non-linearity is replaced with a [`Relu`](https://thinc.ai/docs/api-layers#relu) non-linearity if `use_upper` is `True`, and no non-linearity if `False`. ~~int~~ |
| `use_upper` | Whether to use an additional hidden layer after the state vector in order to predict the action scores. It is recommended to set this to `False` for large pretrained models such as transformers, and `True` for smaller networks. The upper layer is computed on CPU, which becomes a bottleneck on larger GPU-based models, where it's also less necessary. ~~bool~~ | | `use_upper` | Whether to use an additional hidden layer after the state vector in order to predict the action scores. It is recommended to set this to `False` for large pretrained models such as transformers, and `True` for smaller networks. The upper layer is computed on CPU, which becomes a bottleneck on larger GPU-based models, where it's also less necessary. ~~bool~~ |
| `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ | | `nO` | The number of actions the model will predict between. Usually inferred from data at the beginning of training, or loaded from disk. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.BILUOTagger.v1 {#BILUOTagger source="spacy/ml/models/simple_ner.py"} ### spacy.BILUOTagger.v1 {#BILUOTagger source="spacy/ml/models/simple_ner.py"}
@ -406,9 +453,10 @@ generally results in better linear separation between classes, especially for
non-CRF models, because there are more distinct classes for the different non-CRF models, because there are more distinct classes for the different
situations ([Ratinov et al., 2009](https://www.aclweb.org/anthology/W09-1119/)). situations ([Ratinov et al., 2009](https://www.aclweb.org/anthology/W09-1119/)).
| Name | Description | | Name | Description |
| --------- | ------------------------------------------------------------------------------------------ | | ----------- | ------------------------------------------------------------------------------------------ |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | | `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
### spacy.IOBTagger.v1 {#IOBTagger source="spacy/ml/models/simple_ner.py"} ### spacy.IOBTagger.v1 {#IOBTagger source="spacy/ml/models/simple_ner.py"}
@ -431,9 +479,10 @@ spans into tags assigned to each token. The first token of a span is given the
tag B-LABEL, and subsequent tokens are given the tag I-LABEL. All other tokens tag B-LABEL, and subsequent tokens are given the tag I-LABEL. All other tokens
are assigned the tag O. are assigned the tag O.
| Name | Description | | Name | Description |
| --------- | ------------------------------------------------------------------------------------------ | | ----------- | ------------------------------------------------------------------------------------------ |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | | `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
## Tagging architectures {#tagger source="spacy/ml/models/tagger.py"} ## Tagging architectures {#tagger source="spacy/ml/models/tagger.py"}
@ -454,10 +503,11 @@ Build a tagger model, using a provided token-to-vector component. The tagger
model simply adds a linear layer with softmax activation to predict scores given model simply adds a linear layer with softmax activation to predict scores given
the token vectors. the token vectors.
| Name | Description | | Name | Description |
| --------- | ------------------------------------------------------------------------------------------ | | ----------- | ------------------------------------------------------------------------------------------ |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | | `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
| `nO` | The number of tags to output. Inferred from the data if `None`. ~~Optional[int]~~ | | `nO` | The number of tags to output. Inferred from the data if `None`. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"} ## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"}
@ -474,9 +524,6 @@ specific data and challenge.
### spacy.TextCatEnsemble.v1 {#TextCatEnsemble} ### spacy.TextCatEnsemble.v1 {#TextCatEnsemble}
Stacked ensemble of a bag-of-words model and a neural network model. The neural
network has an internal CNN Tok2Vec layer and uses attention.
> #### Example Config > #### Example Config
> >
> ```ini > ```ini
@ -493,17 +540,23 @@ network has an internal CNN Tok2Vec layer and uses attention.
> nO = null > nO = null
> ``` > ```
| Name | Description | Stacked ensemble of a bag-of-words model and a neural network model. The neural
| -------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | network has an internal CNN Tok2Vec layer and uses attention.
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
| `pretrained_vectors` | Whether or not pretrained vectors will be used in addition to the feature vectors. ~~bool~~ | <!-- TODO: model return type -->
| `width` | Output dimension of the feature encoding step. ~~int~~ |
| `embed_size` | Input dimension of the feature encoding step. ~~int~~ | | Name | Description |
| `conv_depth` | Depth of the tok2vec layer. ~~int~~ | | -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `window_size` | The number of contextual vectors to [concatenate](https://thinc.ai/docs/api-layers#expand_window) from the left and from the right. ~~int~~ | | `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
| `ngram_size` | Determines the maximum length of the n-grams in the BOW model. For instance, `ngram_size=3`would give unigram, trigram and bigram features. ~~int~~ | | `pretrained_vectors` | Whether or not pretrained vectors will be used in addition to the feature vectors. ~~bool~~ |
| `dropout` | The dropout rate. ~~float~~ | | `width` | Output dimension of the feature encoding step. ~~int~~ |
| `embed_size` | Input dimension of the feature encoding step. ~~int~~ |
| `conv_depth` | Depth of the tok2vec layer. ~~int~~ |
| `window_size` | The number of contextual vectors to [concatenate](https://thinc.ai/docs/api-layers#expand_window) from the left and from the right. ~~int~~ |
| `ngram_size` | Determines the maximum length of the n-grams in the BOW model. For instance, `ngram_size=3`would give unigram, trigram and bigram features. ~~int~~ |
| `dropout` | The dropout rate. ~~float~~ |
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `begin_training` is called. ~~Optional[int]~~ | | `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `begin_training` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.TextCatCNN.v1 {#TextCatCNN} ### spacy.TextCatCNN.v1 {#TextCatCNN}
@ -530,11 +583,14 @@ A neural network model where token vectors are calculated using a CNN. The
vectors are mean pooled and used as features in a feed-forward network. This vectors are mean pooled and used as features in a feed-forward network. This
architecture is usually less accurate than the ensemble, but runs faster. architecture is usually less accurate than the ensemble, but runs faster.
| Name | Description | <!-- TODO: model return type -->
| ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ | | Name | Description |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `begin_training` is called. ~~Optional[int]~~ | | `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `begin_training` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.TextCatBOW.v1 {#TextCatBOW} ### spacy.TextCatBOW.v1 {#TextCatBOW}
@ -552,12 +608,15 @@ architecture is usually less accurate than the ensemble, but runs faster.
An ngram "bag-of-words" model. This architecture should run much faster than the An ngram "bag-of-words" model. This architecture should run much faster than the
others, but may not be as accurate, especially if texts are short. others, but may not be as accurate, especially if texts are short.
| Name | Description | <!-- TODO: model return type -->
| ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ | | Name | Description |
| `ngram_size` | Determines the maximum length of the n-grams in the BOW model. For instance, `ngram_size=3`would give unigram, trigram and bigram features. ~~int~~ | | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `no_output_layer` | Whether or not to add an output layer to the model (`Softmax` activation if `exclusive_classes` is `True`, else `Logistic`. ~~bool~~ | | `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
| `ngram_size` | Determines the maximum length of the n-grams in the BOW model. For instance, `ngram_size=3`would give unigram, trigram and bigram features. ~~int~~ |
| `no_output_layer` | Whether or not to add an output layer to the model (`Softmax` activation if `exclusive_classes` is `True`, else `Logistic`. ~~bool~~ |
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `begin_training` is called. ~~Optional[int]~~ | | `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `begin_training` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"} ## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
@ -574,9 +633,6 @@ into the "real world". This requires 3 main component
### spacy.EntityLinker.v1 {#EntityLinker} ### spacy.EntityLinker.v1 {#EntityLinker}
The `EntityLinker` model architecture is a Thinc `Model` with a
[`Linear`](https://thinc.ai/api-layers#linear) output layer.
> #### Example Config > #### Example Config
> >
> ```ini > ```ini
@ -602,10 +658,16 @@ The `EntityLinker` model architecture is a Thinc `Model` with a
> @assets = "spacy.CandidateGenerator.v1" > @assets = "spacy.CandidateGenerator.v1"
> ``` > ```
| Name | Description | The `EntityLinker` model architecture is a Thinc `Model` with a
| --------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | [`Linear`](https://thinc.ai/api-layers#linear) output layer.
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `begin_training` is called. ~~Optional[int]~~ | <!-- TODO: model return type -->
| Name | Description |
| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `nO` | Output dimension, determined by the length of the vectors encoding each entity in the KB. If the `nO` dimension is not set, the entity linking component will set it when `begin_training` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model~~ |
### spacy.EmptyKB.v1 {#EmptyKB} ### spacy.EmptyKB.v1 {#EmptyKB}

View File

@ -719,11 +719,11 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--gold-preproc]
Generate an installable Generate an installable
[model Python package](/usage/training#models-generating) from an existing model [model Python package](/usage/training#models-generating) from an existing model
data directory. All data files are copied over. If the path to a `meta.json` is data directory. All data files are copied over. If the path to a
supplied, or a `meta.json` is found in the input directory, this file is used. [`meta.json`](/api/data-formats#meta) is supplied, or a `meta.json` is found in
Otherwise, the data can be entered directly from the command line. spaCy will the input directory, this file is used. Otherwise, the data can be entered
then create a `.tar.gz` archive file that you can distribute and install with directly from the command line. spaCy will then create a `.tar.gz` archive file
`pip install`. that you can distribute and install with `pip install`.
<Infobox title="New in v3.0" variant="warning"> <Infobox title="New in v3.0" variant="warning">
@ -750,7 +750,7 @@ $ python -m spacy package [input_dir] [output_dir] [--meta-path] [--create-meta]
| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `input_dir` | Path to directory containing model data. ~~Path (positional)~~ | | `input_dir` | Path to directory containing model data. ~~Path (positional)~~ |
| `output_dir` | Directory to create package folder in. ~~Path (positional)~~ | | `output_dir` | Directory to create package folder in. ~~Path (positional)~~ |
| `--meta-path`, `-m` <Tag variant="new">2</Tag> | Path to `meta.json` file (optional). ~~Optional[Path] \(option)~~ | | `--meta-path`, `-m` <Tag variant="new">2</Tag> | Path to [`meta.json`](/api/data-formats#meta) file (optional). ~~Optional[Path] \(option)~~ |
| `--create-meta`, `-C` <Tag variant="new">2</Tag> | Create a `meta.json` file on the command line, even if one already exists in the directory. If an existing file is found, its entries will be shown as the defaults in the command line prompt. ~~bool (flag)~~ | | `--create-meta`, `-C` <Tag variant="new">2</Tag> | Create a `meta.json` file on the command line, even if one already exists in the directory. If an existing file is found, its entries will be shown as the defaults in the command line prompt. ~~bool (flag)~~ |
| `--no-sdist`, `-NS`, | Don't build the `.tar.gz` sdist automatically. Can be set if you want to run this step manually. ~~bool (flag)~~ | | `--no-sdist`, `-NS`, | Don't build the `.tar.gz` sdist automatically. Can be set if you want to run this step manually. ~~bool (flag)~~ |
| `--version`, `-v` <Tag variant="new">3</Tag> | Package version to override in meta. Useful when training new versions, as it doesn't require editing the meta template. ~~Optional[str] \(option)~~ | | `--version`, `-v` <Tag variant="new">3</Tag> | Package version to override in meta. Useful when training new versions, as it doesn't require editing the meta template. ~~Optional[str] \(option)~~ |

View File

@ -6,6 +6,7 @@ menu:
- ['Training Data', 'training'] - ['Training Data', 'training']
- ['Pretraining Data', 'pretraining'] - ['Pretraining Data', 'pretraining']
- ['Vocabulary', 'vocab'] - ['Vocabulary', 'vocab']
- ['Model Meta', 'meta']
--- ---
This section documents input and output formats of data used by spaCy, including This section documents input and output formats of data used by spaCy, including
@ -73,15 +74,15 @@ your config and check that it's valid, you can run the
Defines the `nlp` object, its tokenizer and Defines the `nlp` object, its tokenizer and
[processing pipeline](/usage/processing-pipelines) component names. [processing pipeline](/usage/processing-pipelines) component names.
| Name | Description | Default | | Name | Description |
| ------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------- | | ------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `lang` | The language code to use. ~~str~~ | `null` | | `lang` | Model language [ISO code](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes). Defaults to `null`. ~~str~~ |
| `pipeline` | Names of pipeline components in order. Should correspond to sections in the `[components]` block, e.g. `[components.ner]`. See docs on [defining components](/usage/training#config-components). ~~List[str]~~ | `[]` | | `pipeline` | Names of pipeline components in order. Should correspond to sections in the `[components]` block, e.g. `[components.ner]`. See docs on [defining components](/usage/training#config-components). Defaults to `[]`. ~~List[str]~~ |
| `load_vocab_data` | Whether to load additional lexeme and vocab data from [`spacy-lookups-data`](https://github.com/explosion/spacy-lookups-data) if available. ~~bool~~ | `true` | | `load_vocab_data` | Whether to load additional lexeme and vocab data from [`spacy-lookups-data`](https://github.com/explosion/spacy-lookups-data) if available. Defaults to `true`. ~~bool~~ |
| `before_creation` | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `Language` subclass before it's initialized. ~~Optional[Callable[[Type[Language]], Type[Language]]]~~ | `null` | | `before_creation` | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `Language` subclass before it's initialized. Defaults to `null`. ~~Optional[Callable[[Type[Language]], Type[Language]]]~~ |
| `after_creation` | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `nlp` object right after it's initialized. ~~Optional[Callable[[Language], Language]]~~ | `null` | | `after_creation` | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `nlp` object right after it's initialized. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
| `after_pipeline_creation` | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `nlp` object after the pipeline components have been added. ~~Optional[Callable[[Language], Language]]~~ | `null` | | `after_pipeline_creation` | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `nlp` object after the pipeline components have been added. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
| `tokenizer` | The tokenizer to use. ~~Callable[[str], Doc]~~ | [`Tokenizer`](/api/tokenizer) | | `tokenizer` | The tokenizer to use. Defaults to [`Tokenizer`](/api/tokenizer). ~~Callable[[str], Doc]~~ |
### components {#config-components tag="section"} ### components {#config-components tag="section"}
@ -128,24 +129,24 @@ process that are used when you run [`spacy train`](/api/cli#train).
<!-- TODO: complete --> <!-- TODO: complete -->
| Name | Description | Default | | Name | Description |
| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- | | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `seed` | The random seed. ~~int~~ | `${system:seed}` | | `seed` | The random seed. Defaults to variable `${system:seed}`. ~~int~~ |
| `dropout` | The dropout rate. ~~float~~ | `0.1` | | `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `accumulate_gradient` | Whether to divide the batch up into substeps. ~~int~~ | `1` | | `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). ~~Optional[str]~~ | `${paths:init_tok2vec}` | | `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths:init_tok2vec}`. ~~Optional[str]~~ |
| `raw_text` | ~~Optional[str]~~ | `${paths:raw}` | | `raw_text` | TODO: ... Defaults to variable `${paths:raw}`. ~~Optional[str]~~ |
| `vectors` | ~~Optional[str]~~ | `null` | | `vectors` | Model name or path to model containing pretrained word vectors to use, e.g. created with [`init model`](/api/cli#init-model). Defaults to `null`. ~~Optional[str]~~ |
| `patience` | How many steps to continue without improvement in evaluation score. ~~int~~ | `1600` | | `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ |
| `max_epochs` | Maximum number of epochs to train for. ~~int~~ | `0` | | `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |
| `max_steps` | Maximum number of update steps to train for. ~~int~~ | `20000` | | `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ |
| `eval_frequency` | How often to evaluate during training (steps). ~~int~~ | `200` | | `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. ~~Dict[str, float]~~ | `{}` | | `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. ~~List[str]~~ | `[]` | | `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
| `train_corpus` | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. ~~Callable[[Language], Iterator[Example]]~~ | [`Corpus`](/api/corpus) | | `train_corpus` | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/corpus). ~~Callable[[Language], Iterator[Example]]~~ |
| `dev_corpus` | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. ~~Callable[[Language], Iterator[Example]]~~ | [`Corpus`](/api/corpus) | | `dev_corpus` | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/corpus). ~~Callable[[Language], Iterator[Example]]~~ |
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | [`batch_by_words`](/api/top-level#batch_by_words) | | `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. ~~Optimizer~~ | [`Adam`](https://thinc.ai/docs/api-optimizers#adam) | | `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
### pretraining {#config-pretraining tag="section,optional"} ### pretraining {#config-pretraining tag="section,optional"}
@ -153,19 +154,19 @@ This section is optional and defines settings and controls for
[language model pretraining](/usage/training#pretraining). It's used when you [language model pretraining](/usage/training#pretraining). It's used when you
run [`spacy pretrain`](/api/cli#pretrain). run [`spacy pretrain`](/api/cli#pretrain).
| Name | Description | Default | | Name | Description |
| ---------------------------- | ----------------------------------------------------------------------------------------------------------- | --------------------------------------------------- | | ---------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
| `max_epochs` | Maximum number of epochs. ~~int~~ | `1000` | | `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ |
| `min_length` | Minimum length of examples. ~~int~~ | `5` | | `min_length` | Minimum length of examples. Defaults to `5`. ~~int~~ |
| `max_length` | Maximum length of examples. ~~int~~ | `500` | | `max_length` | Maximum length of examples. Defaults to `500`. ~~int~~ |
| `dropout` | The dropout rate. ~~float~~ | `0.2` | | `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ |
| `n_save_every` | Saving frequency. ~~int~~ | `null` | | `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ |
| `batch_size` | The batch size or batch size [schedule](https://thinc.ai/docs/api-schedules). ~~Union[int, Sequence[int]]~~ | `3000` | | `batch_size` | The batch size or batch size [schedule](https://thinc.ai/docs/api-schedules). Defaults to `3000`. ~~Union[int, Sequence[int]]~~ |
| `seed` | The random seed. ~~int~~ | `${system:seed}` | | `seed` | The random seed. Defaults to variable `${system:seed}`. ~~int~~ |
| `use_pytorch_for_gpu_memory` | Allocate memory via PyTorch. ~~bool~~ | `${system:use_pytorch_for_gpu_memory}` | | `use_pytorch_for_gpu_memory` | Allocate memory via PyTorch. Defaults to variable `${system:use_pytorch_for_gpu_memory}`. ~~bool~~ |
| `tok2vec_model` | tok2vec model section in the config. ~~str~~ | `"components.tok2vec.model"` | | `tok2vec_model` | The model section of the embedding component in the config. Defaults to `"components.tok2vec.model"`. ~~str~~ |
| `objective` | The pretraining objective. ~~Dict[str, Any]~~ | `{"type": "characters", "n_characters": 4}` | | `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ |
| `optimizer` | The optimizer. ~~Optimizer~~ | [`Adam`](https://thinc.ai/docs/api-optimizers#adam) | | `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
## Training data {#training} ## Training data {#training}
@ -372,11 +373,11 @@ example = Example.from_dict(doc, gold_dict)
## Pretraining data {#pretraining} ## Pretraining data {#pretraining}
The [`spacy pretrain`](/api/cli#pretrain) command lets you pretrain the tok2vec The [`spacy pretrain`](/api/cli#pretrain) command lets you pretrain the
layer of pipeline components from raw text. Raw text can be provided as a "token-to-vector" embedding layer of pipeline components from raw text. Raw text
`.jsonl` (newline-delimited JSON) file containing one input text per line can be provided as a `.jsonl` (newline-delimited JSON) file containing one input
(roughly paragraph length is good). Optionally, custom tokenization can be text per line (roughly paragraph length is good). Optionally, custom
provided. tokenization can be provided.
> #### Tip: Writing JSONL > #### Tip: Writing JSONL
> >
@ -457,3 +458,75 @@ Here's an example of the 20 most frequent lexemes in the English training data:
```json ```json
https://github.com/explosion/spaCy/tree/master/examples/training/vocab-data.jsonl https://github.com/explosion/spaCy/tree/master/examples/training/vocab-data.jsonl
``` ```
## Model meta {#meta}
The model meta is available as the file `meta.json` and exported automatically
when you save an `nlp` object to disk. Its contents are available as
[`nlp.meta`](/api/language#meta).
<Infobox variant="warning" title="Changed in v3.0">
As of spaCy v3.0, the `meta.json` **isn't** used to construct the language class
and pipeline anymore and only contains meta information for reference and for
creating a Python package with [`spacy package`](/api/cli#package). How to set
up the `nlp` object is now defined in the
[`config.cfg`](/api/data-formats#config), which includes detailed information
about the pipeline components and their model architectures, and all other
settings and hyperparameters used to train the model. It's the **single source
of truth** used for loading a model.
</Infobox>
> #### Example
>
> ```json
> {
> "name": "example_model",
> "lang": "en",
> "version": "1.0.0",
> "spacy_version": ">=3.0.0,<3.1.0",
> "parent_package": "spacy",
> "description": "Example model for spaCy",
> "author": "You",
> "email": "you@example.com",
> "url": "https://example.com",
> "license": "CC BY-SA 3.0",
> "sources": [{ "name": "My Corpus", "license": "MIT" }],
> "vectors": { "width": 0, "vectors": 0, "keys": 0, "name": null },
> "pipeline": ["tok2vec", "ner", "textcat"],
> "labels": {
> "ner": ["PERSON", "ORG", "PRODUCT"],
> "textcat": ["POSITIVE", "NEGATIVE"]
> },
> "accuracy": {
> "ents_f": 82.7300930714,
> "ents_p": 82.135523614,
> "ents_r": 83.3333333333,
> "textcat_score": 88.364323811
> },
> "speed": { "cpu": 7667.8, "gpu": null, "nwords": 10329 },
> "spacy_git_version": "61dfdd9fb"
> }
> ```
| Name | Description |
| ---------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `lang` | Model language [ISO code](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes). Defaults to `"en"`. ~~str~~ |
| `name` | Model name, e.g. `"core_web_sm"`. The final model package name will be `{lang}_{name}`. Defaults to `"model"`. ~~str~~ |
| `version` | Model version. Will be used to version a Python package created with [`spacy package`](/api/cli#package). Defaults to `"0.0.0"`. ~~str~~ |
| `spacy_version` | spaCy version range the model is compatible with. Defaults to spaCy version used to create the model, up to next minor version, which is the default compatibility for the available [pretrained models](/models). For instance, a model trained with v3.0.0 will have the version range `">=3.0.0,<3.1.0"`. ~~str~~ |
| `parent_package` | Name of the spaCy package. Typically `"spacy"` or `"spacy_nightly"`. Defaults to `"spacy"`. ~~str~~ |
| `description` | Model description. Also used for Python package. Defaults to `""`. ~~str~~ |
| `author` | Model author name. Also used for Python package. Defaults to `""`. ~~str~~ |
| `email` | Model author email. Also used for Python package. Defaults to `""`. ~~str~~ |
| `url` | Model author URL. Also used for Python package. Defaults to `""`. ~~str~~ |
| `license` | Model license. Also used for Python package. Defaults to `""`. ~~str~~ |
| `sources` | Data sources used to train the model. Typically a list of dicts with the keys `"name"`, `"url"`, `"author"` and `"license"`. [See here](https://github.com/explosion/spacy-models/tree/master/meta) for examples. Defaults to `None`. ~~Optional[List[Dict[str, str]]]~~ |
| `vectors` | Information about the word vectors included with the model. Typically a dict with the keys `"width"`, `"vectors"` (number of vectors), `"keys"` and `"name"`. ~~Dict[str, Any]~~ |
| `pipeline` | Names of pipeline component names in the model, in order. Corresponds to [`nlp.pipe_names`](/api/language#pipe_names). Only exists for reference and is not used to create the components. This information is defined in the [`config.cfg`](/api/data-formats#config). Defaults to `[]`. ~~List[str]~~ |
| `labels` | Label schemes of the trained pipeline components, keyed by component name. Corresponds to [`nlp.pipe_labels`](/api/language#pipe_labels). [See here](https://github.com/explosion/spacy-models/tree/master/meta) for examples. Defaults to `{}`. ~~Dict[str, Dict[str, List[str]]]~~ |
| `accuracy` | Training accuracy, added automatically by [`spacy train`](/api/cli#train). Dictionary of [score names](/usage/training#metrics) mapped to scores. Defaults to `{}`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
| `speed` | Model speed, added automatically by [`spacy train`](/api/cli#train). Typically a dictionary with the keys `"cpu"`, `"gpu"` and `"nwords"` (words per second). Defaults to `{}`. ~~Dict[str, Optional[Union[float, str]]]~~ |
| `spacy_git_version` <Tag variant="new">3</Tag> | Git commit of [`spacy`](https://github.com/explosion/spaCy) used to create model. ~~str~~ |
| other | Any other custom meta information you want to add. The data is preserved in [`nlp.meta`](/api/language#meta). ~~Any~~ |

View File

@ -200,21 +200,21 @@ probability of the fact that the mention links to the entity ID.
| `alias` | The textual mention or alias. ~~str~~ | | `alias` | The textual mention or alias. ~~str~~ |
| **RETURNS** | The prior probability of the `alias` referring to the `entity`. ~~float~~ | | **RETURNS** | The prior probability of the `alias` referring to the `entity`. ~~float~~ |
## KnowledgeBase.dump {#dump tag="method"} ## KnowledgeBase.to_disk {#to_disk tag="method"}
Save the current state of the knowledge base to a directory. Save the current state of the knowledge base to a directory.
> #### Example > #### Example
> >
> ```python > ```python
> kb.dump(loc) > kb.to_disk(loc)
> ``` > ```
| Name | Description | | Name | Description |
| ----- | ------------------------------------------------------------------------------------------------------------------------------------------ | | ----- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| `loc` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | | `loc` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
## KnowledgeBase.load_bulk {#load_bulk tag="method"} ## KnowledgeBase.from_disk {#from_disk tag="method"}
Restore the state of the knowledge base from a given directory. Note that the Restore the state of the knowledge base from a given directory. Note that the
[`Vocab`](/api/vocab) should also be the same as the one used to create the KB. [`Vocab`](/api/vocab) should also be the same as the one used to create the KB.
@ -226,7 +226,7 @@ Restore the state of the knowledge base from a given directory. Note that the
> from spacy.vocab import Vocab > from spacy.vocab import Vocab
> vocab = Vocab().from_disk("/path/to/vocab") > vocab = Vocab().from_disk("/path/to/vocab")
> kb = KnowledgeBase(vocab=vocab, entity_vector_length=64) > kb = KnowledgeBase(vocab=vocab, entity_vector_length=64)
> kb.load_bulk("/path/to/kb") > kb.from_disk("/path/to/kb")
> ``` > ```
| Name | Description | | Name | Description |

View File

@ -742,7 +742,7 @@ token.ent_iob, token.ent_type
Custom meta data for the Language class. If a model is loaded, contains meta Custom meta data for the Language class. If a model is loaded, contains meta
data of the model. The `Language.meta` is also what's serialized as the data of the model. The `Language.meta` is also what's serialized as the
`meta.json` when you save an `nlp` object to disk. [`meta.json`](/api/data-formats#meta) when you save an `nlp` object to disk.
> #### Example > #### Example
> >
@ -954,12 +954,12 @@ serialization by passing in the string names via the `exclude` argument.
> nlp.from_disk("./model-data", exclude=["ner"]) > nlp.from_disk("./model-data", exclude=["ner"])
> ``` > ```
| Name | Description | | Name | Description |
| ----------- | -------------------------------------------------- | | ----------- | ------------------------------------------------------------------ |
| `vocab` | The shared [`Vocab`](/api/vocab). | | `vocab` | The shared [`Vocab`](/api/vocab). |
| `tokenizer` | Tokenization rules and exceptions. | | `tokenizer` | Tokenization rules and exceptions. |
| `meta` | The meta data, available as `Language.meta`. | | `meta` | The meta data, available as [`Language.meta`](/api/language#meta). |
| ... | String names of pipeline components, e.g. `"ner"`. | | ... | String names of pipeline components, e.g. `"ner"`. |
## FactoryMeta {#factorymeta new="3" tag="dataclass"} ## FactoryMeta {#factorymeta new="3" tag="dataclass"}

View File

@ -15,7 +15,7 @@ multiple components, e.g. to have one embedding and CNN network shared between a
[`EntityRecognizer`](/api/entityrecognizer). [`EntityRecognizer`](/api/entityrecognizer).
In order to use the `Tok2Vec` predictions, subsequent components should use the In order to use the `Tok2Vec` predictions, subsequent components should use the
[Tok2VecListener](/api/architectures#Tok2VecListener) layer as the tok2vec [Tok2VecListener](/api/architectures#Tok2VecListener) layer as the `tok2vec`
subnetwork of their model. This layer will read data from the `doc.tensor` subnetwork of their model. This layer will read data from the `doc.tensor`
attribute during prediction. During training, the `Tok2Vec` component will save attribute during prediction. During training, the `Tok2Vec` component will save
its prediction and backprop callback for each batch, so that the subsequent its prediction and backprop callback for each batch, so that the subsequent

View File

@ -18,9 +18,10 @@ Load a model using the name of an installed
`Path`-like object. spaCy will try resolving the load argument in this order. If `Path`-like object. spaCy will try resolving the load argument in this order. If
a model is loaded from a model name, spaCy will assume it's a Python package and a model is loaded from a model name, spaCy will assume it's a Python package and
import it and call the model's own `load()` method. If a model is loaded from a import it and call the model's own `load()` method. If a model is loaded from a
path, spaCy will assume it's a data directory, read the language and pipeline path, spaCy will assume it's a data directory, load its
settings off the meta.json and initialize the `Language` class. The data will be [`config.cfg`](/api/data-formats#config) and use the language and pipeline
loaded in via [`Language.from_disk`](/api/language#from_disk). information to construct the `Language` class. The data will be loaded in via
[`Language.from_disk`](/api/language#from_disk).
> #### Example > #### Example
> >
@ -40,9 +41,10 @@ loaded in via [`Language.from_disk`](/api/language#from_disk).
| `config` <Tag variant="new">3</Tag> | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. ~~Union[Dict[str, Any], Config]~~ | | `config` <Tag variant="new">3</Tag> | Optional config overrides, either as nested dict or dict keyed by section value in dot notation, e.g. `"components.name.value"`. ~~Union[Dict[str, Any], Config]~~ |
| **RETURNS** | A `Language` object with the loaded model. ~~Language~~ | | **RETURNS** | A `Language` object with the loaded model. ~~Language~~ |
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID Essentially, `spacy.load()` is a convenience wrapper that reads the model's
and pipeline components from a model's `meta.json`, initializes the `Language` [`config.cfg`](/api/data-formats#config), uses the language and pipeline
class, loads in the model data and returns it. information to construct a `Language` object, loads in the model data and
returns it.
```python ```python
### Abstract example ### Abstract example
@ -543,8 +545,8 @@ loaded lazily, to avoid expensive setup code associated with the language data.
Load a model from a package or data path. If called with a package name, spaCy Load a model from a package or data path. If called with a package name, spaCy
will assume the model is a Python package and import and call its `load()` will assume the model is a Python package and import and call its `load()`
method. If called with a path, spaCy will assume it's a data directory, read the method. If called with a path, spaCy will assume it's a data directory, read the
language and pipeline settings from the meta.json and initialize a `Language` language and pipeline settings from the [`config.cfg`](/api/data-formats#config)
class. The model data will then be loaded in via and create a `Language` object. The model data will then be loaded in via
[`Language.from_disk`](/api/language#from_disk). [`Language.from_disk`](/api/language#from_disk).
> #### Example > #### Example
@ -607,7 +609,8 @@ components are created, as well as all training settings and hyperparameters.
### util.load_meta {#util.load_meta tag="function" new="3"} ### util.load_meta {#util.load_meta tag="function" new="3"}
Get a model's `meta.json` from a file path and validate its contents. Get a model's [`meta.json`](/api/data-formats#meta) from a file path and
validate its contents.
> #### Example > #### Example
> >

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 17 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -9,11 +9,7 @@ menu:
next: /usage/training next: /usage/training
--- ---
<!-- TODO: intro, short explanation of embeddings/transformers, point user to processing pipelines docs for intro --> <!-- TODO: intro, short explanation of embeddings/transformers, Tok2Vec and Transformer components, point user to processing pipelines docs for more general info that user should know first -->
## Shared embedding layers {#embedding-layers}
<!-- TODO: write: `Tok2Vec` and `Transformer` components -->
<Accordion title="Whats the difference between word vectors and language models?" id="vectors-vs-language-models"> <Accordion title="Whats the difference between word vectors and language models?" id="vectors-vs-language-models">
@ -55,6 +51,22 @@ of performance.
</Accordion> </Accordion>
## Shared embedding layers {#embedding-layers}
<!-- TODO: write -->
![Pipeline components using a shared embedding component vs. independent embedding layers](../images/tok2vec.svg)
| Shared | Independent |
| ------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
| ✅ **smaller:** models only need to include a single copy of the embeddings | ❌ **larger:** models need to include the embeddings for each component |
| ✅ **faster:** | ❌ **slower:** |
| ❌ **less composable:** all components require the same embedding component in the pipeline | ✅ **modular:** components can be moved and swapped freely |
![Pipeline components listening to shared embedding component](../images/tok2vec-listener.svg)
<!-- TODO: explain the listener concept, how it works etc. -->
## Using transformer models {#transformers} ## Using transformer models {#transformers}
Transformers are a family of neural network architectures that compute **dense, Transformers are a family of neural network architectures that compute **dense,
@ -295,18 +307,6 @@ is, a Thinc model that takes a list of [`Doc`](/api/doc) objects, and returns a
[`FullTransformerBatch`](/api/transformer#fulltransformerbatch) object with the [`FullTransformerBatch`](/api/transformer#fulltransformerbatch) object with the
transformer data. transformer data.
> #### Model type annotations
>
> In the documentation and code base, you may come across type annotations and
> descriptions of [Thinc](https://thinc.ai) model types, like ~~Model[List[Doc],
> List[Floats2d]]~~. This so-called generic type describes the layer and its
> input and output type in this case, it takes a list of `Doc` objects as the
> input and list of 2-dimensional arrays of floats as the output. You can read
> more about defining Thinc models [here](https://thinc.ai/docs/usage-models).
> Also see the [type checking](https://thinc.ai/docs/usage-type-checking) for
> how to enable linting in your editor to see live feedback if your inputs and
> outputs don't match.
The same idea applies to task models that power the **downstream components**. The same idea applies to task models that power the **downstream components**.
Most of spaCy's built-in model creation functions support a `tok2vec` argument, Most of spaCy's built-in model creation functions support a `tok2vec` argument,
which should be a Thinc layer of type ~~Model[List[Doc], List[Floats2d]]~~. This which should be a Thinc layer of type ~~Model[List[Doc], List[Floats2d]]~~. This

View File

@ -70,8 +70,7 @@ import Languages from 'widgets/languages.js'
> nlp = MultiLanguage() > nlp = MultiLanguage()
> >
> # With lazy-loading > # With lazy-loading
> from spacy.util import get_lang_class > nlp = spacy.blank("xx")
> nlp = get_lang_class('xx')
> ``` > ```
spaCy also supports models trained on more than one language. This is especially spaCy also supports models trained on more than one language. This is especially
@ -80,10 +79,10 @@ language-neutral models is `xx`. The language class, a generic subclass
containing only the base language data, can be found in containing only the base language data, can be found in
[`lang/xx`](https://github.com/explosion/spaCy/tree/master/spacy/lang/xx). [`lang/xx`](https://github.com/explosion/spaCy/tree/master/spacy/lang/xx).
To load your model with the neutral, multi-language class, simply set To train a model using the neutral multi-language class, you can set
`"language": "xx"` in your [model package](/usage/training#models-generating)'s `lang = "xx"` in your [training config](/usage/training#config). You can also
`meta.json`. You can also import the class directly, or call import the `MultiLanguage` class directly, or call
[`util.get_lang_class()`](/api/top-level#util.get_lang_class) for lazy-loading. [`spacy.blank("xx")`](/api/top-level#spacy.blank) for lazy-loading.
### Chinese language support {#chinese new=2.3} ### Chinese language support {#chinese new=2.3}
@ -308,12 +307,14 @@ model data.
```yaml ```yaml
### Directory structure {highlight="7"} ### Directory structure {highlight="7"}
└── en_core_web_md-3.0.0.tar.gz # downloaded archive └── en_core_web_md-3.0.0.tar.gz # downloaded archive
├── meta.json # model meta data
├── setup.py # setup file for pip installation ├── setup.py # setup file for pip installation
├── meta.json # copy of model meta
└── en_core_web_md # 📦 model package └── en_core_web_md # 📦 model package
├── __init__.py # init for pip installation ├── __init__.py # init for pip installation
├── meta.json # model meta data
└── en_core_web_md-3.0.0 # model data └── en_core_web_md-3.0.0 # model data
├── config.cfg # model config
├── meta.json # model meta
└── ... # directories with component data
``` ```
You can place the **model package directory** anywhere on your local file You can place the **model package directory** anywhere on your local file

View File

@ -232,7 +232,7 @@ available pipeline components and component functions.
| `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. | | `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. |
| `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. | | `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. |
| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. | | `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. |
| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | | | `tok2vec` | [`Tok2Vec`](/api/tok2vec) | Assign token-to-vector embeddings. |
| `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. | | `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. |
### Disabling and modifying pipeline components {#disabling} ### Disabling and modifying pipeline components {#disabling}

View File

@ -1096,11 +1096,12 @@ ruler.add_patterns([{"label": "ORG", "pattern": "Apple"}])
nlp.to_disk("/path/to/model") nlp.to_disk("/path/to/model")
``` ```
The saved model now includes the `"entity_ruler"` in its `"pipeline"` setting in The saved model now includes the `"entity_ruler"` in its
the `meta.json`, and the model directory contains a file `entityruler.jsonl` [`config.cfg`](/api/data-formats#config) and the model directory contains a file
with the patterns. When you load the model back in, all pipeline components will `entityruler.jsonl` with the patterns. When you load the model back in, all
be restored and deserialized including the entity ruler. This lets you ship pipeline components will be restored and deserialized including the entity
powerful model packages with binary weights _and_ rules included! ruler. This lets you ship powerful model packages with binary weights _and_
rules included!
### Using a large number of phrase patterns {#entityruler-large-phrase-patterns new="2.2.4"} ### Using a large number of phrase patterns {#entityruler-large-phrase-patterns new="2.2.4"}

View File

@ -569,9 +569,32 @@ back later. You can do this with the
nlp.to_disk('/home/me/data/en_example_model') nlp.to_disk('/home/me/data/en_example_model')
``` ```
The directory will be created if it doesn't exist, and the whole pipeline will The directory will be created if it doesn't exist, and the whole pipeline data,
be written out. To make the model more convenient to deploy, we recommend model meta and model configuration will be written out. To make the model more
wrapping it as a Python package. convenient to deploy, we recommend wrapping it as a
[Python package](/api/cli#package).
<Accordion title="Whats the difference between the config.cfg and meta.json?" spaced id="models-meta-vs-config">
When you save a model in spaCy v3.0+, two files will be exported: a
[`config.cfg`](/api/data-formats#config) based on
[`nlp.config`](/api/language#config) and a [`meta.json`](/api/data-formats#meta)
based on [`nlp.meta`](/api/language#meta).
- **config**: Configuration used to create the current `nlp` object, its
pipeline components and models, as well as training settings and
hyperparameters. Can include references to registered functions like
[pipeline components](/usage/processing-pipelines#custom-components) or
[model architectures](/api/architectures). Given a config, spaCy is able
reconstruct the whole tree of objects and the `nlp` object. An exported config
can also be used to [train a model](/usage/training#conig) with the same
settings.
- **meta**: Meta information about the model and the Python package, such as the
author information, license, version, data sources and label scheme. This is
mostly used for documentation purposes and for packaging models. It has no
impact on the functionality of the `nlp` object.
</Accordion>
### Generating a model package {#models-generating} ### Generating a model package {#models-generating}
@ -623,6 +646,9 @@ model package that can be installed using `pip install`.
├── en_example_model # model directory ├── en_example_model # model directory
│ ├── __init__.py # init for pip installation │ ├── __init__.py # init for pip installation
│ └── en_example_model-1.0.0 # model data │ └── en_example_model-1.0.0 # model data
│ ├── config.cfg # model config
│ ├── meta.json # model meta
│ └── ... # directories with component data
└── dist └── dist
└── en_example_model-1.0.0.tar.gz # installable package └── en_example_model-1.0.0.tar.gz # installable package
``` ```
@ -644,13 +670,25 @@ you can also **ship the code with your model** and include it in the
[pipeline components](/usage/processing-pipelines#custom-components) before the [pipeline components](/usage/processing-pipelines#custom-components) before the
`nlp` object is created. `nlp` object is created.
<Infobox variant="warning" title="Important note on making manual edits">
While it's no problem to edit the package code or meta information, avoid making
edits to the `config.cfg` **after** training, as this can easily lead to data
incompatibility. For instance, changing an architecture or hyperparameter can
mean that the trained weights are now incompatible. If you want to make
adjustments, you can do so before training. Otherwise, you should always trust
spaCy to export the current state of its `nlp` objects via
[`nlp.config`](/api/language#config).
</Infobox>
### Loading a custom model package {#loading} ### Loading a custom model package {#loading}
To load a model from a data directory, you can use To load a model from a data directory, you can use
[`spacy.load()`](/api/top-level#spacy.load) with the local path. This will look [`spacy.load()`](/api/top-level#spacy.load) with the local path. This will look
for a meta.json in the directory and use the `lang` and `pipeline` settings to for a `config.cfg` in the directory and use the `lang` and `pipeline` settings
initialize a `Language` class with a processing pipeline and load in the model to initialize a `Language` class with a processing pipeline and load in the
data. model data.
```python ```python
nlp = spacy.load("/path/to/model") nlp = spacy.load("/path/to/model")

View File

@ -384,7 +384,40 @@ that reference this variable.
### Model architectures {#model-architectures} ### Model architectures {#model-architectures}
<!-- TODO: refer to architectures API: /api/architectures --> > #### 💡 Model type annotations
>
> In the documentation and code base, you may come across type annotations and
> descriptions of [Thinc](https://thinc.ai) model types, like ~~Model[List[Doc],
> List[Floats2d]]~~. This so-called generic type describes the layer and its
> input and output type in this case, it takes a list of `Doc` objects as the
> input and list of 2-dimensional arrays of floats as the output. You can read
> more about defining Thinc models [here](https://thinc.ai/docs/usage-models).
> Also see the [type checking](https://thinc.ai/docs/usage-type-checking) for
> how to enable linting in your editor to see live feedback if your inputs and
> outputs don't match.
A **model architecture** is a function that wires up a Thinc
[`Model`](https://thinc.ai/docs/api-model) instance, which you can then use in a
component or as a layer of a larger network. You can use Thinc as a thin
[wrapper around frameworks](https://thinc.ai/docs/usage-frameworks) such as
PyTorch, TensorFlow or MXNet, or you can implement your logic in Thinc
[directly](https://thinc.ai/docs/usage-models).
spaCy's built-in components will never construct their `Model` instances
themselves, so you won't have to subclass the component to change its model
architecture. You can just **update the config** so that it refers to a
different registered function. Once the component has been created, its `Model`
instance has already been assigned, so you cannot change its model architecture.
The architecture is like a recipe for the network, and you can't change the
recipe once the dish has already been prepared. You have to make a new one.
spaCy includes a variety of built-in [architectures](/api/architectures) for
different tasks. For example:
<!-- TODO: -->
| Architecture | Description |
| ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [HashEmbedCNN](/api/architectures#HashEmbedCNN) | Build spaCys “standard” embedding layer, which uses hash embedding with subword features and a CNN with layer-normalized maxout. ~~Model[List[Doc], List[Floats2d]]~~ |
### Metrics, training output and weighted scores {#metrics} ### Metrics, training output and weighted scores {#metrics}
@ -433,14 +466,12 @@ components are weighted equally.
| Name | Description | | Name | Description |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- | | -------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| **Loss** | The training loss representing the amount of work left for the optimizer. Should decrease, but usually not to `0`. | | **Loss** | The training loss representing the amount of work left for the optimizer. Should decrease, but usually not to `0`. |
| **Precision** (P) | The percentage of generated predictions that are correct. Should increase. | | **Precision** (P) | Percentage of predicted annotations that were correct. Should increase. |
| **Recall** (R) | The percentage of gold-standard annotations that are in fact predicted. Should increase. | | **Recall** (R) | Percentage of reference annotations recovered. Should increase. |
| **F-Score** (F) | The weighted average of precision and recall. Should increase. | | **F-Score** (F) | Harmonic mean of precision and recall. Should increase. |
| **UAS** / **LAS** | Unlabeled and labeled attachment score for the dependency parser, i.e. the percentage of correct arcs. Should increase. | | **UAS** / **LAS** | Unlabeled and labeled attachment score for the dependency parser, i.e. the percentage of correct arcs. Should increase. |
| **Words per second** (WPS) | Prediction speed in words per second. Should stay stable. | | **Words per second** (WPS) | Prediction speed in words per second. Should stay stable. |
<!-- TODO: is this still relevant? -->
Note that if the development data has raw text, some of the gold-standard Note that if the development data has raw text, some of the gold-standard
entities might not align to the predicted tokenization. These tokenization entities might not align to the predicted tokenization. These tokenization
errors are **excluded from the NER evaluation**. If your tokenization makes it errors are **excluded from the NER evaluation**. If your tokenization makes it

View File

@ -1,305 +0,0 @@
---
title: Transformers
teaser: Using transformer models like BERT in spaCy
menu:
- ['Installation', 'install']
- ['Runtime Usage', 'runtime']
- ['Training Usage', 'training']
next: /usage/training
---
## Installation {#install hidden="true"}
Transformers are a family of neural network architectures that compute **dense,
context-sensitive representations** for the tokens in your documents. Downstream
models in your pipeline can then use these representations as input features to
**improve their predictions**. You can connect multiple components to a single
transformer model, with any or all of those components giving feedback to the
transformer to fine-tune it to your tasks. spaCy's transformer support
interoperates with [PyTorch](https://pytorch.org) and the
[HuggingFace `transformers`](https://huggingface.co/transformers/) library,
giving you access to thousands of pretrained models for your pipelines. There
are many [great guides](http://jalammar.github.io/illustrated-transformer/) to
transformer models, but for practical purposes, you can simply think of them as
a drop-in replacement that let you achieve **higher accuracy** in exchange for
**higher training and runtime costs**.
### System requirements
We recommend an NVIDIA GPU with at least 10GB of memory in order to work with
transformer models. The exact requirements will depend on the transformer you
model you choose and whether you're training the pipeline or simply running it.
Training a transformer-based model without a GPU will be too slow for most
practical purposes. You'll also need to make sure your GPU drivers are
up-to-date and v9+ of the CUDA runtime is installed.
Once you have CUDA installed, you'll need to install two pip packages,
[`cupy`](https://docs.cupy.dev/en/stable/install.html) and
[`spacy-transformers`](https://github.com/explosion/spacy-transformers). `cupy`
is just like `numpy`, but for GPU. The best way to install it is to choose a
wheel that matches the version of CUDA you're using. You may also need to set
the `CUDA_PATH` environment variable if your CUDA runtime is installed in a
non-standard location. Putting it all together, if you had installed CUDA 10.2
in `/opt/nvidia/cuda`, you would run:
```bash
### Installation with CUDA
export CUDA_PATH="/opt/nvidia/cuda"
pip install cupy-cuda102
pip install spacy-transformers
```
Provisioning a new machine will require about 5GB of data to be downloaded in
total: 3GB for the CUDA runtime, 800MB for PyTorch, 400MB for CuPy, 500MB for
the transformer weights, and about 200MB for spaCy and its various requirements.
## Runtime usage {#runtime}
Transformer models can be used as **drop-in replacements** for other types of
neural networks, so your spaCy pipeline can include them in a way that's
completely invisible to the user. Users will download, load and use the model in
the standard way, like any other spaCy pipeline. Instead of using the
transformers as subnetworks directly, you can also use them via the
[`Transformer`](/api/transformer) pipeline component.
![The processing pipeline with the transformer component](../images/pipeline_transformer.svg)
The `Transformer` component sets the
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
which lets you access the transformers outputs at runtime.
```bash
$ python -m spacy download en_core_trf_lg
```
```python
### Example
import spacy
from thinc.api import use_pytorch_for_gpu_memory, require_gpu
# Use the GPU, with memory allocations directed via PyTorch.
# This prevents out-of-memory errors that would otherwise occur from competing
# memory pools.
use_pytorch_for_gpu_memory()
require_gpu(0)
nlp = spacy.load("en_core_trf_lg")
for doc in nlp.pipe(["some text", "some other text"]):
tokvecs = doc._.trf_data.tensors[-1]
```
You can also customize how the [`Transformer`](/api/transformer) component sets
annotations onto the [`Doc`](/api/doc), by customizing the `annotation_setter`.
This callback will be called with the raw input and output data for the whole
batch, along with the batch of `Doc` objects, allowing you to implement whatever
you need. The annotation setter is called with a batch of [`Doc`](/api/doc)
objects and a [`FullTransformerBatch`](/api/transformer#fulltransformerbatch)
containing the transformers data for the batch.
```python
def custom_annotation_setter(docs, trf_data):
# TODO:
...
nlp = spacy.load("en_core_trf_lg")
nlp.get_pipe("transformer").annotation_setter = custom_annotation_setter
doc = nlp("This is a text")
print() # TODO:
```
## Training usage {#training}
The recommended workflow for training is to use spaCy's
[config system](/usage/training#config), usually via the
[`spacy train`](/api/cli#train) command. The training config defines all
component settings and hyperparameters in one place and lets you describe a tree
of objects by referring to creation functions, including functions you register
yourself. For details on how to get started with training your own model, check
out the [training quickstart](/usage/training#quickstart).
<Project id="en_core_bert">
The easiest way to get started is to clone a transformers-based project
template. Swap in your data, edit the settings and hyperparameters and train,
evaluate, package and visualize your model.
</Project>
The `[components]` section in the [`config.cfg`](/api/data-formats#config)
describes the pipeline components and the settings used to construct them,
including their model implementation. Here's a config snippet for the
[`Transformer`](/api/transformer) component, along with matching Python code. In
this case, the `[components.transformer]` block describes the `transformer`
component:
> #### Python equivalent
>
> ```python
> from spacy_transformers import Transformer, TransformerModel
> from spacy_transformers.annotation_setters import null_annotation_setter
> from spacy_transformers.span_getters import get_doc_spans
>
> trf = Transformer(
> nlp.vocab,
> TransformerModel(
> "bert-base-cased",
> get_spans=get_doc_spans,
> tokenizer_config={"use_fast": True},
> ),
> annotation_setter=null_annotation_setter,
> max_batch_items=4096,
> )
> ```
```ini
### config.cfg (excerpt)
[components.transformer]
factory = "transformer"
max_batch_items = 4096
[components.transformer.model]
@architectures = "spacy-transformers.TransformerModel.v1"
name = "bert-base-cased"
tokenizer_config = {"use_fast": true}
[components.transformer.model.get_spans]
@span_getters = "doc_spans.v1"
[components.transformer.annotation_setter]
@annotation_setters = "spacy-transformer.null_annotation_setter.v1"
```
The `[components.transformer.model]` block describes the `model` argument passed
to the transformer component. It's a Thinc
[`Model`](https://thinc.ai/docs/api-model) object that will be passed into the
component. Here, it references the function
[spacy-transformers.TransformerModel.v1](/api/architectures#TransformerModel)
registered in the [`architectures` registry](/api/top-level#registry). If a key
in a block starts with `@`, it's **resolved to a function** and all other
settings are passed to the function as arguments. In this case, `name`,
`tokenizer_config` and `get_spans`.
`get_spans` is a function that takes a batch of `Doc` object and returns lists
of potentially overlapping `Span` objects to process by the transformer. Several
[built-in functions](/api/transformer#span-getters) are available for example,
to process the whole document or individual sentences. When the config is
resolved, the function is created and passed into the model as an argument.
<Infobox variant="warning">
Remember that the `config.cfg` used for training should contain **no missing
values** and requires all settings to be defined. You don't want any hidden
defaults creeping in and changing your results! spaCy will tell you if settings
are missing, and you can run
[`spacy init fill-config`](/api/cli#init-fill-config) to automatically fill in
all defaults.
</Infobox>
### Customizing the settings {#training-custom-settings}
To change any of the settings, you can edit the `config.cfg` and re-run the
training. To change any of the functions, like the span getter, you can replace
the name of the referenced function e.g. `@span_getters = "sent_spans.v1"` to
process sentences. You can also register your own functions using the
`span_getters` registry:
> #### config.cfg
>
> ```ini
> [components.transformer.model.get_spans]
> @span_getters = "custom_sent_spans"
> ```
```python
### code.py
import spacy_transformers
@spacy_transformers.registry.span_getters("custom_sent_spans")
def configure_custom_sent_spans():
# TODO: write custom example
def get_sent_spans(docs):
return [list(doc.sents) for doc in docs]
return get_sent_spans
```
To resolve the config during training, spaCy needs to know about your custom
function. You can make it available via the `--code` argument that can point to
a Python file. For more details on training with custom code, see the
[training documentation](/usage/training#custom-code).
```bash
$ python -m spacy train ./config.cfg --code ./code.py
```
### Customizing the model implementations {#training-custom-model}
The [`Transformer`](/api/transformer) component expects a Thinc
[`Model`](https://thinc.ai/docs/api-model) object to be passed in as its `model`
argument. You're not limited to the implementation provided by
`spacy-transformers` the only requirement is that your registered function
must return an object of type ~~Model[List[Doc], FullTransformerBatch]~~: that
is, a Thinc model that takes a list of [`Doc`](/api/doc) objects, and returns a
[`FullTransformerBatch`](/api/transformer#fulltransformerbatch) object with the
transformer data.
> #### Model type annotations
>
> In the documentation and code base, you may come across type annotations and
> descriptions of [Thinc](https://thinc.ai) model types, like ~~Model[List[Doc],
> List[Floats2d]]~~. This so-called generic type describes the layer and its
> input and output type in this case, it takes a list of `Doc` objects as the
> input and list of 2-dimensional arrays of floats as the output. You can read
> more about defining Thinc models [here](https://thinc.ai/docs/usage-models).
> Also see the [type checking](https://thinc.ai/docs/usage-type-checking) for
> how to enable linting in your editor to see live feedback if your inputs and
> outputs don't match.
The same idea applies to task models that power the **downstream components**.
Most of spaCy's built-in model creation functions support a `tok2vec` argument,
which should be a Thinc layer of type `Model[List[Doc], List[Floats2d]]`. This
is where we'll plug in our transformer model, using the
[Tok2VecListener](/api/architectures#Tok2VecListener) layer, which sneakily
delegates to the `Transformer` pipeline component.
```ini
### config.cfg (excerpt) {highlight="12"}
[components.ner]
factory = "ner"
[nlp.pipeline.ner.model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 3
hidden_width = 128
maxout_pieces = 3
use_upper = false
[nlp.pipeline.ner.model.tok2vec]
@architectures = "spacy-transformers.Tok2VecListener.v1"
grad_factor = 1.0
[nlp.pipeline.ner.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
```
The [Tok2VecListener](/api/architectures#Tok2VecListener) layer expects a
[pooling layer](https://thinc.ai/docs/api-layers#reduction-ops) as the argument
`pooling`, which needs to be of type `Model[Ragged, Floats2d]`. This layer
determines how the vector for each spaCy token will be computed from the zero or
more source rows the token is aligned against. Here we use the
[`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean) layer, which
averages the wordpiece rows. We could instead use
[`reduce_max`](https://thinc.ai/docs/api-layers#reduce_max), or a custom
function you write yourself.
You can have multiple components all listening to the same transformer model,
and all passing gradients back to it. By default, all of the gradients will be
**equally weighted**. You can control this with the `grad_factor` setting, which
lets you reweight the gradients from the different listeners. For instance,
setting `grad_factor = 0` would disable gradients from one of the listeners,
while `grad_factor = 2.0` would multiply them by 2. This is similar to having a
custom learning rate for each component. Instead of a constant, you can also
provide a schedule, allowing you to freeze the shared parameters at the start of
training.

View File

@ -152,6 +152,7 @@ The following methods, attributes and commands are new in spaCy v3.0.
| [`Language.config`](/api/language#config) | The [config](/usage/training#config) used to create the current `nlp` object. An instance of [`Config`](https://thinc.ai/docs/api-config#config) and can be saved to disk and used for training. | | [`Language.config`](/api/language#config) | The [config](/usage/training#config) used to create the current `nlp` object. An instance of [`Config`](https://thinc.ai/docs/api-config#config) and can be saved to disk and used for training. |
| [`Pipe.score`](/api/pipe#score) | Method on trainable pipeline components that returns a dictionary of evaluation scores. | | [`Pipe.score`](/api/pipe#score) | Method on trainable pipeline components that returns a dictionary of evaluation scores. |
| [`registry`](/api/top-level#registry) | Function registry to map functions to string names that can be referenced in [configs](/usage/training#config). | | [`registry`](/api/top-level#registry) | Function registry to map functions to string names that can be referenced in [configs](/usage/training#config). |
| [`util.load_meta`](/api/top-level#util.load_meta) [`util.load_config`](/api/top-level#util.load_config) | Updated helpers for loading a model's [`meta.json`](/api/data-formats#meta) and [`config.cfg`](/api/data-formats#config). |
| [`init config`](/api/cli#init-config) [`init fill-config`](/api/cli#init-fill-config) [`debug config`](/api/cli#debug-config) | CLI commands for initializing, auto-filling and debugging [training configs](/usage/training). | | [`init config`](/api/cli#init-config) [`init fill-config`](/api/cli#init-fill-config) [`debug config`](/api/cli#debug-config) | CLI commands for initializing, auto-filling and debugging [training configs](/usage/training). |
| [`project`](/api/cli#project) | Suite of CLI commands for cloning, running and managing [spaCy projects](/usage/projects). | | [`project`](/api/cli#project) | Suite of CLI commands for cloning, running and managing [spaCy projects](/usage/projects). |
@ -175,6 +176,11 @@ Note that spaCy v3.0 now requires **Python 3.6+**.
There can be many [different models](/models) and not just one "English There can be many [different models](/models) and not just one "English
model", so you should always use the full model name like model", so you should always use the full model name like
[`en_core_web_sm`](/models/en) explicitly. [`en_core_web_sm`](/models/en) explicitly.
- A model's [`meta.json`](/api/data-formats#meta) is now only used to provide
meta information like the model name, author, license and labels. It's **not**
used to construct the processing pipeline anymore. This is all defined in the
[`config.cfg`](/api/data-formats#config), which also includes all settings
used to train the model.
- The [`train`](/api/cli#train) and [`pretrain`](/api/cli#pretrain) commands now - The [`train`](/api/cli#train) and [`pretrain`](/api/cli#pretrain) commands now
only take a `config.cfg` file containing the full only take a `config.cfg` file containing the full
[training config](/usage/training#config). [training config](/usage/training#config).

View File

@ -32,6 +32,7 @@
"Floats2d": "https://thinc.ai/docs/api-types#types", "Floats2d": "https://thinc.ai/docs/api-types#types",
"Floats3d": "https://thinc.ai/docs/api-types#types", "Floats3d": "https://thinc.ai/docs/api-types#types",
"FloatsXd": "https://thinc.ai/docs/api-types#types", "FloatsXd": "https://thinc.ai/docs/api-types#types",
"Ops": "https://thinc.ai/docs/api-backends#ops",
"cymem.Pool": "https://github.com/explosion/cymem", "cymem.Pool": "https://github.com/explosion/cymem",
"preshed.BloomFilter": "https://github.com/explosion/preshed", "preshed.BloomFilter": "https://github.com/explosion/preshed",
"transformers.BatchEncoding": "https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding", "transformers.BatchEncoding": "https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding",

View File

@ -53,7 +53,15 @@ const icons = {
package: PackageIcon, package: PackageIcon,
} }
export default function Icon({ name, width = 20, height, inline = false, variant, className }) { export default function Icon({
name,
width = 20,
height,
inline = false,
variant,
className,
...props
}) {
const IconComponent = icons[name] const IconComponent = icons[name]
const iconClassNames = classNames(classes.root, className, { const iconClassNames = classNames(classes.root, className, {
[classes.inline]: inline, [classes.inline]: inline,
@ -67,6 +75,7 @@ export default function Icon({ name, width = 20, height, inline = false, variant
aria-hidden="true" aria-hidden="true"
width={width} width={width}
height={height || width} height={height || width}
{...props}
/> />
) )
} }

View File

@ -9,19 +9,25 @@ function isNum(children) {
return isString(children) && /^\d+[.,]?[\dx]+?(|x|ms|mb|gb|k|m)?$/i.test(children) return isString(children) && /^\d+[.,]?[\dx]+?(|x|ms|mb|gb|k|m)?$/i.test(children)
} }
function getCellContent(children) { function getCellContent(cellChildren) {
const icons = { const icons = {
'✅': { name: 'yes', variant: 'success' }, '✅': { name: 'yes', variant: 'success', 'aria-label': 'positive' },
'❌': { name: 'no', variant: 'error' }, '❌': { name: 'no', variant: 'error', 'aria-label': 'negative' },
} }
let children = isString(cellChildren) ? [cellChildren] : cellChildren
if (isString(children) && icons[children.trim()]) { if (Array.isArray(children)) {
const iconProps = icons[children.trim()] return children.map((child, i) => {
return <Icon {...iconProps} /> if (isString(child)) {
} const icon = icons[child.trim()]
// Work around prettier auto-escape if (icon) {
if (isString(children) && children.startsWith('\\')) { const props = { ...icon, inline: i < children.length, 'aria-hidden': undefined }
return children.slice(1) return <Icon {...props} key={i} />
}
// Work around prettier auto-escape
if (child.startsWith('\\')) return child.slice(1)
}
return child
})
} }
return children return children
} }

View File

@ -38,7 +38,8 @@ const DATA = [
{ {
id: 'optimize', id: 'optimize',
title: 'Optimize for', title: 'Optimize for',
help: '...', help:
'Optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger & slower model). Will impact the choice of architecture, pretrained weights and hyperparameters.',
options: [ options: [
{ id: 'efficiency', title: 'efficiency', checked: DEFAULT_OPT === 'efficiency' }, { id: 'efficiency', title: 'efficiency', checked: DEFAULT_OPT === 'efficiency' },
{ id: 'accuracy', title: 'accuracy', checked: DEFAULT_OPT === 'accuracy' }, { id: 'accuracy', title: 'accuracy', checked: DEFAULT_OPT === 'accuracy' },
@ -84,10 +85,12 @@ export default function QuickstartTraining({ id, title, download = 'config.cfg'
query={query} query={query}
render={({ site }) => { render={({ site }) => {
const langs = site.siteMetadata.languages const langs = site.siteMetadata.languages
DATA[0].dropdown = langs.map(({ name, code }) => ({ DATA[0].dropdown = langs
id: code, .map(({ name, code }) => ({
title: name, id: code,
})) title: name,
}))
.sort((a, b) => a.id.localeCompare(b.id))
return ( return (
<Quickstart <Quickstart
download={download} download={download}