2022-03-28 12:13:50 +03:00
|
|
|
import pickle
|
2023-06-26 12:41:03 +03:00
|
|
|
from typing import cast
|
|
|
|
|
|
|
|
import hypothesis.strategies as st
|
2022-03-28 12:13:50 +03:00
|
|
|
import pytest
|
|
|
|
from hypothesis import given
|
2023-06-26 12:41:03 +03:00
|
|
|
|
2022-03-28 12:13:50 +03:00
|
|
|
from spacy import util
|
|
|
|
from spacy.lang.en import English
|
|
|
|
from spacy.language import Language
|
|
|
|
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
|
Store activations in `Doc`s when `save_activations` is enabled (#11002)
* Store activations in Doc when `store_activations` is enabled
This change adds the new `activations` attribute to `Doc`. This
attribute can be used by trainable pipes to store their activations,
probabilities, and guesses for downstream users.
As an example, this change modifies the `tagger` and `senter` pipes to
add an `store_activations` option. When this option is enabled, the
probabilities and guesses are stored in `set_annotations`.
* Change type of `store_activations` to `Union[bool, List[str]]`
When the value is:
- A bool: all activations are stored when set to `True`.
- A List[str]: the activations named in the list are stored
* Formatting fixes in Tagger
* Support store_activations in spancat and morphologizer
* Make Doc.activations type visible to MyPy
* textcat/textcat_multilabel: add store_activations option
* trainable_lemmatizer/entity_linker: add store_activations option
* parser/ner: do not currently support returning activations
* Extend tagger and senter tests
So that they, like the other tests, also check that we get no
activations if no activations were requested.
* Document `Doc.activations` and `store_activations` in the relevant pipes
* Start errors/warnings at higher numbers to avoid merge conflicts
Between the master and v4 branches.
* Add `store_activations` to docstrings.
* Replace store_activations setter by set_store_activations method
Setters that take a different type than what the getter returns are still
problematic for MyPy. Replace the setter by a method, so that type inference
works everywhere.
* Use dict comprehension suggested by @svlandeg
* Revert "Use dict comprehension suggested by @svlandeg"
This reverts commit 6e7b958f7060397965176c69649e5414f1f24988.
* EntityLinker: add type annotations to _add_activations
* _store_activations: make kwarg-only, remove doc_scores_lens arg
* set_annotations: add type annotations
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* TextCat.predict: return dict
* Make the `TrainablePipe.store_activations` property a bool
This means that we can also bring back `store_activations` setter.
* Remove `TrainablePipe.activations`
We do not need to enumerate the activations anymore since `store_activations` is
`bool`.
* Add type annotations for activations in predict/set_annotations
* Rename `TrainablePipe.store_activations` to `save_activations`
* Error E1400 is not used anymore
This error was used when activations were still `Union[bool, List[str]]`.
* Change wording in API docs after store -> save change
* docs: tag (save_)activations as new in spaCy 4.0
* Fix copied line in morphologizer activations test
* Don't train in any test_save_activations test
* Rename activations
- "probs" -> "probabilities"
- "guesses" -> "label_ids", except in the edit tree lemmatizer, where
"guesses" -> "tree_ids".
* Remove unused W400 warning.
This warning was used when we still allowed the user to specify
which activations to save.
* Formatting fixes
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Replace "kb_ids" by a constant
* spancat: replace a cast by an assertion
* Fix EOF spacing
* Fix comments in test_save_activations tests
* Do not set RNG seed in activation saving tests
* Revert "spancat: replace a cast by an assertion"
This reverts commit 0bd5730d16432443a2b247316928d4f789ad8741.
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2022-09-13 10:51:12 +03:00
|
|
|
from spacy.pipeline.trainable_pipe import TrainablePipe
|
2022-03-28 12:13:50 +03:00
|
|
|
from spacy.strings import StringStore
|
2023-06-26 12:41:03 +03:00
|
|
|
from spacy.training import Example
|
2022-03-28 12:13:50 +03:00
|
|
|
from spacy.util import make_tempdir
|
|
|
|
|
|
|
|
TRAIN_DATA = [
|
|
|
|
("She likes green eggs", {"lemmas": ["she", "like", "green", "egg"]}),
|
|
|
|
("Eat blue ham", {"lemmas": ["eat", "blue", "ham"]}),
|
|
|
|
]
|
|
|
|
|
|
|
|
PARTIAL_DATA = [
|
|
|
|
# partial annotation
|
|
|
|
("She likes green eggs", {"lemmas": ["", "like", "green", ""]}),
|
|
|
|
# misaligned partial annotation
|
|
|
|
(
|
|
|
|
"He hates green eggs",
|
|
|
|
{
|
|
|
|
"words": ["He", "hat", "es", "green", "eggs"],
|
|
|
|
"lemmas": ["", "hat", "e", "green", ""],
|
|
|
|
},
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def test_initialize_examples():
|
|
|
|
nlp = Language()
|
|
|
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
|
|
|
train_examples = []
|
|
|
|
for t in TRAIN_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
# you shouldn't really call this more than once, but for testing it should be fine
|
|
|
|
nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
nlp.initialize(get_examples=lambda: None)
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
nlp.initialize(get_examples=lambda: train_examples[0])
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
nlp.initialize(get_examples=lambda: [])
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
nlp.initialize(get_examples=train_examples)
|
|
|
|
|
|
|
|
|
|
|
|
def test_initialize_from_labels():
|
|
|
|
nlp = Language()
|
|
|
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
|
|
|
lemmatizer.min_tree_freq = 1
|
|
|
|
train_examples = []
|
|
|
|
for t in TRAIN_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
|
|
|
|
nlp2 = Language()
|
|
|
|
lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
|
|
|
|
lemmatizer2.initialize(
|
2022-12-07 07:53:41 +03:00
|
|
|
# We want to check that the strings in replacement nodes are
|
|
|
|
# added to the string store. Avoid that they get added through
|
|
|
|
# the examples.
|
|
|
|
get_examples=lambda: train_examples[:1],
|
2022-03-28 12:13:50 +03:00
|
|
|
labels=lemmatizer.label_data,
|
|
|
|
)
|
|
|
|
assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3}
|
2022-12-07 07:53:41 +03:00
|
|
|
assert lemmatizer2.label_data == {
|
|
|
|
"trees": [
|
|
|
|
{"orig": "S", "subst": "s"},
|
|
|
|
{
|
|
|
|
"prefix_len": 1,
|
|
|
|
"suffix_len": 0,
|
|
|
|
"prefix_tree": 0,
|
|
|
|
"suffix_tree": 4294967295,
|
|
|
|
},
|
|
|
|
{"orig": "s", "subst": ""},
|
|
|
|
{
|
|
|
|
"prefix_len": 0,
|
|
|
|
"suffix_len": 1,
|
|
|
|
"prefix_tree": 4294967295,
|
|
|
|
"suffix_tree": 2,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"prefix_len": 0,
|
|
|
|
"suffix_len": 0,
|
|
|
|
"prefix_tree": 4294967295,
|
|
|
|
"suffix_tree": 4294967295,
|
|
|
|
},
|
|
|
|
{"orig": "E", "subst": "e"},
|
|
|
|
{
|
|
|
|
"prefix_len": 1,
|
|
|
|
"suffix_len": 0,
|
|
|
|
"prefix_tree": 5,
|
|
|
|
"suffix_tree": 4294967295,
|
|
|
|
},
|
|
|
|
],
|
|
|
|
"labels": (1, 3, 4, 6),
|
|
|
|
}
|
2022-03-28 12:13:50 +03:00
|
|
|
|
|
|
|
|
2023-01-20 21:34:11 +03:00
|
|
|
@pytest.mark.parametrize("top_k", (1, 5, 30))
|
|
|
|
def test_no_data(top_k):
|
2022-03-28 12:13:50 +03:00
|
|
|
# Test that the lemmatizer provides a nice error when there's no tagging data / labels
|
|
|
|
TEXTCAT_DATA = [
|
|
|
|
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
|
|
|
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
|
|
|
]
|
|
|
|
nlp = English()
|
2023-01-20 21:34:11 +03:00
|
|
|
nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
2022-03-28 12:13:50 +03:00
|
|
|
nlp.add_pipe("textcat")
|
|
|
|
|
|
|
|
train_examples = []
|
|
|
|
for t in TEXTCAT_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
|
|
|
|
|
2023-01-20 21:34:11 +03:00
|
|
|
@pytest.mark.parametrize("top_k", (1, 5, 30))
|
|
|
|
def test_incomplete_data(top_k):
|
2022-03-28 12:13:50 +03:00
|
|
|
# Test that the lemmatizer works with incomplete information
|
|
|
|
nlp = English()
|
2023-01-20 21:34:11 +03:00
|
|
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
2022-03-28 12:13:50 +03:00
|
|
|
lemmatizer.min_tree_freq = 1
|
|
|
|
train_examples = []
|
|
|
|
for t in PARTIAL_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
for i in range(50):
|
|
|
|
losses = {}
|
|
|
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
|
|
|
assert losses["trainable_lemmatizer"] < 0.00001
|
|
|
|
|
|
|
|
# test the trained model
|
|
|
|
test_text = "She likes blue eggs"
|
|
|
|
doc = nlp(test_text)
|
|
|
|
assert doc[1].lemma_ == "like"
|
|
|
|
assert doc[2].lemma_ == "blue"
|
|
|
|
|
2023-01-12 14:13:55 +03:00
|
|
|
# Check that incomplete annotations are ignored.
|
|
|
|
scores, _ = lemmatizer.model([eg.predicted for eg in train_examples], is_train=True)
|
|
|
|
_, dX = lemmatizer.get_loss(train_examples, scores)
|
|
|
|
xp = lemmatizer.model.ops.xp
|
|
|
|
|
|
|
|
# Missing annotations.
|
|
|
|
assert xp.count_nonzero(dX[0][0]) == 0
|
|
|
|
assert xp.count_nonzero(dX[0][3]) == 0
|
|
|
|
assert xp.count_nonzero(dX[1][0]) == 0
|
|
|
|
assert xp.count_nonzero(dX[1][3]) == 0
|
|
|
|
|
|
|
|
# Misaligned annotations.
|
|
|
|
assert xp.count_nonzero(dX[1][1]) == 0
|
2022-03-28 12:13:50 +03:00
|
|
|
|
|
|
|
|
2023-01-20 21:34:11 +03:00
|
|
|
@pytest.mark.parametrize("top_k", (1, 5, 30))
|
|
|
|
def test_overfitting_IO(top_k):
|
2022-03-28 12:13:50 +03:00
|
|
|
nlp = English()
|
2023-01-20 21:34:11 +03:00
|
|
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
2022-03-28 12:13:50 +03:00
|
|
|
lemmatizer.min_tree_freq = 1
|
|
|
|
train_examples = []
|
|
|
|
for t in TRAIN_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
|
|
|
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
|
|
|
|
for i in range(50):
|
|
|
|
losses = {}
|
|
|
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
|
|
|
assert losses["trainable_lemmatizer"] < 0.00001
|
|
|
|
|
|
|
|
test_text = "She likes blue eggs"
|
|
|
|
doc = nlp(test_text)
|
|
|
|
assert doc[0].lemma_ == "she"
|
|
|
|
assert doc[1].lemma_ == "like"
|
|
|
|
assert doc[2].lemma_ == "blue"
|
|
|
|
assert doc[3].lemma_ == "egg"
|
|
|
|
|
|
|
|
# Check model after a {to,from}_disk roundtrip
|
|
|
|
with util.make_tempdir() as tmp_dir:
|
|
|
|
nlp.to_disk(tmp_dir)
|
|
|
|
nlp2 = util.load_model_from_path(tmp_dir)
|
|
|
|
doc2 = nlp2(test_text)
|
|
|
|
assert doc2[0].lemma_ == "she"
|
|
|
|
assert doc2[1].lemma_ == "like"
|
|
|
|
assert doc2[2].lemma_ == "blue"
|
|
|
|
assert doc2[3].lemma_ == "egg"
|
|
|
|
|
|
|
|
# Check model after a {to,from}_bytes roundtrip
|
|
|
|
nlp_bytes = nlp.to_bytes()
|
|
|
|
nlp3 = English()
|
2023-01-20 21:34:11 +03:00
|
|
|
nlp3.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
2022-03-28 12:13:50 +03:00
|
|
|
nlp3.from_bytes(nlp_bytes)
|
|
|
|
doc3 = nlp3(test_text)
|
|
|
|
assert doc3[0].lemma_ == "she"
|
|
|
|
assert doc3[1].lemma_ == "like"
|
|
|
|
assert doc3[2].lemma_ == "blue"
|
|
|
|
assert doc3[3].lemma_ == "egg"
|
|
|
|
|
|
|
|
# Check model after a pickle roundtrip.
|
|
|
|
nlp_bytes = pickle.dumps(nlp)
|
|
|
|
nlp4 = pickle.loads(nlp_bytes)
|
|
|
|
doc4 = nlp4(test_text)
|
|
|
|
assert doc4[0].lemma_ == "she"
|
|
|
|
assert doc4[1].lemma_ == "like"
|
|
|
|
assert doc4[2].lemma_ == "blue"
|
|
|
|
assert doc4[3].lemma_ == "egg"
|
|
|
|
|
|
|
|
|
2023-01-16 12:25:53 +03:00
|
|
|
def test_is_distillable():
|
|
|
|
nlp = English()
|
|
|
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
|
|
|
assert lemmatizer.is_distillable
|
|
|
|
|
|
|
|
|
|
|
|
def test_distill():
|
|
|
|
teacher = English()
|
|
|
|
teacher_lemmatizer = teacher.add_pipe("trainable_lemmatizer")
|
|
|
|
teacher_lemmatizer.min_tree_freq = 1
|
|
|
|
train_examples = []
|
|
|
|
for t in TRAIN_DATA:
|
|
|
|
train_examples.append(Example.from_dict(teacher.make_doc(t[0]), t[1]))
|
|
|
|
|
|
|
|
optimizer = teacher.initialize(get_examples=lambda: train_examples)
|
|
|
|
|
|
|
|
for i in range(50):
|
|
|
|
losses = {}
|
|
|
|
teacher.update(train_examples, sgd=optimizer, losses=losses)
|
|
|
|
assert losses["trainable_lemmatizer"] < 0.00001
|
|
|
|
|
|
|
|
student = English()
|
|
|
|
student_lemmatizer = student.add_pipe("trainable_lemmatizer")
|
|
|
|
student_lemmatizer.min_tree_freq = 1
|
|
|
|
student_lemmatizer.initialize(
|
|
|
|
get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data
|
|
|
|
)
|
|
|
|
|
|
|
|
distill_examples = [
|
|
|
|
Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA
|
|
|
|
]
|
|
|
|
|
|
|
|
for i in range(50):
|
|
|
|
losses = {}
|
|
|
|
student_lemmatizer.distill(
|
|
|
|
teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses
|
|
|
|
)
|
|
|
|
assert losses["trainable_lemmatizer"] < 0.00001
|
|
|
|
|
|
|
|
test_text = "She likes blue eggs"
|
|
|
|
doc = student(test_text)
|
|
|
|
assert doc[0].lemma_ == "she"
|
|
|
|
assert doc[1].lemma_ == "like"
|
|
|
|
assert doc[2].lemma_ == "blue"
|
|
|
|
assert doc[3].lemma_ == "egg"
|
|
|
|
|
|
|
|
|
2022-03-28 12:13:50 +03:00
|
|
|
def test_lemmatizer_requires_labels():
|
|
|
|
nlp = English()
|
|
|
|
nlp.add_pipe("trainable_lemmatizer")
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
nlp.initialize()
|
|
|
|
|
|
|
|
|
|
|
|
def test_lemmatizer_label_data():
|
|
|
|
nlp = English()
|
|
|
|
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
|
|
|
lemmatizer.min_tree_freq = 1
|
|
|
|
train_examples = []
|
|
|
|
for t in TRAIN_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
|
|
|
|
nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
|
|
|
|
nlp2 = English()
|
|
|
|
lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer")
|
|
|
|
lemmatizer2.initialize(
|
|
|
|
get_examples=lambda: train_examples, labels=lemmatizer.label_data
|
|
|
|
)
|
|
|
|
|
|
|
|
# Verify that the labels and trees are the same.
|
|
|
|
assert lemmatizer.labels == lemmatizer2.labels
|
|
|
|
assert lemmatizer.trees.to_bytes() == lemmatizer2.trees.to_bytes()
|
|
|
|
|
|
|
|
|
|
|
|
def test_dutch():
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
tree = trees.add("deelt", "delen")
|
|
|
|
assert trees.tree_to_str(tree) == "(m 0 3 () (m 0 2 (s '' 'l') (s 'lt' 'n')))"
|
|
|
|
|
|
|
|
tree = trees.add("gedeeld", "delen")
|
|
|
|
assert (
|
|
|
|
trees.tree_to_str(tree) == "(m 2 3 (s 'ge' '') (m 0 2 (s '' 'l') (s 'ld' 'n')))"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_from_to_bytes():
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
trees.add("deelt", "delen")
|
|
|
|
trees.add("gedeeld", "delen")
|
|
|
|
|
|
|
|
b = trees.to_bytes()
|
|
|
|
|
|
|
|
trees2 = EditTrees(strings)
|
|
|
|
trees2.from_bytes(b)
|
|
|
|
|
|
|
|
# Verify that the nodes did not change.
|
|
|
|
assert len(trees) == len(trees2)
|
|
|
|
for i in range(len(trees)):
|
|
|
|
assert trees.tree_to_str(i) == trees2.tree_to_str(i)
|
|
|
|
|
|
|
|
# Reinserting the same trees should not add new nodes.
|
|
|
|
trees2.add("deelt", "delen")
|
|
|
|
trees2.add("gedeeld", "delen")
|
|
|
|
assert len(trees) == len(trees2)
|
|
|
|
|
|
|
|
|
|
|
|
def test_from_to_disk():
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
trees.add("deelt", "delen")
|
|
|
|
trees.add("gedeeld", "delen")
|
|
|
|
|
|
|
|
trees2 = EditTrees(strings)
|
|
|
|
with make_tempdir() as temp_dir:
|
|
|
|
trees_file = temp_dir / "edit_trees.bin"
|
|
|
|
trees.to_disk(trees_file)
|
|
|
|
trees2 = trees2.from_disk(trees_file)
|
|
|
|
|
|
|
|
# Verify that the nodes did not change.
|
|
|
|
assert len(trees) == len(trees2)
|
|
|
|
for i in range(len(trees)):
|
|
|
|
assert trees.tree_to_str(i) == trees2.tree_to_str(i)
|
|
|
|
|
|
|
|
# Reinserting the same trees should not add new nodes.
|
|
|
|
trees2.add("deelt", "delen")
|
|
|
|
trees2.add("gedeeld", "delen")
|
|
|
|
assert len(trees) == len(trees2)
|
|
|
|
|
|
|
|
|
|
|
|
@given(st.text(), st.text())
|
|
|
|
def test_roundtrip(form, lemma):
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
tree = trees.add(form, lemma)
|
|
|
|
assert trees.apply(tree, form) == lemma
|
|
|
|
|
|
|
|
|
|
|
|
@given(st.text(alphabet="ab"), st.text(alphabet="ab"))
|
|
|
|
def test_roundtrip_small_alphabet(form, lemma):
|
|
|
|
# Test with small alphabets to have more overlap.
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
tree = trees.add(form, lemma)
|
|
|
|
assert trees.apply(tree, form) == lemma
|
|
|
|
|
|
|
|
|
|
|
|
def test_unapplicable_trees():
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
tree3 = trees.add("deelt", "delen")
|
|
|
|
|
|
|
|
# Replacement fails.
|
|
|
|
assert trees.apply(tree3, "deeld") == None
|
|
|
|
|
|
|
|
# Suffix + prefix are too large.
|
|
|
|
assert trees.apply(tree3, "de") == None
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_strings():
|
|
|
|
strings = StringStore()
|
|
|
|
trees = EditTrees(strings)
|
|
|
|
no_change = trees.add("xyz", "xyz")
|
|
|
|
empty = trees.add("", "")
|
|
|
|
assert no_change == empty
|
Store activations in `Doc`s when `save_activations` is enabled (#11002)
* Store activations in Doc when `store_activations` is enabled
This change adds the new `activations` attribute to `Doc`. This
attribute can be used by trainable pipes to store their activations,
probabilities, and guesses for downstream users.
As an example, this change modifies the `tagger` and `senter` pipes to
add an `store_activations` option. When this option is enabled, the
probabilities and guesses are stored in `set_annotations`.
* Change type of `store_activations` to `Union[bool, List[str]]`
When the value is:
- A bool: all activations are stored when set to `True`.
- A List[str]: the activations named in the list are stored
* Formatting fixes in Tagger
* Support store_activations in spancat and morphologizer
* Make Doc.activations type visible to MyPy
* textcat/textcat_multilabel: add store_activations option
* trainable_lemmatizer/entity_linker: add store_activations option
* parser/ner: do not currently support returning activations
* Extend tagger and senter tests
So that they, like the other tests, also check that we get no
activations if no activations were requested.
* Document `Doc.activations` and `store_activations` in the relevant pipes
* Start errors/warnings at higher numbers to avoid merge conflicts
Between the master and v4 branches.
* Add `store_activations` to docstrings.
* Replace store_activations setter by set_store_activations method
Setters that take a different type than what the getter returns are still
problematic for MyPy. Replace the setter by a method, so that type inference
works everywhere.
* Use dict comprehension suggested by @svlandeg
* Revert "Use dict comprehension suggested by @svlandeg"
This reverts commit 6e7b958f7060397965176c69649e5414f1f24988.
* EntityLinker: add type annotations to _add_activations
* _store_activations: make kwarg-only, remove doc_scores_lens arg
* set_annotations: add type annotations
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* TextCat.predict: return dict
* Make the `TrainablePipe.store_activations` property a bool
This means that we can also bring back `store_activations` setter.
* Remove `TrainablePipe.activations`
We do not need to enumerate the activations anymore since `store_activations` is
`bool`.
* Add type annotations for activations in predict/set_annotations
* Rename `TrainablePipe.store_activations` to `save_activations`
* Error E1400 is not used anymore
This error was used when activations were still `Union[bool, List[str]]`.
* Change wording in API docs after store -> save change
* docs: tag (save_)activations as new in spaCy 4.0
* Fix copied line in morphologizer activations test
* Don't train in any test_save_activations test
* Rename activations
- "probs" -> "probabilities"
- "guesses" -> "label_ids", except in the edit tree lemmatizer, where
"guesses" -> "tree_ids".
* Remove unused W400 warning.
This warning was used when we still allowed the user to specify
which activations to save.
* Formatting fixes
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Replace "kb_ids" by a constant
* spancat: replace a cast by an assertion
* Fix EOF spacing
* Fix comments in test_save_activations tests
* Do not set RNG seed in activation saving tests
* Revert "spancat: replace a cast by an assertion"
This reverts commit 0bd5730d16432443a2b247316928d4f789ad8741.
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2022-09-13 10:51:12 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_save_activations():
|
|
|
|
nlp = English()
|
|
|
|
lemmatizer = cast(TrainablePipe, nlp.add_pipe("trainable_lemmatizer"))
|
|
|
|
lemmatizer.min_tree_freq = 1
|
|
|
|
train_examples = []
|
|
|
|
for t in TRAIN_DATA:
|
|
|
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
|
|
|
nlp.initialize(get_examples=lambda: train_examples)
|
|
|
|
nO = lemmatizer.model.get_dim("nO")
|
|
|
|
|
|
|
|
doc = nlp("This is a test.")
|
|
|
|
assert "trainable_lemmatizer" not in doc.activations
|
|
|
|
|
|
|
|
lemmatizer.save_activations = True
|
|
|
|
doc = nlp("This is a test.")
|
|
|
|
assert list(doc.activations["trainable_lemmatizer"].keys()) == [
|
|
|
|
"probabilities",
|
|
|
|
"tree_ids",
|
|
|
|
]
|
|
|
|
assert doc.activations["trainable_lemmatizer"]["probabilities"].shape == (5, nO)
|
|
|
|
assert doc.activations["trainable_lemmatizer"]["tree_ids"].shape == (5,)
|