mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-28 19:06:33 +03:00
e27c60a702
* Improve the correctness of _parse_patch * If there are no more actions, do not attempt to make further transitions, even if not all states are final. * Assert that the number of actions for a step is the same as the number of states. * Reimplement distillation with oracle cut size The code for distillation with an oracle cut size was not reimplemented after the parser refactor. We did not notice, because we did not have tests for this functionality. This change brings back the functionality and adds this to the parser tests. * Rename states2actions to _states_to_actions for consistency * Test distillation max cuts in NER * Mark parser/NER tests as slow * Typo * Fix invariant in _states_diff_to_actions * Rename _init_batch -> _init_batch_from_teacher * Ninja edit the ninja edit * Check that we raise an exception when we pass the incorrect number or actions * Remove unnecessary get Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Write out condition more explicitly --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
62 lines
1.6 KiB
Python
62 lines
1.6 KiB
Python
import numpy
|
|
import pytest
|
|
|
|
from spacy.lang.en import English
|
|
from spacy.ml.tb_framework import TransitionModelInputs
|
|
from spacy.training import Example
|
|
|
|
TRAIN_DATA = [
|
|
(
|
|
"They trade mortgage-backed securities.",
|
|
{
|
|
"heads": [1, 1, 4, 4, 5, 1, 1],
|
|
"deps": ["nsubj", "ROOT", "compound", "punct", "nmod", "dobj", "punct"],
|
|
},
|
|
),
|
|
(
|
|
"I like London and Berlin.",
|
|
{
|
|
"heads": [1, 1, 1, 2, 2, 1],
|
|
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
|
|
},
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def nlp_parser():
|
|
nlp = English()
|
|
parser = nlp.add_pipe("parser")
|
|
|
|
train_examples = []
|
|
for text, annotations in TRAIN_DATA:
|
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
|
for dep in annotations["deps"]:
|
|
parser.add_label(dep)
|
|
nlp.initialize()
|
|
|
|
return nlp, parser
|
|
|
|
|
|
def test_incorrect_number_of_actions(nlp_parser):
|
|
nlp, parser = nlp_parser
|
|
doc = nlp.make_doc("test")
|
|
|
|
# Too many actions for the number of docs
|
|
with pytest.raises(AssertionError):
|
|
parser.model.predict(
|
|
TransitionModelInputs(
|
|
docs=[doc], moves=parser.moves, actions=[numpy.array([0, 0], dtype="i")]
|
|
)
|
|
)
|
|
|
|
# Too few actions for the number of docs
|
|
with pytest.raises(AssertionError):
|
|
parser.model.predict(
|
|
TransitionModelInputs(
|
|
docs=[doc, doc],
|
|
moves=parser.moves,
|
|
actions=[numpy.array([0], dtype="i")],
|
|
)
|
|
)
|