mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
fix parser tests to work with example (most still failing)
This commit is contained in:
parent
9f43ba839a
commit
cd790aaa2a
|
@ -497,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
def has_gold(self, gold, start=0, end=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def preprocess_gold(self, gold):
|
||||
def preprocess_gold(self, example):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||
|
|
|
@ -94,7 +94,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
def has_gold(self, gold, start=0, end=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def preprocess_gold(self, gold):
|
||||
def preprocess_gold(self, example):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
|
|
|
@ -97,7 +97,7 @@ cdef class TransitionSystem:
|
|||
def finalize_doc(self, doc):
|
||||
pass
|
||||
|
||||
def preprocess_gold(self, gold):
|
||||
def preprocess_gold(self, example):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_gold_parse(self, StateClass state, example):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
from spacy.gold import Example
|
||||
from spacy.pipeline.defaults import default_parser
|
||||
from spacy.pipeline import DependencyParser
|
||||
from spacy.tokens import Doc
|
||||
|
@ -11,9 +12,9 @@ from spacy.syntax.arc_eager import ArcEager
|
|||
|
||||
def get_sequence_costs(M, words, heads, deps, transitions):
|
||||
doc = Doc(Vocab(), words=words)
|
||||
gold = GoldParse(doc, heads=heads, deps=deps)
|
||||
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
|
||||
state = StateClass(doc)
|
||||
M.preprocess_gold(gold)
|
||||
M.preprocess_gold(example)
|
||||
cost_history = []
|
||||
for gold_action in transitions:
|
||||
state_costs = {}
|
||||
|
@ -149,6 +150,6 @@ def test_get_oracle_actions():
|
|||
elif head < i:
|
||||
parser.moves.add_action(3, dep)
|
||||
heads, deps = projectivize(heads, deps)
|
||||
gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
|
||||
parser.moves.preprocess_gold(gold)
|
||||
parser.moves.get_oracle_sequence(doc, gold)
|
||||
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
|
||||
parser.moves.preprocess_gold(example)
|
||||
parser.moves.get_oracle_sequence(example)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import pytest
|
||||
from spacy.attrs import ENT_IOB
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from spacy.pipeline.defaults import default_ner
|
||||
|
@ -8,7 +10,7 @@ from spacy.syntax.ner import BiluoPushDown
|
|||
from spacy.tokens import Doc
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
from ...gold import Example
|
||||
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
||||
|
@ -48,41 +50,45 @@ def tsys(vocab, entity_types):
|
|||
|
||||
|
||||
def test_get_oracle_moves(tsys, doc, entity_annots):
|
||||
gold = GoldParse(doc, entities=entity_annots)
|
||||
tsys.preprocess_gold(gold)
|
||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
||||
|
||||
|
||||
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
|
||||
entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
|
||||
gold = GoldParse(doc, entities=entity_annots)
|
||||
for i, tag in enumerate(gold.ner):
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
ex_dict = example.to_dict()
|
||||
|
||||
for i, tag in enumerate(ex_dict["doc_annotation"]["entities"]):
|
||||
if tag == "L-!GPE":
|
||||
gold.ner[i] = "-"
|
||||
tsys.preprocess_gold(gold)
|
||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
||||
ex_dict["doc_annotation"]["entities"][i] = "-"
|
||||
example = Example.from_dict(doc, ex_dict)
|
||||
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
|
||||
|
||||
def test_get_oracle_moves_negative_entities2(tsys, vocab):
|
||||
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
||||
gold = GoldParse(doc, entities=[])
|
||||
gold.ner = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
|
||||
tsys.preprocess_gold(gold)
|
||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
||||
entity_annots = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
|
||||
|
||||
def test_get_oracle_moves_negative_O(tsys, vocab):
|
||||
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
||||
gold = GoldParse(doc, entities=[])
|
||||
gold.ner = ["O", "!O", "O", "!O"]
|
||||
tsys.preprocess_gold(gold)
|
||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
||||
entity_annots = ["O", "!O", "O", "!O"]
|
||||
example = Example.from_dict(doc, {"entities": []})
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
|
||||
|
@ -92,7 +98,7 @@ def test_oracle_moves_missing_B(en_vocab):
|
|||
biluo_tags = [None, None, "L-PRODUCT"]
|
||||
|
||||
doc = Doc(en_vocab, words=words)
|
||||
gold = GoldParse(doc, words=words, entities=biluo_tags)
|
||||
example = Example.from_dict(doc, {"words": words, "entities": biluo_tags})
|
||||
|
||||
moves = BiluoPushDown(en_vocab.strings)
|
||||
move_types = ("M", "B", "I", "L", "U", "O")
|
||||
|
@ -107,8 +113,8 @@ def test_oracle_moves_missing_B(en_vocab):
|
|||
moves.add_action(move_types.index("I"), label)
|
||||
moves.add_action(move_types.index("L"), label)
|
||||
moves.add_action(move_types.index("U"), label)
|
||||
moves.preprocess_gold(gold)
|
||||
moves.get_oracle_sequence(doc, gold)
|
||||
moves.preprocess_gold(example)
|
||||
moves.get_oracle_sequence(example)
|
||||
|
||||
|
||||
def test_oracle_moves_whitespace(en_vocab):
|
||||
|
@ -116,7 +122,7 @@ def test_oracle_moves_whitespace(en_vocab):
|
|||
biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
|
||||
|
||||
doc = Doc(en_vocab, words=words)
|
||||
gold = GoldParse(doc, words=words, entities=biluo_tags)
|
||||
example = Example.from_dict(doc, {"entities": biluo_tags})
|
||||
|
||||
moves = BiluoPushDown(en_vocab.strings)
|
||||
move_types = ("M", "B", "I", "L", "U", "O")
|
||||
|
@ -128,8 +134,8 @@ def test_oracle_moves_whitespace(en_vocab):
|
|||
else:
|
||||
action, label = tag.split("-")
|
||||
moves.add_action(move_types.index(action), label)
|
||||
moves.preprocess_gold(gold)
|
||||
moves.get_oracle_sequence(doc, gold)
|
||||
moves.preprocess_gold(example)
|
||||
moves.get_oracle_sequence(example)
|
||||
|
||||
|
||||
def test_accept_blocked_token():
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from spacy.gold import Example
|
||||
from spacy.pipeline.defaults import default_parser, default_tok2vec
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
|
@ -71,7 +73,8 @@ def test_update_doc(parser, model, doc, gold):
|
|||
weights -= 0.001 * gradient
|
||||
return weights, gradient
|
||||
|
||||
parser.update((doc, gold), sgd=optimize)
|
||||
example = Example.from_dict(doc, gold)
|
||||
parser.update([example], sgd=optimize)
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
|
|
|
@ -42,11 +42,6 @@ def tokvecs(docs, vector_size):
|
|||
return output
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def golds(docs):
|
||||
return [GoldParse(doc) for doc in docs]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_size(docs):
|
||||
return len(docs)
|
||||
|
@ -77,19 +72,24 @@ def scores(moves, batch_size, beam_width):
|
|||
]
|
||||
|
||||
|
||||
# All tests below are skipped after removing Beam stuff during the Example/GoldParse refactor
|
||||
@pytest.mark.skip
|
||||
def test_create_beam(beam):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_beam_advance(beam, scores):
|
||||
beam.advance(scores)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_beam_advance_too_few_scores(beam, scores):
|
||||
with pytest.raises(IndexError):
|
||||
beam.advance(scores[:-1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_beam_parse():
|
||||
nlp = Language()
|
||||
config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}
|
||||
|
|
|
@ -3,6 +3,7 @@ from thinc.api import Adam
|
|||
from spacy.attrs import NORM
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
from spacy.gold import Example
|
||||
from spacy.pipeline.defaults import default_parser
|
||||
from spacy.tokens import Doc
|
||||
from spacy.pipeline import DependencyParser
|
||||
|
@ -27,8 +28,8 @@ def parser(vocab):
|
|||
for i in range(10):
|
||||
losses = {}
|
||||
doc = Doc(vocab, words=["a", "b", "c", "d"])
|
||||
gold = dict(heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
||||
parser.update((doc, gold), sgd=sgd, losses=losses)
|
||||
example = Example.from_dict(doc, {"heads": [1, 1, 3, 3], "deps": ["left", "ROOT", "left", "ROOT"]})
|
||||
parser.update([example], sgd=sgd, losses=losses)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user