mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	fix parser tests to work with example (most still failing)
This commit is contained in:
		
							parent
							
								
									9f43ba839a
								
							
						
					
					
						commit
						cd790aaa2a
					
				| 
						 | 
				
			
			@ -497,7 +497,7 @@ cdef class ArcEager(TransitionSystem):
 | 
			
		|||
    def has_gold(self, gold, start=0, end=None):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def preprocess_gold(self, gold):
 | 
			
		||||
    def preprocess_gold(self, example):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    cdef Transition lookup_transition(self, object name_or_id) except *:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -94,7 +94,7 @@ cdef class BiluoPushDown(TransitionSystem):
 | 
			
		|||
    def has_gold(self, gold, start=0, end=None):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def preprocess_gold(self, gold):
 | 
			
		||||
    def preprocess_gold(self, example):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    cdef Transition lookup_transition(self, object name) except *:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ cdef class TransitionSystem:
 | 
			
		|||
    def finalize_doc(self, doc):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def preprocess_gold(self, gold):
 | 
			
		||||
    def preprocess_gold(self, example):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def is_gold_parse(self, StateClass state, example):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
import pytest
 | 
			
		||||
from spacy.vocab import Vocab
 | 
			
		||||
 | 
			
		||||
from spacy.gold import Example
 | 
			
		||||
from spacy.pipeline.defaults import default_parser
 | 
			
		||||
from spacy.pipeline import DependencyParser
 | 
			
		||||
from spacy.tokens import Doc
 | 
			
		||||
| 
						 | 
				
			
			@ -11,9 +12,9 @@ from spacy.syntax.arc_eager import ArcEager
 | 
			
		|||
 | 
			
		||||
def get_sequence_costs(M, words, heads, deps, transitions):
 | 
			
		||||
    doc = Doc(Vocab(), words=words)
 | 
			
		||||
    gold = GoldParse(doc, heads=heads, deps=deps)
 | 
			
		||||
    example = Example.from_dict(doc, {"heads": heads, "deps": deps})
 | 
			
		||||
    state = StateClass(doc)
 | 
			
		||||
    M.preprocess_gold(gold)
 | 
			
		||||
    M.preprocess_gold(example)
 | 
			
		||||
    cost_history = []
 | 
			
		||||
    for gold_action in transitions:
 | 
			
		||||
        state_costs = {}
 | 
			
		||||
| 
						 | 
				
			
			@ -149,6 +150,6 @@ def test_get_oracle_actions():
 | 
			
		|||
        elif head < i:
 | 
			
		||||
            parser.moves.add_action(3, dep)
 | 
			
		||||
    heads, deps = projectivize(heads, deps)
 | 
			
		||||
    gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
 | 
			
		||||
    parser.moves.preprocess_gold(gold)
 | 
			
		||||
    parser.moves.get_oracle_sequence(doc, gold)
 | 
			
		||||
    example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
 | 
			
		||||
    parser.moves.preprocess_gold(example)
 | 
			
		||||
    parser.moves.get_oracle_sequence(example)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,6 @@
 | 
			
		|||
import pytest
 | 
			
		||||
from spacy.attrs import ENT_IOB
 | 
			
		||||
 | 
			
		||||
from spacy import util
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.pipeline.defaults import default_ner
 | 
			
		||||
| 
						 | 
				
			
			@ -8,7 +10,7 @@ from spacy.syntax.ner import BiluoPushDown
 | 
			
		|||
from spacy.tokens import Doc
 | 
			
		||||
 | 
			
		||||
from ..util import make_tempdir
 | 
			
		||||
 | 
			
		||||
from ...gold import Example
 | 
			
		||||
 | 
			
		||||
TRAIN_DATA = [
 | 
			
		||||
    ("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
 | 
			
		||||
| 
						 | 
				
			
			@ -48,41 +50,45 @@ def tsys(vocab, entity_types):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def test_get_oracle_moves(tsys, doc, entity_annots):
 | 
			
		||||
    gold = GoldParse(doc, entities=entity_annots)
 | 
			
		||||
    tsys.preprocess_gold(gold)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(doc, gold)
 | 
			
		||||
    example = Example.from_dict(doc, {"entities": entity_annots})
 | 
			
		||||
    tsys.preprocess_gold(example)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(example)
 | 
			
		||||
    names = [tsys.get_class_name(act) for act in act_classes]
 | 
			
		||||
    assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
 | 
			
		||||
    entity_annots = [(s, e, "!" + label) for s, e, label in entity_annots]
 | 
			
		||||
    gold = GoldParse(doc, entities=entity_annots)
 | 
			
		||||
    for i, tag in enumerate(gold.ner):
 | 
			
		||||
    example = Example.from_dict(doc, {"entities": entity_annots})
 | 
			
		||||
    ex_dict = example.to_dict()
 | 
			
		||||
 | 
			
		||||
    for i, tag in enumerate(ex_dict["doc_annotation"]["entities"]):
 | 
			
		||||
        if tag == "L-!GPE":
 | 
			
		||||
            gold.ner[i] = "-"
 | 
			
		||||
    tsys.preprocess_gold(gold)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(doc, gold)
 | 
			
		||||
            ex_dict["doc_annotation"]["entities"][i] = "-"
 | 
			
		||||
    example = Example.from_dict(doc, ex_dict)
 | 
			
		||||
 | 
			
		||||
    tsys.preprocess_gold(example)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(example)
 | 
			
		||||
    names = [tsys.get_class_name(act) for act in act_classes]
 | 
			
		||||
    assert names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_oracle_moves_negative_entities2(tsys, vocab):
 | 
			
		||||
    doc = Doc(vocab, words=["A", "B", "C", "D"])
 | 
			
		||||
    gold = GoldParse(doc, entities=[])
 | 
			
		||||
    gold.ner = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
 | 
			
		||||
    tsys.preprocess_gold(gold)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(doc, gold)
 | 
			
		||||
    entity_annots = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
 | 
			
		||||
    example = Example.from_dict(doc, {"entities": entity_annots})
 | 
			
		||||
    tsys.preprocess_gold(example)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(example)
 | 
			
		||||
    names = [tsys.get_class_name(act) for act in act_classes]
 | 
			
		||||
    assert names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_oracle_moves_negative_O(tsys, vocab):
 | 
			
		||||
    doc = Doc(vocab, words=["A", "B", "C", "D"])
 | 
			
		||||
    gold = GoldParse(doc, entities=[])
 | 
			
		||||
    gold.ner = ["O", "!O", "O", "!O"]
 | 
			
		||||
    tsys.preprocess_gold(gold)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(doc, gold)
 | 
			
		||||
    entity_annots = ["O", "!O", "O", "!O"]
 | 
			
		||||
    example = Example.from_dict(doc, {"entities": []})
 | 
			
		||||
    tsys.preprocess_gold(example)
 | 
			
		||||
    act_classes = tsys.get_oracle_sequence(example)
 | 
			
		||||
    names = [tsys.get_class_name(act) for act in act_classes]
 | 
			
		||||
    assert names
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -92,7 +98,7 @@ def test_oracle_moves_missing_B(en_vocab):
 | 
			
		|||
    biluo_tags = [None, None, "L-PRODUCT"]
 | 
			
		||||
 | 
			
		||||
    doc = Doc(en_vocab, words=words)
 | 
			
		||||
    gold = GoldParse(doc, words=words, entities=biluo_tags)
 | 
			
		||||
    example = Example.from_dict(doc, {"words": words, "entities": biluo_tags})
 | 
			
		||||
 | 
			
		||||
    moves = BiluoPushDown(en_vocab.strings)
 | 
			
		||||
    move_types = ("M", "B", "I", "L", "U", "O")
 | 
			
		||||
| 
						 | 
				
			
			@ -107,8 +113,8 @@ def test_oracle_moves_missing_B(en_vocab):
 | 
			
		|||
            moves.add_action(move_types.index("I"), label)
 | 
			
		||||
            moves.add_action(move_types.index("L"), label)
 | 
			
		||||
            moves.add_action(move_types.index("U"), label)
 | 
			
		||||
    moves.preprocess_gold(gold)
 | 
			
		||||
    moves.get_oracle_sequence(doc, gold)
 | 
			
		||||
    moves.preprocess_gold(example)
 | 
			
		||||
    moves.get_oracle_sequence(example)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_oracle_moves_whitespace(en_vocab):
 | 
			
		||||
| 
						 | 
				
			
			@ -116,7 +122,7 @@ def test_oracle_moves_whitespace(en_vocab):
 | 
			
		|||
    biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
 | 
			
		||||
 | 
			
		||||
    doc = Doc(en_vocab, words=words)
 | 
			
		||||
    gold = GoldParse(doc, words=words, entities=biluo_tags)
 | 
			
		||||
    example = Example.from_dict(doc, {"entities": biluo_tags})
 | 
			
		||||
 | 
			
		||||
    moves = BiluoPushDown(en_vocab.strings)
 | 
			
		||||
    move_types = ("M", "B", "I", "L", "U", "O")
 | 
			
		||||
| 
						 | 
				
			
			@ -128,8 +134,8 @@ def test_oracle_moves_whitespace(en_vocab):
 | 
			
		|||
        else:
 | 
			
		||||
            action, label = tag.split("-")
 | 
			
		||||
            moves.add_action(move_types.index(action), label)
 | 
			
		||||
    moves.preprocess_gold(gold)
 | 
			
		||||
    moves.get_oracle_sequence(doc, gold)
 | 
			
		||||
    moves.preprocess_gold(example)
 | 
			
		||||
    moves.get_oracle_sequence(example)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_accept_blocked_token():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,6 @@
 | 
			
		|||
import pytest
 | 
			
		||||
 | 
			
		||||
from spacy.gold import Example
 | 
			
		||||
from spacy.pipeline.defaults import default_parser, default_tok2vec
 | 
			
		||||
from spacy.vocab import Vocab
 | 
			
		||||
from spacy.syntax.arc_eager import ArcEager
 | 
			
		||||
| 
						 | 
				
			
			@ -71,7 +73,8 @@ def test_update_doc(parser, model, doc, gold):
 | 
			
		|||
        weights -= 0.001 * gradient
 | 
			
		||||
        return weights, gradient
 | 
			
		||||
 | 
			
		||||
    parser.update((doc, gold), sgd=optimize)
 | 
			
		||||
    example = Example.from_dict(doc, gold)
 | 
			
		||||
    parser.update([example], sgd=optimize)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,11 +42,6 @@ def tokvecs(docs, vector_size):
 | 
			
		|||
    return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def golds(docs):
 | 
			
		||||
    return [GoldParse(doc) for doc in docs]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def batch_size(docs):
 | 
			
		||||
    return len(docs)
 | 
			
		||||
| 
						 | 
				
			
			@ -77,19 +72,24 @@ def scores(moves, batch_size, beam_width):
 | 
			
		|||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# All tests below are skipped after removing Beam stuff during the Example/GoldParse refactor
 | 
			
		||||
@pytest.mark.skip
 | 
			
		||||
def test_create_beam(beam):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skip
 | 
			
		||||
def test_beam_advance(beam, scores):
 | 
			
		||||
    beam.advance(scores)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skip
 | 
			
		||||
def test_beam_advance_too_few_scores(beam, scores):
 | 
			
		||||
    with pytest.raises(IndexError):
 | 
			
		||||
        beam.advance(scores[:-1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skip
 | 
			
		||||
def test_beam_parse():
 | 
			
		||||
    nlp = Language()
 | 
			
		||||
    config = {"learn_tokens": False, "min_action_freq": 30, "beam_width":  1, "beam_update_prob": 1.0}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from thinc.api import Adam
 | 
			
		|||
from spacy.attrs import NORM
 | 
			
		||||
from spacy.vocab import Vocab
 | 
			
		||||
 | 
			
		||||
from spacy.gold import Example
 | 
			
		||||
from spacy.pipeline.defaults import default_parser
 | 
			
		||||
from spacy.tokens import Doc
 | 
			
		||||
from spacy.pipeline import DependencyParser
 | 
			
		||||
| 
						 | 
				
			
			@ -27,8 +28,8 @@ def parser(vocab):
 | 
			
		|||
    for i in range(10):
 | 
			
		||||
        losses = {}
 | 
			
		||||
        doc = Doc(vocab, words=["a", "b", "c", "d"])
 | 
			
		||||
        gold = dict(heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
 | 
			
		||||
        parser.update((doc, gold), sgd=sgd, losses=losses)
 | 
			
		||||
        example = Example.from_dict(doc, {"heads": [1, 1, 3, 3], "deps": ["left", "ROOT", "left", "ROOT"]})
 | 
			
		||||
        parser.update([example], sgd=sgd, losses=losses)
 | 
			
		||||
    return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user