Merge branch 'v4' into refactor/span-group-for-mentions

# Conflicts:
#	spacy/errors.py
#	spacy/kb/__init__.py
#	spacy/kb/candidate.pxd
#	spacy/kb/candidate.pyx
#	spacy/kb/kb.pyx
#	spacy/kb/kb_in_memory.pyx
#	spacy/ml/models/entity_linker.py
#	spacy/pipeline/entity_linker.py
#	spacy/tests/pipeline/test_entity_linker.py
#	spacy/tests/serialize/test_serialize_kb.py
#	website/docs/api/inmemorylookupkb.mdx
#	website/docs/api/kb.mdx
This commit is contained in:
Raphael Mitsch 2023-03-20 09:29:52 +01:00
commit 1620a04d46
21 changed files with 445 additions and 163 deletions

View File

@ -82,7 +82,7 @@ class Warnings(metaclass=ErrorsWithCodes):
"ignoring the duplicate entry.")
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
W024 = ("Entity '{entity}' - alias '{alias}' combination already exists in "
"the Knowledge Base.")
W026 = ("Unable to set all sentence boundaries from dependency parses. If "
"you are constructing a parse tree incrementally by setting "
@ -209,7 +209,11 @@ class Warnings(metaclass=ErrorsWithCodes):
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
# v4 warning strings
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
W401 = ("`incl_prior is True`, but the selected knowledge base type {kb_type} doesn't support prior probability "
"lookups so this setting will be ignored. If your KB does support prior probability lookups, make sure "
"to return `True` in `.supports_prior_probs`.")
class Errors(metaclass=ErrorsWithCodes):
@ -960,7 +964,9 @@ class Errors(metaclass=ErrorsWithCodes):
E4003 = ("Training examples for distillation must have the exact same tokens in the "
"reference and predicted docs.")
E4004 = ("Backprop is not supported when is_train is not set.")
E4005 = ("Expected `entity_id` to be of type {should_type}, but is of type {is_type}.")
E4005 = ("EntityLinker_v1 is not supported in spaCy v4. Update your configuration.")
E4006 = ("Expected `entity_id` to be of type {exp_type}, but is of type {found_type}.")
RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"}

15
spacy/kb/candidate.pxd Normal file
View File

@ -0,0 +1,15 @@
from libcpp.vector cimport vector
from .kb_in_memory cimport InMemoryLookupKB
from ..typedefs cimport hash_t
cdef class Candidate:
pass
cdef class InMemoryCandidate(Candidate):
cdef readonly hash_t _entity_hash
cdef readonly hash_t _alias_hash
cpdef vector[float] _entity_vector
cdef float _prior_prob
cdef readonly InMemoryLookupKB _kb
cdef float _entity_freq

96
spacy/kb/candidate.pyx Normal file
View File

@ -0,0 +1,96 @@
# cython: infer_types=True, profile=True
from .kb_in_memory cimport InMemoryLookupKB
from ..errors import Errors
cdef class Candidate:
"""A `Candidate` object refers to a textual mention that may or may not be resolved
to a specific entity from a Knowledge Base. This will be used as input for the entity linking
algorithm which will disambiguate the various candidates to the correct one.
Each candidate, which represents a possible link between one textual mention and one entity in the knowledge base,
is assigned a certain prior probability.
DOCS: https://spacy.io/api/kb/#candidate-init
"""
def __init__(self):
# Make sure abstract Candidate is not instantiated.
if self.__class__ == Candidate:
raise TypeError(
Errors.E1046.format(cls_name=self.__class__.__name__)
)
@property
def entity_id(self) -> int:
"""RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
otherwise the hash of the entity ID string)."""
raise NotImplementedError
@property
def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID."""
raise NotImplementedError
@property
def entity_vector(self) -> vector[float]:
"""RETURNS (vector[float]): Entity vector."""
raise NotImplementedError
cdef class InMemoryCandidate(Candidate):
"""Candidate for InMemoryLookupKB."""
def __init__(
self,
kb: InMemoryLookupKB,
entity_hash: int,
alias_hash: int,
entity_vector: vector[float],
prior_prob: float,
entity_freq: float
):
"""
kb (InMemoryLookupKB]): InMemoryLookupKB instance.
entity_id (int): Entity ID as hash that can be looked up with InMemoryKB.vocab.strings.__getitem__().
entity_freq (int): Entity frequency in KB corpus.
entity_vector (List[float]): Entity embedding.
alias_hash (int): Alias hash.
prior_prob (float): Prior probability of entity for this alias. I. e. the probability that, independent of
the context, this alias - which matches one of this entity's aliases - resolves to one this entity.
"""
super().__init__()
self._entity_hash = entity_hash
self._entity_vector = entity_vector
self._prior_prob = prior_prob
self._kb = kb
self._alias_hash = alias_hash
self._entity_freq = entity_freq
@property
def entity_id(self) -> int:
return self._entity_hash
@property
def entity_vector(self) -> vector[float]:
return self._entity_vector
@property
def prior_prob(self) -> float:
"""RETURNS (float): Prior probability that this alias, which matches one of this entity's synonyms, resolves to
this entity."""
return self._prior_prob
@property
def alias(self) -> str:
"""RETURNS (str): Alias."""
return self._kb.vocab.strings[self._alias_hash]
@property
def entity_id_(self) -> str:
return self._kb.vocab.strings[self._entity_hash]
@property
def entity_freq(self) -> float:
"""RETURNS (float): Entity frequency in KB corpus."""
return self._entity_freq

View File

@ -32,9 +32,10 @@ cdef class KnowledgeBase:
def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]:
"""
Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If no candidate is found for a given text, an empty list is returned.
Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
If no candidates are found for a given mention, an empty list is returned.
mentions (SpanGroup): Mentions for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
"""
@ -42,9 +43,10 @@ cdef class KnowledgeBase:
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
"""
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If the no candidate is found for a given text, an empty list is returned.
Return candidate entities for a specific mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
If no candidate is found for the given mention, an empty list is returned.
mention (Span): Mention for which to get candidates.
RETURNS (Iterable[Candidate]): Identified candidates.
"""
@ -106,3 +108,10 @@ cdef class KnowledgeBase:
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
)
@property
def supports_prior_probs(self) -> bool:
"""RETURNS (bool): Whether this KB type supports looking up prior probabilities for entity mentions."""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="supports_prior_probs", name=self.__name__)
)

View File

@ -243,9 +243,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return [
InMemoryCandidate(
hash_to_str=self.vocab.strings.__getitem__,
entity_id=self._entries[entry_index].entity_hash,
mention=alias,
kb=self,
entity_hash=self._entries[entry_index].entity_hash,
alias_hash=alias_hash,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
prior_prob=prior_prob,
entity_freq=self._entries[entry_index].freq
@ -283,6 +283,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return 0.0
def supports_prior_probs(self) -> bool:
return True
def to_bytes(self, **kwargs):
"""Serialize the current state to a binary string.
"""

View File

@ -124,18 +124,18 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
Return candidate entities for a given mention and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query.
mention (Span): Entity mention for which to identify candidates.
RETURNS (Iterable[InMemoryCandidate]): Identified candidates.
RETURNS (Iterable[Candidate]): Identified candidates.
"""
return kb.get_candidates(mention)
def get_candidates_batch(
kb: KnowledgeBase, mentions: SpanGroup
kb: KnowledgeBase, mentions: Iterable[SpanGroup]
) -> Iterable[Iterable[Candidate]]:
"""
Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query.
mention (SpanGroup): Entity mentions for which to identify candidates.
RETURNS (Iterable[Iterable[InMemoryCandidate]]): Identified candidates.
mentions (SpanGroup): Entity mentions for which to identify candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
"""
return kb.get_candidates_batch(mentions)

