mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
fix parser tests to work with example (most still failing)
This commit is contained in:
parent
9f43ba839a
commit
cd790aaa2a
|
@ -497,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def has_gold(self, gold, start=0, end=None):
|
def has_gold(self, gold, start=0, end=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def preprocess_gold(self, gold):
|
def preprocess_gold(self, example):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||||
|
|
|
@ -94,7 +94,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
def has_gold(self, gold, start=0, end=None):
|
def has_gold(self, gold, start=0, end=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def preprocess_gold(self, gold):
|
def preprocess_gold(self, example):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
|
|
|
@ -97,7 +97,7 @@ cdef class TransitionSystem:
|
||||||
def finalize_doc(self, doc):
|
def finalize_doc(self, doc):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess_gold(self, gold):
|
def preprocess_gold(self, example):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def is_gold_parse(self, StateClass state, example):
|
def is_gold_parse(self, StateClass state, example):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.pipeline.defaults import default_parser
|
from spacy.pipeline.defaults import default_parser
|
||||||
from spacy.pipeline import DependencyParser
|
from spacy.pipeline import DependencyParser
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
@ -11,9 +12,9 @@ from spacy.syntax.arc_eager import ArcEager
|
||||||
|
|
||||||
def get_sequence_costs(M, words, heads, deps, transitions):
|
def get_sequence_costs(M, words, heads, deps, transitions):
|
||||||
doc = Doc(Vocab(), words=words)
|
doc = Doc(Vocab(), words=words)
|
||||||
gold = GoldParse(doc, heads=heads, deps=deps)
|
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
|
||||||
state = StateClass(doc)
|
state = StateClass(doc)
|
||||||
M.preprocess_gold(gold)
|
M.preprocess_gold(example)
|
||||||
cost_history = []
|
cost_history = []
|
||||||
for gold_action in transitions:
|
for gold_action in transitions:
|
||||||
state_costs = {}
|
state_costs = {}
|
||||||
|
@ -149,6 +150,6 @@ def test_get_oracle_actions():
|
||||||
elif head < i:
|
elif head < i:
|
||||||
parser.moves.add_action(3, dep)
|
parser.moves.add_action(3, dep)
|
||||||
heads, deps = projectivize(heads, deps)
|
heads, deps = projectivize(heads, deps)
|
||||||
gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
|
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
|
||||||
parser.moves.preprocess_gold(gold)
|
parser.moves.preprocess_gold(example)
|
||||||
parser.moves.get_oracle_sequence(doc, gold)
|
parser.moves.get_oracle_sequence(example)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
from spacy.attrs import ENT_IOB
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline.defaults import default_ner
|
from spacy.pipeline.defaults import default_ner
|
||||||
|
@ -8,7 +10,7 @@ from spacy.syntax.ner import BiluoPushDown
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
from ...gold import Example
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
||||||
|
@ -48,41 +50,45 @@ def tsys(vocab, entity_types):
|
||||||
|
|
||||||
|
|
||||||
def test_get_oracle_moves(tsys, doc, entity_annots):
|
def test_get_oracle_moves(tsys, doc, entity_annots):
|
||||||
gold = GoldParse(doc, entities=entity_annots)
|
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||||
tsys.preprocess_gold(gold)
|
tsys.preprocess_gold(example)
|
||||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
act_classes = tsys.get_oracle_sequence(example)
|
||||||
names = [tsys.get_class_name(act) for act in act_classes]
|
names = [tsys.get_class_name(act) for act in act_classes]
|
||||||
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
|
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
|
||||||
entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
|
entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
|
||||||
gold = GoldParse(doc, entities=entity_annots)
|
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||||
for i, tag in enumerate(gold.ner):
|
ex_dict = example.to_dict()
|
||||||
|
|
||||||
|
for i, tag in enumerate(ex_dict["doc_annotation"]["entities"]):
|
||||||
if tag == "L-!GPE":
|
if tag == "L-!GPE":
|
||||||
gold.ner[i] = "-"
|
ex_dict["doc_annotation"]["entities"][i] = "-"
|
||||||
tsys.preprocess_gold(gold)
|
example = Example.from_dict(doc, ex_dict)
|
||||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
|
||||||
|
tsys.preprocess_gold(example)
|
||||||
|
act_classes = tsys.get_oracle_sequence(example)
|
||||||
names = [tsys.get_class_name(act) for act in act_classes]
|
names = [tsys.get_class_name(act) for act in act_classes]
|
||||||
assert names
|
assert names
|
||||||
|
|
||||||
|
|
||||||
def test_get_oracle_moves_negative_entities2(tsys, vocab):
|
def test_get_oracle_moves_negative_entities2(tsys, vocab):
|
||||||
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
||||||
gold = GoldParse(doc, entities=[])
|
entity_annots = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
|
||||||
gold.ner = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
|
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||||
tsys.preprocess_gold(gold)
|
tsys.preprocess_gold(example)
|
||||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
act_classes = tsys.get_oracle_sequence(example)
|
||||||
names = [tsys.get_class_name(act) for act in act_classes]
|
names = [tsys.get_class_name(act) for act in act_classes]
|
||||||
assert names
|
assert names
|
||||||
|
|
||||||
|
|
||||||
def test_get_oracle_moves_negative_O(tsys, vocab):
|
def test_get_oracle_moves_negative_O(tsys, vocab):
|
||||||
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
||||||
gold = GoldParse(doc, entities=[])
|
entity_annots = ["O", "!O", "O", "!O"]
|
||||||
gold.ner = ["O", "!O", "O", "!O"]
|
example = Example.from_dict(doc, {"entities": []})
|
||||||
tsys.preprocess_gold(gold)
|
tsys.preprocess_gold(example)
|
||||||
act_classes = tsys.get_oracle_sequence(doc, gold)
|
act_classes = tsys.get_oracle_sequence(example)
|
||||||
names = [tsys.get_class_name(act) for act in act_classes]
|
names = [tsys.get_class_name(act) for act in act_classes]
|
||||||
assert names
|
assert names
|
||||||
|
|
||||||
|
@ -92,7 +98,7 @@ def test_oracle_moves_missing_B(en_vocab):
|
||||||
biluo_tags = [None, None, "L-PRODUCT"]
|
biluo_tags = [None, None, "L-PRODUCT"]
|
||||||
|
|
||||||
doc = Doc(en_vocab, words=words)
|
doc = Doc(en_vocab, words=words)
|
||||||
gold = GoldParse(doc, words=words, entities=biluo_tags)
|
example = Example.from_dict(doc, {"words": words, "entities": biluo_tags})
|
||||||
|
|
||||||
moves = BiluoPushDown(en_vocab.strings)
|
moves = BiluoPushDown(en_vocab.strings)
|
||||||
move_types = ("M", "B", "I", "L", "U", "O")
|
move_types = ("M", "B", "I", "L", "U", "O")
|
||||||
|
@ -107,8 +113,8 @@ def test_oracle_moves_missing_B(en_vocab):
|
||||||
moves.add_action(move_types.index("I"), label)
|
moves.add_action(move_types.index("I"), label)
|
||||||
moves.add_action(move_types.index("L"), label)
|
moves.add_action(move_types.index("L"), label)
|
||||||
moves.add_action(move_types.index("U"), label)
|
moves.add_action(move_types.index("U"), label)
|
||||||
moves.preprocess_gold(gold)
|
moves.preprocess_gold(example)
|
||||||
moves.get_oracle_sequence(doc, gold)
|
moves.get_oracle_sequence(example)
|
||||||
|
|
||||||
|
|
||||||
def test_oracle_moves_whitespace(en_vocab):
|
def test_oracle_moves_whitespace(en_vocab):
|
||||||
|
@ -116,7 +122,7 @@ def test_oracle_moves_whitespace(en_vocab):
|
||||||
biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
|
biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
|
||||||
|
|
||||||
doc = Doc(en_vocab, words=words)
|
doc = Doc(en_vocab, words=words)
|
||||||
gold = GoldParse(doc, words=words, entities=biluo_tags)
|
example = Example.from_dict(doc, {"entities": biluo_tags})
|
||||||
|
|
||||||
moves = BiluoPushDown(en_vocab.strings)
|
moves = BiluoPushDown(en_vocab.strings)
|
||||||
move_types = ("M", "B", "I", "L", "U", "O")
|
move_types = ("M", "B", "I", "L", "U", "O")
|
||||||
|
@ -128,8 +134,8 @@ def test_oracle_moves_whitespace(en_vocab):
|
||||||
else:
|
else:
|
||||||
action, label = tag.split("-")
|
action, label = tag.split("-")
|
||||||
moves.add_action(move_types.index(action), label)
|
moves.add_action(move_types.index(action), label)
|
||||||
moves.preprocess_gold(gold)
|
moves.preprocess_gold(example)
|
||||||
moves.get_oracle_sequence(doc, gold)
|
moves.get_oracle_sequence(example)
|
||||||
|
|
||||||
|
|
||||||
def test_accept_blocked_token():
|
def test_accept_blocked_token():
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.pipeline.defaults import default_parser, default_tok2vec
|
from spacy.pipeline.defaults import default_parser, default_tok2vec
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.syntax.arc_eager import ArcEager
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
|
@ -71,7 +73,8 @@ def test_update_doc(parser, model, doc, gold):
|
||||||
weights -= 0.001 * gradient
|
weights -= 0.001 * gradient
|
||||||
return weights, gradient
|
return weights, gradient
|
||||||
|
|
||||||
parser.update((doc, gold), sgd=optimize)
|
example = Example.from_dict(doc, gold)
|
||||||
|
parser.update([example], sgd=optimize)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
@pytest.mark.xfail
|
||||||
|
|
|
@ -42,11 +42,6 @@ def tokvecs(docs, vector_size):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def golds(docs):
|
|
||||||
return [GoldParse(doc) for doc in docs]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def batch_size(docs):
|
def batch_size(docs):
|
||||||
return len(docs)
|
return len(docs)
|
||||||
|
@ -77,19 +72,24 @@ def scores(moves, batch_size, beam_width):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# All tests below are skipped after removing Beam stuff during the Example/GoldParse refactor
|
||||||
|
@pytest.mark.skip
|
||||||
def test_create_beam(beam):
|
def test_create_beam(beam):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
def test_beam_advance(beam, scores):
|
def test_beam_advance(beam, scores):
|
||||||
beam.advance(scores)
|
beam.advance(scores)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
def test_beam_advance_too_few_scores(beam, scores):
|
def test_beam_advance_too_few_scores(beam, scores):
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
beam.advance(scores[:-1])
|
beam.advance(scores[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
def test_beam_parse():
|
def test_beam_parse():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}
|
config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}
|
||||||
|
|
|
@ -3,6 +3,7 @@ from thinc.api import Adam
|
||||||
from spacy.attrs import NORM
|
from spacy.attrs import NORM
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.pipeline.defaults import default_parser
|
from spacy.pipeline.defaults import default_parser
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.pipeline import DependencyParser
|
from spacy.pipeline import DependencyParser
|
||||||
|
@ -27,8 +28,8 @@ def parser(vocab):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
losses = {}
|
losses = {}
|
||||||
doc = Doc(vocab, words=["a", "b", "c", "d"])
|
doc = Doc(vocab, words=["a", "b", "c", "d"])
|
||||||
gold = dict(heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
example = Example.from_dict(doc, {"heads": [1, 1, 3, 3], "deps": ["left", "ROOT", "left", "ROOT"]})
|
||||||
parser.update((doc, gold), sgd=sgd, losses=losses)
|
parser.update([example], sgd=sgd, losses=losses)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user