mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 16:52:21 +03:00
Merge branch 'v4' into feature/docwise-generator-batching
# Conflicts: # spacy/kb/kb.pyx # spacy/kb/kb_in_memory.pyx # spacy/ml/models/entity_linker.py # spacy/pipeline/entity_linker.py # spacy/tests/pipeline/test_entity_linker.py # website/docs/api/inmemorylookupkb.mdx # website/docs/api/kb.mdx
This commit is contained in:
commit
e5be5d6092
|
@ -36,13 +36,25 @@ cdef class KnowledgeBase:
|
||||||
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
|
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
|
||||||
probability of the specified mention text resolving to that entity - might be included.
|
probability of the specified mention text resolving to that entity - might be included.
|
||||||
If no candidates are found for a given mention, an empty list is returned.
|
If no candidates are found for a given mention, an empty list is returned.
|
||||||
mentions (Iterable[SpangGroup]): Mentions for which to get candidates.
|
mentions (Iterable[SpanGroup]): Mentions for which to get candidates.
|
||||||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
|
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||||
|
"""
|
||||||
|
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
|
||||||
|
and the prior probability of that alias resolving to that entity.
|
||||||
|
If the no candidate is found for a given text, an empty list is returned.
|
||||||
|
mention (Span): Mention for which to get candidates.
|
||||||
|
RETURNS (Iterable[Candidate]): Identified candidates.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
|
||||||
|
)
|
||||||
|
|
||||||
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]:
|
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]:
|
||||||
"""
|
"""
|
||||||
Return vectors for entities.
|
Return vectors for entities.
|
||||||
|
|
|
@ -5,8 +5,7 @@ from thinc.api import chain, list2ragged, reduce_mean, residual
|
||||||
from thinc.api import Model, Maxout, Linear, tuplify, Ragged
|
from thinc.api import Model, Maxout, Linear, tuplify, Ragged
|
||||||
|
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ...kb import KnowledgeBase, InMemoryLookupKB
|
from ...kb import KnowledgeBase, InMemoryLookupKB, Candidate
|
||||||
from ...kb import Candidate
|
|
||||||
from ...vocab import Vocab
|
from ...vocab import Vocab
|
||||||
from ...tokens import Doc, Span, SpanGroup
|
from ...tokens import Doc, Span, SpanGroup
|
||||||
from ..extract_spans import extract_spans
|
from ..extract_spans import extract_spans
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
|
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any, Tuple
|
||||||
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
||||||
|
from thinc.types import Floats2d
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
|
@ -157,39 +158,9 @@ class Tok2Vec(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tok2vec#update
|
DOCS: https://spacy.io/api/tok2vec#update
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
|
||||||
losses = {}
|
|
||||||
validate_examples(examples, "Tok2Vec.update")
|
validate_examples(examples, "Tok2Vec.update")
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
set_dropout_rate(self.model, drop)
|
return self._update_with_docs(docs, drop=drop, sgd=sgd, losses=losses)
|
||||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
|
||||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
|
|
||||||
def accumulate_gradient(one_d_tokvecs):
|
|
||||||
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
|
||||||
to all but the last listener. Only the last one does the backprop.
|
|
||||||
"""
|
|
||||||
nonlocal d_tokvecs
|
|
||||||
for i in range(len(one_d_tokvecs)):
|
|
||||||
d_tokvecs[i] += one_d_tokvecs[i]
|
|
||||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
|
||||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
|
||||||
|
|
||||||
def backprop(one_d_tokvecs):
|
|
||||||
"""Callback to actually do the backprop. Passed to last listener."""
|
|
||||||
accumulate_gradient(one_d_tokvecs)
|
|
||||||
d_docs = bp_tokvecs(d_tokvecs)
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
return d_docs
|
|
||||||
|
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
|
||||||
for listener in self.listeners[:-1]:
|
|
||||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
|
||||||
if self.listeners:
|
|
||||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def get_loss(self, examples, scores) -> None:
|
def get_loss(self, examples, scores) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -219,6 +190,96 @@ class Tok2Vec(TrainablePipe):
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def distill(
|
||||||
|
self,
|
||||||
|
teacher_pipe: Optional["TrainablePipe"],
|
||||||
|
examples: Iterable["Example"],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Performs an update of the student pipe's model using the
|
||||||
|
student's distillation examples and sets the annotations
|
||||||
|
of the teacher's distillation examples using the teacher pipe.
|
||||||
|
|
||||||
|
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to use
|
||||||
|
for prediction.
|
||||||
|
examples (Iterable[Example]): Distillation examples. The reference (teacher)
|
||||||
|
and predicted (student) docs must have the same number of tokens and the
|
||||||
|
same orthography.
|
||||||
|
drop (float): dropout rate.
|
||||||
|
sgd (Optional[Optimizer]): An optimizer. Will be created via
|
||||||
|
create_optimizer if not set.
|
||||||
|
losses (Optional[Dict[str, float]]): Optional record of loss during
|
||||||
|
distillation.
|
||||||
|
RETURNS: The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#distill
|
||||||
|
"""
|
||||||
|
# By default we require a teacher pipe, but there are downstream
|
||||||
|
# implementations that don't require a pipe.
|
||||||
|
if teacher_pipe is None:
|
||||||
|
raise ValueError(Errors.E4002.format(name=self.name))
|
||||||
|
teacher_docs = [eg.reference for eg in examples]
|
||||||
|
student_docs = [eg.predicted for eg in examples]
|
||||||
|
teacher_preds = teacher_pipe.predict(teacher_docs)
|
||||||
|
teacher_pipe.set_annotations(teacher_docs, teacher_preds)
|
||||||
|
return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)
|
||||||
|
|
||||||
|
def _update_with_docs(
|
||||||
|
self,
|
||||||
|
docs: Iterable[Doc],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
):
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
|
tokvecs, accumulate_gradient, backprop = self._create_backprops(
|
||||||
|
docs, losses, sgd=sgd
|
||||||
|
)
|
||||||
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
|
for listener in self.listeners[:-1]:
|
||||||
|
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||||
|
if self.listeners:
|
||||||
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def _create_backprops(
|
||||||
|
self,
|
||||||
|
docs: Iterable[Doc],
|
||||||
|
losses: Dict[str, float],
|
||||||
|
*,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
) -> Tuple[Floats2d, Callable, Callable]:
|
||||||
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||||
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
|
def accumulate_gradient(one_d_tokvecs):
|
||||||
|
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||||
|
to all but the last listener. Only the last one does the backprop.
|
||||||
|
"""
|
||||||
|
nonlocal d_tokvecs
|
||||||
|
for i in range(len(one_d_tokvecs)):
|
||||||
|
d_tokvecs[i] += one_d_tokvecs[i]
|
||||||
|
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||||
|
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
|
def backprop(one_d_tokvecs):
|
||||||
|
"""Callback to actually do the backprop. Passed to last listener."""
|
||||||
|
accumulate_gradient(one_d_tokvecs)
|
||||||
|
d_docs = bp_tokvecs(d_tokvecs)
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
return d_docs
|
||||||
|
|
||||||
|
return tokvecs, accumulate_gradient, backprop
|
||||||
|
|
||||||
|
|
||||||
class Tok2VecListener(Model):
|
class Tok2VecListener(Model):
|
||||||
"""A layer that gets fed its answers from an upstream connection,
|
"""A layer that gets fed its answers from an upstream connection,
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overlo
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
class StringStore:
|
class StringStore:
|
||||||
def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
|
def __init__(self, strings: Optional[Iterable[str]] = None) -> None: ...
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, string_or_hash: str) -> int: ...
|
def __getitem__(self, string_or_hash: str) -> int: ...
|
||||||
@overload
|
@overload
|
||||||
|
|
|
@ -9,6 +9,7 @@ from spacy.lang.en import English
|
||||||
from spacy.lang.en.syntax_iterators import noun_chunks
|
from spacy.lang.en.syntax_iterators import noun_chunks
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.pipeline import TrainablePipe
|
from spacy.pipeline import TrainablePipe
|
||||||
|
from spacy.strings import StringStore
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
|
from spacy.util import SimpleFrozenList, get_arg_names, make_tempdir
|
||||||
|
@ -131,7 +132,7 @@ def test_issue5458():
|
||||||
# Test that the noun chuncker does not generate overlapping spans
|
# Test that the noun chuncker does not generate overlapping spans
|
||||||
# fmt: off
|
# fmt: off
|
||||||
words = ["In", "an", "era", "where", "markets", "have", "brought", "prosperity", "and", "empowerment", "."]
|
words = ["In", "an", "era", "where", "markets", "have", "brought", "prosperity", "and", "empowerment", "."]
|
||||||
vocab = Vocab(strings=words)
|
vocab = Vocab(strings=StringStore(words))
|
||||||
deps = ["ROOT", "det", "pobj", "advmod", "nsubj", "aux", "relcl", "dobj", "cc", "conj", "punct"]
|
deps = ["ROOT", "det", "pobj", "advmod", "nsubj", "aux", "relcl", "dobj", "cc", "conj", "punct"]
|
||||||
pos = ["ADP", "DET", "NOUN", "ADV", "NOUN", "AUX", "VERB", "NOUN", "CCONJ", "NOUN", "PUNCT"]
|
pos = ["ADP", "DET", "NOUN", "ADV", "NOUN", "AUX", "VERB", "NOUN", "CCONJ", "NOUN", "PUNCT"]
|
||||||
heads = [0, 2, 0, 9, 6, 6, 2, 6, 7, 7, 0]
|
heads = [0, 2, 0, 9, 6, 6, 2, 6, 7, 7, 0]
|
||||||
|
|
|
@ -540,3 +540,86 @@ def test_tok2vec_listeners_textcat():
|
||||||
assert cats1["imperative"] < 0.9
|
assert cats1["imperative"] < 0.9
|
||||||
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
|
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
|
||||||
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
||||||
|
|
||||||
|
|
||||||
|
cfg_string_distillation = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","tagger"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
rows = [2000, 1000, 1000, 1000]
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 96
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_tok2vec_distillation_teacher_annotations():
|
||||||
|
orig_config = Config().from_str(cfg_string_distillation)
|
||||||
|
teacher_nlp = util.load_model_from_config(
|
||||||
|
orig_config, auto_fill=True, validate=True
|
||||||
|
)
|
||||||
|
student_nlp = util.load_model_from_config(
|
||||||
|
orig_config, auto_fill=True, validate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
train_examples_teacher = []
|
||||||
|
train_examples_student = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples_teacher.append(
|
||||||
|
Example.from_dict(teacher_nlp.make_doc(t[0]), t[1])
|
||||||
|
)
|
||||||
|
train_examples_student.append(
|
||||||
|
Example.from_dict(student_nlp.make_doc(t[0]), t[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = teacher_nlp.initialize(lambda: train_examples_teacher)
|
||||||
|
student_nlp.initialize(lambda: train_examples_student)
|
||||||
|
|
||||||
|
# Since Language.distill creates a copy of the examples to use as
|
||||||
|
# its internal teacher/student docs, we'll need to monkey-patch the
|
||||||
|
# tok2vec pipe's distill method.
|
||||||
|
student_tok2vec = student_nlp.get_pipe("tok2vec")
|
||||||
|
student_tok2vec._old_distill = student_tok2vec.distill
|
||||||
|
|
||||||
|
def tok2vec_distill_wrapper(
|
||||||
|
self,
|
||||||
|
teacher_pipe,
|
||||||
|
examples,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert all(not eg.reference.tensor.any() for eg in examples)
|
||||||
|
out = self._old_distill(teacher_pipe, examples, **kwargs)
|
||||||
|
assert all(eg.reference.tensor.any() for eg in examples)
|
||||||
|
return out
|
||||||
|
|
||||||
|
student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec)
|
||||||
|
student_nlp.distill(teacher_nlp, train_examples_student, sgd=optimizer, losses={})
|
||||||
|
|
|
@ -13,8 +13,11 @@ from spacy.vocab import Vocab
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
test_strings = [([], []), (["rats", "are", "cute"], ["i", "like", "rats"])]
|
test_strings = [
|
||||||
test_strings_attrs = [(["rats", "are", "cute"], "Hello")]
|
(StringStore(), StringStore()),
|
||||||
|
(StringStore(["rats", "are", "cute"]), StringStore(["i", "like", "rats"])),
|
||||||
|
]
|
||||||
|
test_strings_attrs = [(StringStore(["rats", "are", "cute"]), "Hello")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(599)
|
@pytest.mark.issue(599)
|
||||||
|
@ -81,7 +84,7 @@ def test_serialize_vocab_roundtrip_bytes(strings1, strings2):
|
||||||
vocab2 = Vocab(strings=strings2)
|
vocab2 = Vocab(strings=strings2)
|
||||||
vocab1_b = vocab1.to_bytes()
|
vocab1_b = vocab1.to_bytes()
|
||||||
vocab2_b = vocab2.to_bytes()
|
vocab2_b = vocab2.to_bytes()
|
||||||
if strings1 == strings2:
|
if strings1.to_bytes() == strings2.to_bytes():
|
||||||
assert vocab1_b == vocab2_b
|
assert vocab1_b == vocab2_b
|
||||||
else:
|
else:
|
||||||
assert vocab1_b != vocab2_b
|
assert vocab1_b != vocab2_b
|
||||||
|
@ -117,11 +120,12 @@ def test_serialize_vocab_roundtrip_disk(strings1, strings2):
|
||||||
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
|
def test_serialize_vocab_lex_attrs_bytes(strings, lex_attr):
|
||||||
vocab1 = Vocab(strings=strings)
|
vocab1 = Vocab(strings=strings)
|
||||||
vocab2 = Vocab()
|
vocab2 = Vocab()
|
||||||
vocab1[strings[0]].norm_ = lex_attr
|
s = next(iter(vocab1.strings))
|
||||||
assert vocab1[strings[0]].norm_ == lex_attr
|
vocab1[s].norm_ = lex_attr
|
||||||
assert vocab2[strings[0]].norm_ != lex_attr
|
assert vocab1[s].norm_ == lex_attr
|
||||||
|
assert vocab2[s].norm_ != lex_attr
|
||||||
vocab2 = vocab2.from_bytes(vocab1.to_bytes())
|
vocab2 = vocab2.from_bytes(vocab1.to_bytes())
|
||||||
assert vocab2[strings[0]].norm_ == lex_attr
|
assert vocab2[s].norm_ == lex_attr
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||||
|
@ -136,14 +140,15 @@ def test_deserialize_vocab_seen_entries(strings, lex_attr):
|
||||||
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
def test_serialize_vocab_lex_attrs_disk(strings, lex_attr):
|
||||||
vocab1 = Vocab(strings=strings)
|
vocab1 = Vocab(strings=strings)
|
||||||
vocab2 = Vocab()
|
vocab2 = Vocab()
|
||||||
vocab1[strings[0]].norm_ = lex_attr
|
s = next(iter(vocab1.strings))
|
||||||
assert vocab1[strings[0]].norm_ == lex_attr
|
vocab1[s].norm_ = lex_attr
|
||||||
assert vocab2[strings[0]].norm_ != lex_attr
|
assert vocab1[s].norm_ == lex_attr
|
||||||
|
assert vocab2[s].norm_ != lex_attr
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
file_path = d / "vocab"
|
file_path = d / "vocab"
|
||||||
vocab1.to_disk(file_path)
|
vocab1.to_disk(file_path)
|
||||||
vocab2 = vocab2.from_disk(file_path)
|
vocab2 = vocab2.from_disk(file_path)
|
||||||
assert vocab2[strings[0]].norm_ == lex_attr
|
assert vocab2[s].norm_ == lex_attr
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
||||||
|
|
|
@ -17,7 +17,7 @@ def test_issue361(en_vocab, text1, text2):
|
||||||
|
|
||||||
@pytest.mark.issue(600)
|
@pytest.mark.issue(600)
|
||||||
def test_issue600():
|
def test_issue600():
|
||||||
vocab = Vocab(tag_map={"NN": {"pos": "NOUN"}})
|
vocab = Vocab()
|
||||||
doc = Doc(vocab, words=["hello"])
|
doc = Doc(vocab, words=["hello"])
|
||||||
doc[0].tag_ = "NN"
|
doc[0].tag_ = "NN"
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class Vocab:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
|
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
|
||||||
strings: Optional[Union[List[str], StringStore]] = ...,
|
strings: Optional[StringStore] = ...,
|
||||||
lookups: Optional[Lookups] = ...,
|
lookups: Optional[Lookups] = ...,
|
||||||
oov_prob: float = ...,
|
oov_prob: float = ...,
|
||||||
writing_system: Dict[str, Any] = ...,
|
writing_system: Dict[str, Any] = ...,
|
||||||
|
|
|
@ -49,9 +49,8 @@ cdef class Vocab:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vocab
|
DOCS: https://spacy.io/api/vocab
|
||||||
"""
|
"""
|
||||||
def __init__(self, lex_attr_getters=None, strings=tuple(), lookups=None,
|
def __init__(self, lex_attr_getters=None, strings=None, lookups=None,
|
||||||
oov_prob=-20., writing_system={}, get_noun_chunks=None,
|
oov_prob=-20., writing_system=None, get_noun_chunks=None):
|
||||||
**deprecated_kwargs):
|
|
||||||
"""Create the vocabulary.
|
"""Create the vocabulary.
|
||||||
|
|
||||||
lex_attr_getters (dict): A dictionary mapping attribute IDs to
|
lex_attr_getters (dict): A dictionary mapping attribute IDs to
|
||||||
|
@ -69,16 +68,19 @@ cdef class Vocab:
|
||||||
self.cfg = {'oov_prob': oov_prob}
|
self.cfg = {'oov_prob': oov_prob}
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._by_orth = PreshMap()
|
self._by_orth = PreshMap()
|
||||||
self.strings = StringStore()
|
|
||||||
self.length = 0
|
self.length = 0
|
||||||
if strings:
|
if strings is None:
|
||||||
for string in strings:
|
self.strings = StringStore()
|
||||||
_ = self[string]
|
else:
|
||||||
|
self.strings = strings
|
||||||
self.lex_attr_getters = lex_attr_getters
|
self.lex_attr_getters = lex_attr_getters
|
||||||
self.morphology = Morphology(self.strings)
|
self.morphology = Morphology(self.strings)
|
||||||
self.vectors = Vectors(strings=self.strings)
|
self.vectors = Vectors(strings=self.strings)
|
||||||
self.lookups = lookups
|
self.lookups = lookups
|
||||||
self.writing_system = writing_system
|
if writing_system is None:
|
||||||
|
self.writing_system = {}
|
||||||
|
else:
|
||||||
|
self.writing_system = writing_system
|
||||||
self.get_noun_chunks = get_noun_chunks
|
self.get_noun_chunks = get_noun_chunks
|
||||||
|
|
||||||
property vectors:
|
property vectors:
|
||||||
|
|
|
@ -81,7 +81,7 @@ implementation of `KnowledgeBase.get_candidates()`.
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ |
|
| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ |
|
||||||
| **RETURNS** | An iterator over iterables of iterables with relevant `Candidate` objects (per mention and doc). ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
|
| **RETURNS** | An iterator (per document) over iterables (per mention) of iterables (per candidate for this mention) with relevant `Candidate` objects. ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
|
||||||
|
|
||||||
## KnowledgeBase.get_vector {id="get_vector",tag="method"}
|
## KnowledgeBase.get_vector {id="get_vector",tag="method"}
|
||||||
|
|
||||||
|
@ -167,13 +167,11 @@ Construct an `InMemoryCandidate` object. Usually this constructor is not called
|
||||||
directly, but instead these objects are returned by the `get_candidates` method
|
directly, but instead these objects are returned by the `get_candidates` method
|
||||||
of the [`entity_linker`](/api/entitylinker) pipe.
|
of the [`entity_linker`](/api/entitylinker) pipe.
|
||||||
|
|
||||||
> #### Example```python
|
> #### Example
|
||||||
>
|
>
|
||||||
|
> ```python
|
||||||
> from spacy.kb import InMemoryCandidate candidate = InMemoryCandidate(kb,
|
> from spacy.kb import InMemoryCandidate candidate = InMemoryCandidate(kb,
|
||||||
> entity_hash, entity_freq, entity_vector, alias_hash, prior_prob)
|
> entity_hash, entity_freq, entity_vector, alias_hash, prior_prob)
|
||||||
>
|
|
||||||
> ```
|
|
||||||
>
|
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|
|
|
@ -100,6 +100,43 @@ pipeline components are applied to the `Doc` in order. Both
|
||||||
| `doc` | The document to process. ~~Doc~~ |
|
| `doc` | The document to process. ~~Doc~~ |
|
||||||
| **RETURNS** | The processed document. ~~Doc~~ |
|
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||||
|
|
||||||
|
## Tok2Vec.distill {id="distill", tag="method,experimental", version="4"}
|
||||||
|
|
||||||
|
Performs an update of the student pipe's model using the student's distillation
|
||||||
|
examples and sets the annotations of the teacher's distillation examples using
|
||||||
|
the teacher pipe.
|
||||||
|
|
||||||
|
Unlike other trainable pipes, the student pipe doesn't directly learn its
|
||||||
|
representations from the teacher. However, since downstream pipes that do
|
||||||
|
perform distillation expect the tok2vec annotations to be present on the
|
||||||
|
correct distillation examples, we need to ensure that they are set beforehand.
|
||||||
|
|
||||||
|
The distillation is performed on ~~Example~~ objects. The `Example.reference`
|
||||||
|
and `Example.predicted` ~~Doc~~s must have the same number of tokens and the
|
||||||
|
same orthography. Even though the reference does not need have to have gold
|
||||||
|
annotations, the teacher could adds its own annotations when necessary.
|
||||||
|
|
||||||
|
This feature is experimental.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> teacher_pipe = teacher.add_pipe("tok2vec")
|
||||||
|
> student_pipe = student.add_pipe("tok2vec")
|
||||||
|
> optimizer = nlp.resume_training()
|
||||||
|
> losses = student.distill(teacher_pipe, examples, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `teacher_pipe` | The teacher pipe to use for prediction. ~~Optional[TrainablePipe]~~ |
|
||||||
|
| `examples` | Distillation examples. The reference (teacher) and predicted (student) docs must have the same number of tokens and the same orthography. ~~Iterable[Example]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `drop` | Dropout rate. ~~float~~ |
|
||||||
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||||
|
| `losses` | Optional record of the loss during distillation. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
## Tok2Vec.pipe {id="pipe",tag="method"}
|
## Tok2Vec.pipe {id="pipe",tag="method"}
|
||||||
|
|
||||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||||
|
|
|
@ -17,14 +17,15 @@ Create the vocabulary.
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
|
> from spacy.strings import StringStore
|
||||||
> from spacy.vocab import Vocab
|
> from spacy.vocab import Vocab
|
||||||
> vocab = Vocab(strings=["hello", "world"])
|
> vocab = Vocab(strings=StringStore(["hello", "world"]))
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
|
| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
|
||||||
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
|
| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values. ~~Optional[StringStore]~~ |
|
||||||
| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
|
| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
|
||||||
| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
|
| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
|
||||||
| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
|
| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user