View File

@ -1,5 +1,5 @@
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
from typing import cast
import warnings
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, cast
from numpy import dtype
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from pathlib import Path
@ -10,6 +10,7 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
from ..kb import KnowledgeBase, Candidate
from ..tokens import Doc, Span
from ..ml import empty_kb
from ..tokens import Doc, Span, SpanGroup
from .pipe import deserialize_config
@ -17,7 +18,7 @@ from .trainable_pipe import TrainablePipe
from ..language import Language
from ..vocab import Vocab
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..errors import Errors, Warnings
from ..util import SimpleFrozenList, registry
from .. import util
from ..scorer import Scorer
@ -57,8 +58,8 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"overwrite": False,
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True,
"candidates_batch_size": 1,
@ -117,28 +118,9 @@ def make_entity_linker(
prediction is discarded. If None, predictions are not filtered by any threshold.
save_activations (bool): save model activations in Doc when annotating.
"""
if not model.attrs.get("include_span_maker", False):
try:
from spacy_legacy.components.entity_linker import EntityLinker_v1
except:
raise ImportError(
"In order to use v1 of the EntityLinker, you must use spacy-legacy>=3.0.12."
)
# The only difference in arguments here is that use_gold_ents and threshold aren't available.
return EntityLinker_v1(
nlp.vocab,
model,
name,
labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
overwrite=overwrite,
scorer=scorer,
)
raise ValueError(Errors.E4005)
return EntityLinker(
nlp.vocab,
model,
@ -259,6 +241,8 @@ class EntityLinker(TrainablePipe):
if candidates_batch_size < 1:
raise ValueError(Errors.E1044)
if self.incl_prior and not self.kb.supports_prior_probs:
warnings.warn(Warnings.W401)
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
@ -542,18 +526,19 @@ class EntityLinker(TrainablePipe):
)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_id_str)
final_kb_ids.append(candidates[0].entity_id_)
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[1.0],
ents=[candidates[0].entity_id_int],
ents=[candidates[0].entity_id],
)
else:
random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.incl_prior:
if self.incl_prior and self.kb.supports_prior_probs:
prior_probs = xp.asarray([c.prior_prob for c in candidates]) # type: ignore
else:
prior_probs = xp.asarray([0.0 for _ in candidates])
scores = prior_probs
# add in similarity from the context
@ -577,7 +562,7 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_id_str
candidates[scores.argmax().item()].entity_id_
if self.threshold is None
or scores.max() >= self.threshold
else EntityLinker.NIL
@ -586,7 +571,7 @@ class EntityLinker(TrainablePipe):
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=scores,
ents=[c.entity_id_int for c in candidates],
ents=[c.entity_id for c in candidates],
)
self._add_doc_activations(
docs_scores=docs_scores,

View File

@ -1,5 +1,6 @@
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any, Tuple
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from thinc.types import Floats2d
from itertools import islice
from .trainable_pipe import TrainablePipe
@ -157,39 +158,9 @@ class Tok2Vec(TrainablePipe):
DOCS: https://spacy.io/api/tok2vec#update
"""
if losses is None:
losses = {}
validate_examples(examples, "Tok2Vec.update")
docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
losses.setdefault(self.name, 0.0)
def accumulate_gradient(one_d_tokvecs):
"""Accumulate tok2vec loss and gradient. This is passed as a callback
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return d_docs
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, accumulate_gradient)
if self.listeners:
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
return self._update_with_docs(docs, drop=drop, sgd=sgd, losses=losses)
def get_loss(self, examples, scores) -> None:
pass
@ -219,6 +190,96 @@ class Tok2Vec(TrainablePipe):
def add_label(self, label):
raise NotImplementedError
def distill(
self,
teacher_pipe: Optional["TrainablePipe"],
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Performs an update of the student pipe's model using the
student's distillation examples and sets the annotations
of the teacher's distillation examples using the teacher pipe.
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to use
for prediction.
examples (Iterable[Example]): Distillation examples. The reference (teacher)
and predicted (student) docs must have the same number of tokens and the
same orthography.
drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
DOCS: https://spacy.io/api/tok2vec#distill
"""
# By default we require a teacher pipe, but there are downstream
# implementations that don't require a pipe.
if teacher_pipe is None:
raise ValueError(Errors.E4002.format(name=self.name))
teacher_docs = [eg.reference for eg in examples]
student_docs = [eg.predicted for eg in examples]
teacher_preds = teacher_pipe.predict(teacher_docs)
teacher_pipe.set_annotations(teacher_docs, teacher_preds)
return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)
def _update_with_docs(
self,
docs: Iterable[Doc],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
):
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
set_dropout_rate(self.model, drop)
tokvecs, accumulate_gradient, backprop = self._create_backprops(
docs, losses, sgd=sgd
)
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, accumulate_gradient)
if self.listeners:
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
def _create_backprops(
self,
docs: Iterable[Doc],
losses: Dict[str, float],
*,
sgd: Optional[Optimizer] = None,
) -> Tuple[Floats2d, Callable, Callable]:
tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def accumulate_gradient(one_d_tokvecs):
"""Accumulate tok2vec loss and gradient. This is passed as a callback
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return d_docs
return tokvecs, accumulate_gradient, backprop
class Tok2VecListener(Model):
"""A layer that gets fed its answers from an upstream connection,

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overlo
from pathlib import Path
class StringStore:
def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
def __init__(self, strings: Optional[Iterable[str]] = None) -> None: ...
@overload
def __getitem__(self, string_or_hash: str) -> int: ...
@overload

View File

@ -7,7 +7,7 @@ from thinc.types import Ragged
from spacy import registry, util
from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle
from spacy.kb import InMemoryCandidate, InMemoryLookupKB, KnowledgeBase
from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase
from spacy.lang.en import English
from spacy.ml import load_kb
from spacy.ml.models.entity_linker import build_span_maker, get_candidates
@ -465,16 +465,17 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
adam_ent_cands = get_candidates(mykb, adam_ent)
assert len(get_candidates(mykb, douglas_ent)) == 2
assert len(get_candidates(mykb, adam_ent)) == 1
assert len(adam_ent_cands) == 1
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates
assert get_candidates(mykb, adam_ent)[0].entity_id_str == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
assert adam_ent_cands[0].entity_id_ == "Q2"
assert adam_ent_cands[0].alias == "adam"
assert_almost_equal(adam_ent_cands[0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp):
@ -563,9 +564,9 @@ def test_vocab_serialization(nlp):
candidates = mykb._get_alias_candidates("adam")
assert len(candidates) == 1
assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_id_str == "Q2"
assert candidates[0].mention == "adam"
assert candidates[0].entity_id == q2_hash
assert candidates[0].entity_id_ == "Q2"
assert candidates[0].alias == "adam"
with make_tempdir() as d:
mykb.to_disk(d / "kb")
@ -574,9 +575,9 @@ def test_vocab_serialization(nlp):
candidates = kb_new_vocab._get_alias_candidates("adam")
assert len(candidates) == 1
assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_id_str == "Q2"
assert candidates[0].mention == "adam"
assert candidates[0].entity_id == q2_hash
assert candidates[0].entity_id_ == "Q2"
assert candidates[0].alias == "adam"
assert kb_new_vocab.get_vector("Q2") == [2]
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)
@ -991,14 +992,11 @@ def test_scorer_links():
@pytest.mark.parametrize(
"name,config",
[
("entity_linker", {"@architectures": "spacy.EntityLinker.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
("entity_linker", {"@architectures": "spacy.EntityLinker.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL}),
],
)
# fmt: on
def test_legacy_architectures(name, config):
from spacy_legacy.components.entity_linker import EntityLinker_v1
# Ensure that the legacy architectures still work
vector_length = 3
nlp = English()
@ -1020,10 +1018,7 @@ def test_legacy_architectures(name, config):
return mykb
entity_linker = nlp.add_pipe(name, config={"model": config})
if config["@architectures"] == "spacy.EntityLinker.v1":
assert isinstance(entity_linker, EntityLinker_v1)
else:
assert isinstance(entity_linker, EntityLinker)
assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb)
optimizer = nlp.initialize(get_examples=lambda: train_examples)

View File

@ -9,6 +9,7 @@ from spacy.lang.en import English
from spacy.lang.en.syntax_iterators import noun_chunks
from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.strings import StringStore
from spacy.tokens import Doc
from spacy.training import Example
from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
@ -131,7 +132,7 @@ def test_issue5458():
# Test that the noun chuncker does not generate overlapping spans
# fmt: off
words = ["In", "an", "era", "where", "markets", "have", "brought", "prosperity", "and", "empowerment", "."]
vocab = Vocab(strings=words)
vocab = Vocab(strings=StringStore(words))
deps = ["ROOT", "det", "pobj", "advmod", "nsubj", "aux", "relcl", "dobj", "cc", "conj", "punct"]
pos = ["ADP", "DET", "NOUN", "ADV", "NOUN", "AUX", "VERB", "NOUN", "CCONJ", "NOUN", "PUNCT"]
heads = [0, 2, 0, 9, 6, 6, 2, 6, 7, 7, 0]

View File

@ -540,3 +540,86 @@ def test_tok2vec_listeners_textcat():
assert cats1["imperative"] < 0.9
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
cfg_string_distillation = """
[nlp]
lang = "en"
pipeline = ["tok2vec","tagger"]
[components]
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v2"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v2"
width = ${components.tok2vec.model.encode.width}
rows = [2000, 1000, 1000, 1000]
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
"""
def test_tok2vec_distillation_teacher_annotations():
orig_config = Config().from_str(cfg_string_distillation)
teacher_nlp = util.load_model_from_config(
orig_config, auto_fill=True, validate=True
)
student_nlp = util.load_model_from_config(
orig_config, auto_fill=True, validate=True
)
train_examples_teacher = []
train_examples_student = []
for t in TRAIN_DATA:
train_examples_teacher.append(
Example.from_dict(teacher_nlp.make_doc(t[0]), t[1])
)
train_examples_student.append(
Example.from_dict(student_nlp.make_doc(t[0]), t[1])
)
optimizer = teacher_nlp.initialize(lambda: train_examples_teacher)
student_nlp.initialize(lambda: train_examples_student)
# Since Language.distill creates a copy of the examples to use as
# its internal teacher/student docs, we'll need to monkey-patch the
# tok2vec pipe's distill method.
student_tok2vec = student_nlp.get_pipe("tok2vec")
student_tok2vec._old_distill = student_tok2vec.distill
def tok2vec_distill_wrapper(
self,
teacher_pipe,
examples,
**kwargs,
):
assert all(not eg.reference.tensor.any() for eg in examples)
out = self._old_distill(teacher_pipe, examples, **kwargs)
assert all(eg.reference.tensor.any() for eg in examples)
return out
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
student_nlp.distill(teacher_nlp, train_examples_student, sgd=optimizer, losses={})

View File

@ -67,20 +67,20 @@ def _check_kb(kb):
# check candidates & probabilities
candidates = sorted(
kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_str
kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_
)
assert len(candidates) == 2
assert candidates[0].entity_id_str == "Q007"
assert candidates[0].entity_id_ == "Q007"
assert 6.999 < candidates[0].entity_freq < 7.01
assert candidates[0].entity_vector == [0, 0, 7]
assert candidates[0].mention == "double07"
assert candidates[0].alias == "double07"
assert 0.899 < candidates[0].prior_prob < 0.901
assert candidates[1].entity_id_str == "Q17"
assert candidates[1].entity_id_ == "Q17"
assert 1.99 < candidates[1].entity_freq < 2.01
assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].mention == "double07"
assert candidates[1].alias == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101

View File

@ -13,8 +13,11 @@ from spacy.vocab import Vocab
from ..util import make_tempdir
test_strings = [([], []), (["rats", "are", "cute"], ["i", "like", "rats"])]
test_strings_attrs = [(["rats", "are", "cute"], "Hello")]
test_strings = [
(StringStore(), StringStore()),
(StringStore(["rats", "are", "cute"]), StringStore(["i", "like", "rats"])),
]
test_strings_attrs = [(StringStore(["rats", "are", "cute"]), "Hello")]
@pytest.mark.issue(599)
@ -81,7 +84,7 @@ def test_serialize_vocab_roundtrip_bytes(strings1, strings2):
vocab2 = Vocab(strings=strings2)
vocab1_b = vocab1.to_bytes()
vocab2_b = vocab2.to_bytes()
if strings1 == strings2:
if strings1.to_bytes() == strings2.to_bytes():
assert vocab1_b == vocab2_b
else:
assert vocab1_b != vocab2_b
@ -117,11 +120,12 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
vocab1 = Vocab(strings=strings)
vocab2 = Vocab()
vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr
s = next(iter(vocab1.strings))
vocab1[s].norm_ = lex_attr
assert vocab1[s].norm_ == lex_attr
assert vocab2[s].norm_ != lex_attr
vocab2 = vocab2.from_bytes(vocab1.to_bytes())
assert vocab2[strings[0]].norm_ == lex_attr
assert vocab2[s].norm_ == lex_attr
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
@ -136,14 +140,15 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1 = Vocab(strings=strings)
vocab2 = Vocab()
vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr
s = next(iter(vocab1.strings))
vocab1[s].norm_ = lex_attr
assert vocab1[s].norm_ == lex_attr
assert vocab2[s].norm_ != lex_attr
with make_tempdir() as d:
file_path = d / "vocab"
vocab1.to_disk(file_path)
vocab2 = vocab2.from_disk(file_path)
assert vocab2[strings[0]].norm_ == lex_attr
assert vocab2[s].norm_ == lex_attr
@pytest.mark.parametrize("strings1,strings2", test_strings)

View File

@ -17,7 +17,7 @@ def test_issue361(en_vocab, text1, text2):
@pytest.mark.issue(600)
def test_issue600():
vocab = Vocab(tag_map={"NN": {"pos": "NOUN"}})
vocab = Vocab()
doc = Doc(vocab, words=["hello"])
doc[0].tag_ = "NN"

View File

@ -26,7 +26,7 @@ class Vocab:
def __init__(
self,
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
strings: Optional[Union[List[str], StringStore]] = ...,
strings: Optional[StringStore] = ...,
lookups: Optional[Lookups] = ...,
oov_prob: float = ...,
writing_system: Dict[str, Any] = ...,

View File

@ -49,9 +49,8 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab
"""
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
oov_prob=-20., writing_system={}, get_noun_chunks=None,
**deprecated_kwargs):
def __init__(self, lex_attr_getters=None, strings=None, lookups=None,
oov_prob=-20., writing_system=None, get_noun_chunks=None):
"""Create the vocabulary.
lex_attr_getters (dict): A dictionary mapping attribute IDs to
@ -69,16 +68,19 @@ cdef class Vocab:
self.cfg = {'oov_prob': oov_prob}
self.mem = Pool()
self._by_orth = PreshMap()
self.strings = StringStore()
self.length = 0
if strings:
for string in strings:
_ = self[string]
if strings is None:
self.strings = StringStore()
else:
self.strings = strings
self.lex_attr_getters = lex_attr_getters
self.morphology = Morphology(self.strings)
self.vectors = Vectors(strings=self.strings)
self.lookups = lookups
self.writing_system = writing_system
if writing_system is None:
self.writing_system = {}
else:
self.writing_system = writing_system
self.get_noun_chunks = get_noun_chunks
property vectors:

View File

@ -192,29 +192,14 @@ to you.
> from spacy.tokens import SpanGroup
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates(SpanGroup(doc, spans=[doc[0:2], doc[3:]])
> candidates = kb.get_candidates_batch([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
> ```
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------------------------------ |
| `mentions` | The textual mention or alias. ~~Iterable[Span]~~ |
| `mentions` | The textual mentions. ~~Iterable[SpanGroup]~~ |
| **RETURNS** | An iterable of iterable with relevant `InMemoryCandidate` objects. ~~Iterable[Iterable[InMemoryCandidate]]~~ |
## InMemoryLookupKB.get_alias_candidates {id="get_alias_candidates",tag="method"}
Given a certain textual mention as input, retrieve a list of candidate entities
of type [`InMemoryCandidate`](/api/kb#candidate).
> #### Example
>
> ```python
> candidates = kb.get_alias_candidates("Douglas")
> ```
| Name | Description |
| ----------- | ----------------------------------------------------------------------------- |
| `alias` | The textual mention or alias. ~~str~~ |
| **RETURNS** | The list of relevant `InMemoryCandidate` objects. ~~List[InMemoryCandidate]~~ |
## InMemoryLookupKB.get_vector {id="get_vector",tag="method"}

View File

@ -96,12 +96,12 @@ to you.
> from spacy.tokens import SpanGroup
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates(SpanGroup(doc, spans=[doc[0:2], doc[3:]])
> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
> ```
| Name | Description |
| ----------- | -------------------------------------------------------------------------------------------- |
| `mentions` | The textual mention or alias. ~~SpanGroup~~ |
| `mentions` | The textual mentions. ~~Iterable[SpanGroup]~~ |
| **RETURNS** | An iterable of iterable with relevant `Candidate` objects. ~~Iterable[Iterable[Candidate]]~~ |
## KnowledgeBase.get_vector {id="get_vector",tag="method"}
@ -176,11 +176,11 @@ Restore the state of the knowledge base from a given directory. Note that the
## InMemoryCandidate {id="candidate",tag="class"}
A `InMemoryCandidate` object refers to a textual mention (alias) that may or may
not be resolved to a specific entity from a `KnowledgeBase`. This will be used
as input for the entity linking algorithm which will disambiguate the various
candidates to the correct one. Each candidate `(alias, entity)` pair is assigned
to a certain prior probability.
An `InMemoryCandidate` object refers to a textual mention (alias) that may or
may not be resolved to a specific entity from a `KnowledgeBase`. This will be
used as input for the entity linking algorithm which will disambiguate the
various candidates to the correct one. Each candidate `(alias, entity)` pair is
assigned to a certain prior probability.
### InMemoryCandidate.\_\_init\_\_ {id="candidate-init",tag="method"}
@ -188,22 +188,20 @@ Construct an `InMemoryCandidate` object. Usually this constructor is not called
directly, but instead these objects are returned by the `get_candidates` method
of the [`entity_linker`](/api/entitylinker) pipe.
> #### Example```python
> #### Example
>
> ```python
> from spacy.kb import InMemoryCandidate candidate = InMemoryCandidate(kb,
> entity_hash, entity_freq, entity_vector, mention_hash, prior_prob)
>
> ```
>
> entity_hash, entity_freq, entity_vector, alias_hash, prior_prob)
> ```
| Name | Description |
| -------------- | ------------------------------------------------------------------------- |
| `kb` | The knowledge base that defined this candidate. ~~KnowledgeBase~~ |
| `entity_hash` | The hash of the entity's KB ID. ~~int~~ |
| `entity_freq` | The entity frequency as recorded in the KB. ~~float~~ |
| `mention_hash` | The hash of the textual mention. ~~int~~ |
| `prior_prob` | The prior probability of the `alias` referring to the `entity`. ~~float~~ |
| Name | Description |
| ------------- | ------------------------------------------------------------------------- |
| `kb` | The knowledge base that defined this candidate. ~~KnowledgeBase~~ |
| `entity_hash` | The hash of the entity's KB ID. ~~int~~ |
| `entity_freq` | The entity frequency as recorded in the KB. ~~float~~ |
| `alias_hash` | The hash of the entity alias. ~~int~~ |
| `prior_prob` | The prior probability of the `alias` referring to the `entity`. ~~float~~ |
## InMemoryCandidate attributes {id="candidate-attributes"}

View File

@ -100,6 +100,43 @@ pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
## Tok2Vec.distill {id="distill", tag="method,experimental", version="4"}
Performs an update of the student pipe's model using the student's distillation
examples and sets the annotations of the teacher's distillation examples using
the teacher pipe.
Unlike other trainable pipes, the student pipe doesn't directly learn its
representations from the teacher. However, since downstream pipes that do
perform distillation expect the tok2vec annotations to be present on the
correct distillation examples, we need to ensure that they are set beforehand.
The distillation is performed on ~~Example~~ objects. The `Example.reference`
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
same orthography. Even though the reference does not need have to have gold
annotations, the teacher could adds its own annotations when necessary.
This feature is experimental.
> #### Example
>
> ```python
> teacher_pipe = teacher.add_pipe("tok2vec")
> student_pipe = student.add_pipe("tok2vec")
> optimizer = nlp.resume_training()
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
> ```
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to use for prediction. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Tok2Vec.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood

View File

@ -17,14 +17,15 @@ Create the vocabulary.
> #### Example
>
> ```python
> from spacy.strings import StringStore
> from spacy.vocab import Vocab
> vocab = Vocab(strings=["hello", "world"])
> vocab = Vocab(strings=StringStore(["hello", "world"]))
> ```
| Name | Description |
| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values. ~~Optional[StringStore]~~ |
| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |