Merge branch 'v4' into feature/docwise-generator-batching

# Conflicts:
#	spacy/kb/kb.pyx
#	spacy/kb/kb_in_memory.pyx
#	spacy/ml/models/entity_linker.py
#	spacy/pipeline/entity_linker.py
#	spacy/tests/pipeline/test_entity_linker.py
#	website/docs/api/inmemorylookupkb.mdx
#	website/docs/api/kb.mdx
This commit is contained in:
Raphael Mitsch 2023-03-20 10:50:54 +01:00
commit e5be5d6092
13 changed files with 264 additions and 65 deletions

View File

@ -36,13 +36,25 @@ cdef class KnowledgeBase:
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
If no candidates are found for a given mention, an empty list is returned.
mentions (Iterable[SpangGroup]): Mentions for which to get candidates.
mentions (Iterable[SpanGroup]): Mentions for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
)
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
"""
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If the no candidate is found for a given text, an empty list is returned.
mention (Span): Mention for which to get candidates.
RETURNS (Iterable[Candidate]): Identified candidates.
"""
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
)
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]:
"""
Return vectors for entities.

View File

@ -5,8 +5,7 @@ from thinc.api import chain, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear, tuplify, Ragged
from ...util import registry
from ...kb import KnowledgeBase, InMemoryLookupKB
from ...kb import Candidate
from ...kb import KnowledgeBase, InMemoryLookupKB, Candidate
from ...vocab import Vocab
from ...tokens import Doc, Span, SpanGroup
from ..extract_spans import extract_spans

View File

@ -1,5 +1,6 @@
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any, Tuple
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from thinc.types import Floats2d
from itertools import islice
from .trainable_pipe import TrainablePipe
@ -157,39 +158,9 @@ class Tok2Vec(TrainablePipe):
DOCS: https://spacy.io/api/tok2vec#update
"""
if losses is None:
losses = {}
validate_examples(examples, "Tok2Vec.update")
docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
losses.setdefault(self.name, 0.0)
def accumulate_gradient(one_d_tokvecs):
"""Accumulate tok2vec loss and gradient. This is passed as a callback
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return d_docs
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, accumulate_gradient)
if self.listeners:
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
return self._update_with_docs(docs, drop=drop, sgd=sgd, losses=losses)
def get_loss(self, examples, scores) -> None:
pass
@ -219,6 +190,96 @@ class Tok2Vec(TrainablePipe):
def add_label(self, label):
raise NotImplementedError
def distill(
self,
teacher_pipe: Optional["TrainablePipe"],
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Performs an update of the student pipe's model using the
student's distillation examples and sets the annotations
of the teacher's distillation examples using the teacher pipe.
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to use
for prediction.
examples (Iterable[Example]): Distillation examples. The reference (teacher)
and predicted (student) docs must have the same number of tokens and the
same orthography.
drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
DOCS: https://spacy.io/api/tok2vec#distill
"""
# By default we require a teacher pipe, but there are downstream
# implementations that don't require a pipe.
if teacher_pipe is None:
raise ValueError(Errors.E4002.format(name=self.name))
teacher_docs = [eg.reference for eg in examples]
student_docs = [eg.predicted for eg in examples]
teacher_preds = teacher_pipe.predict(teacher_docs)
teacher_pipe.set_annotations(teacher_docs, teacher_preds)
return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)
def _update_with_docs(
self,
docs: Iterable[Doc],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
):
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
set_dropout_rate(self.model, drop)
tokvecs, accumulate_gradient, backprop = self._create_backprops(
docs, losses, sgd=sgd
)
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, accumulate_gradient)
if self.listeners:
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
def _create_backprops(
self,
docs: Iterable[Doc],
losses: Dict[str, float],
*,
sgd: Optional[Optimizer] = None,
) -> Tuple[Floats2d, Callable, Callable]:
tokvecs, bp_tokvecs = self.model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def accumulate_gradient(one_d_tokvecs):
"""Accumulate tok2vec loss and gradient. This is passed as a callback
to all but the last listener. Only the last one does the backprop.
"""
nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return d_docs
return tokvecs, accumulate_gradient, backprop
class Tok2VecListener(Model):
"""A layer that gets fed its answers from an upstream connection,

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overlo
from pathlib import Path
class StringStore:
def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
def __init__(self, strings: Optional[Iterable[str]] = None) -> None: ...
@overload
def __getitem__(self, string_or_hash: str) -> int: ...
@overload

View File

@ -9,6 +9,7 @@ from spacy.lang.en import English
from spacy.lang.en.syntax_iterators import noun_chunks
from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.strings import StringStore
from spacy.tokens import Doc
from spacy.training import Example
from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
@ -131,7 +132,7 @@ def test_issue5458():
# Test that the noun chuncker does not generate overlapping spans
# fmt: off
words = ["In", "an", "era", "where", "markets", "have", "brought", "prosperity", "and", "empowerment", "."]
vocab = Vocab(strings=words)
vocab = Vocab(strings=StringStore(words))
deps = ["ROOT", "det", "pobj", "advmod", "nsubj", "aux", "relcl", "dobj", "cc", "conj", "punct"]
pos = ["ADP", "DET", "NOUN", "ADV", "NOUN", "AUX", "VERB", "NOUN", "CCONJ", "NOUN", "PUNCT"]
heads = [0, 2, 0, 9, 6, 6, 2, 6, 7, 7, 0]

View File

@ -540,3 +540,86 @@ def test_tok2vec_listeners_textcat():
assert cats1["imperative"] < 0.9
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
cfg_string_distillation = """
[nlp]
lang = "en"
pipeline = ["tok2vec","tagger"]
[components]
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v2"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v2"
width = ${components.tok2vec.model.encode.width}
rows = [2000, 1000, 1000, 1000]
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
"""
def test_tok2vec_distillation_teacher_annotations():
orig_config = Config().from_str(cfg_string_distillation)
teacher_nlp = util.load_model_from_config(
orig_config, auto_fill=True, validate=True
)
student_nlp = util.load_model_from_config(
orig_config, auto_fill=True, validate=True
)
train_examples_teacher = []
train_examples_student = []
for t in TRAIN_DATA:
train_examples_teacher.append(
Example.from_dict(teacher_nlp.make_doc(t[0]), t[1])
)
train_examples_student.append(
Example.from_dict(student_nlp.make_doc(t[0]), t[1])
)
optimizer = teacher_nlp.initialize(lambda: train_examples_teacher)
student_nlp.initialize(lambda: train_examples_student)
# Since Language.distill creates a copy of the examples to use as
# its internal teacher/student docs, we'll need to monkey-patch the
# tok2vec pipe's distill method.
student_tok2vec = student_nlp.get_pipe("tok2vec")
student_tok2vec._old_distill = student_tok2vec.distill
def tok2vec_distill_wrapper(
self,
teacher_pipe,
examples,
**kwargs,
):
assert all(not eg.reference.tensor.any() for eg in examples)
out = self._old_distill(teacher_pipe, examples, **kwargs)
assert all(eg.reference.tensor.any() for eg in examples)
return out
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
student_nlp.distill(teacher_nlp, train_examples_student, sgd=optimizer, losses={})

View File

@ -13,8 +13,11 @@ from spacy.vocab import Vocab
from ..util import make_tempdir
test_strings = [([], []), (["rats", "are", "cute"], ["i", "like", "rats"])]
test_strings_attrs = [(["rats", "are", "cute"], "Hello")]
test_strings = [
(StringStore(), StringStore()),
(StringStore(["rats", "are", "cute"]), StringStore(["i", "like", "rats"])),
]
test_strings_attrs = [(StringStore(["rats", "are", "cute"]), "Hello")]
@pytest.mark.issue(599)
@ -81,7 +84,7 @@ def test_serialize_vocab_roundtrip_bytes(strings1, strings2):
vocab2 = Vocab(strings=strings2)
vocab1_b = vocab1.to_bytes()
vocab2_b = vocab2.to_bytes()
if strings1 == strings2:
if strings1.to_bytes() == strings2.to_bytes():
assert vocab1_b == vocab2_b
else:
assert vocab1_b != vocab2_b
@ -117,11 +120,12 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
vocab1 = Vocab(strings=strings)
vocab2 = Vocab()
vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr
s = next(iter(vocab1.strings))
vocab1[s].norm_ = lex_attr
assert vocab1[s].norm_ == lex_attr
assert vocab2[s].norm_ != lex_attr
vocab2 = vocab2.from_bytes(vocab1.to_bytes())
assert vocab2[strings[0]].norm_ == lex_attr
assert vocab2[s].norm_ == lex_attr
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
@ -136,14 +140,15 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
vocab1 = Vocab(strings=strings)
vocab2 = Vocab()
vocab1[strings[0]].norm_ = lex_attr
assert vocab1[strings[0]].norm_ == lex_attr
assert vocab2[strings[0]].norm_ != lex_attr
s = next(iter(vocab1.strings))
vocab1[s].norm_ = lex_attr
assert vocab1[s].norm_ == lex_attr
assert vocab2[s].norm_ != lex_attr
with make_tempdir() as d:
file_path = d / "vocab"
vocab1.to_disk(file_path)
vocab2 = vocab2.from_disk(file_path)
assert vocab2[strings[0]].norm_ == lex_attr
assert vocab2[s].norm_ == lex_attr
@pytest.mark.parametrize("strings1,strings2", test_strings)

View File

@ -17,7 +17,7 @@ def test_issue361(en_vocab, text1, text2):
@pytest.mark.issue(600)
def test_issue600():
vocab = Vocab(tag_map={"NN": {"pos": "NOUN"}})
vocab = Vocab()
doc = Doc(vocab, words=["hello"])
doc[0].tag_ = "NN"

View File

@ -26,7 +26,7 @@ class Vocab:
def __init__(
self,
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
strings: Optional[Union[List[str], StringStore]] = ...,
strings: Optional[StringStore] = ...,
lookups: Optional[Lookups] = ...,
oov_prob: float = ...,
writing_system: Dict[str, Any] = ...,

View File

@ -49,9 +49,8 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab
"""
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
oov_prob=-20., writing_system={}, get_noun_chunks=None,
**deprecated_kwargs):
def __init__(self, lex_attr_getters=None, strings=None, lookups=None,
oov_prob=-20., writing_system=None, get_noun_chunks=None):
"""Create the vocabulary.
lex_attr_getters (dict): A dictionary mapping attribute IDs to
@ -69,16 +68,19 @@ cdef class Vocab:
self.cfg = {'oov_prob': oov_prob}
self.mem = Pool()
self._by_orth = PreshMap()
self.strings = StringStore()
self.length = 0
if strings:
for string in strings:
_ = self[string]
if strings is None:
self.strings = StringStore()
else:
self.strings = strings
self.lex_attr_getters = lex_attr_getters
self.morphology = Morphology(self.strings)
self.vectors = Vectors(strings=self.strings)
self.lookups = lookups
self.writing_system = writing_system
if writing_system is None:
self.writing_system = {}
else:
self.writing_system = writing_system
self.get_noun_chunks = get_noun_chunks
property vectors:

View File

@ -81,7 +81,7 @@ implementation of `KnowledgeBase.get_candidates()`.
| Name | Description |
| ----------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ |
| **RETURNS** | An iterator over iterables of iterables with relevant `Candidate` objects (per mention and doc). ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
| **RETURNS** | An iterator (per document) over iterables (per mention) of iterables (per candidate for this mention) with relevant `Candidate` objects. ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
## KnowledgeBase.get_vector {id="get_vector",tag="method"}
@ -167,13 +167,11 @@ Construct an `InMemoryCandidate` object. Usually this constructor is not called
directly, but instead these objects are returned by the `get_candidates` method
of the [`entity_linker`](/api/entitylinker) pipe.
> #### Example```python
> #### Example
>
> ```python
> from spacy.kb import InMemoryCandidate candidate = InMemoryCandidate(kb,
> entity_hash, entity_freq, entity_vector, alias_hash, prior_prob)
>
> ```
>
> ```
| Name | Description |

View File

@ -100,6 +100,43 @@ pipeline components are applied to the `Doc` in order. Both
| `doc` | The document to process. ~~Doc~~ |
| **RETURNS** | The processed document. ~~Doc~~ |
## Tok2Vec.distill {id="distill", tag="method,experimental", version="4"}
Performs an update of the student pipe's model using the student's distillation
examples and sets the annotations of the teacher's distillation examples using
the teacher pipe.
Unlike other trainable pipes, the student pipe doesn't directly learn its
representations from the teacher. However, since downstream pipes that do
perform distillation expect the tok2vec annotations to be present on the
correct distillation examples, we need to ensure that they are set beforehand.
The distillation is performed on ~~Example~~ objects. The `Example.reference`
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
same orthography. Even though the reference does not need have to have gold
annotations, the teacher could adds its own annotations when necessary.
This feature is experimental.
> #### Example
>
> ```python
> teacher_pipe = teacher.add_pipe("tok2vec")
> student_pipe = student.add_pipe("tok2vec")
> optimizer = nlp.resume_training()
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
> ```
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
| `teacher_pipe` | The teacher pipe to use for prediction. ~~Optional[TrainablePipe]~~ |
| `examples` | Distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | Dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Tok2Vec.pipe {id="pipe",tag="method"}
Apply the pipe to a stream of documents. This usually happens under the hood

View File

@ -17,14 +17,15 @@ Create the vocabulary.
> #### Example
>
> ```python
> from spacy.strings import StringStore
> from spacy.vocab import Vocab
> vocab = Vocab(strings=["hello", "world"])
> vocab = Vocab(strings=StringStore(["hello", "world"]))
> ```
| Name | Description |
| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values. ~~Optional[StringStore]~~ |
| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |