From 29b77fd0eba684eb6d6f54fee6f4b79345819235 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 1 Apr 2018 17:26:37 +0200 Subject: [PATCH] Add tests for gold alignment and parser state --- spacy/tests/gold/test_misaligned_gold.py | 28 ++++++ spacy/tests/parser/test_state.py | 107 +++++++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 spacy/tests/gold/test_misaligned_gold.py create mode 100644 spacy/tests/parser/test_state.py diff --git a/spacy/tests/gold/test_misaligned_gold.py b/spacy/tests/gold/test_misaligned_gold.py new file mode 100644 index 000000000..f8733e5fa --- /dev/null +++ b/spacy/tests/gold/test_misaligned_gold.py @@ -0,0 +1,28 @@ +'''Test logic for mapping annotations when the gold parse and the doc don't +align in tokenization.''' +from __future__ import unicode_literals +import pytest +from ...tokens import Doc +from ...gold import GoldParse +from ...gold import _flatten_fused_heads +from ...vocab import Vocab + +@pytest.mark.parametrize('fused,flat', [ + ([ [(0, 1), 1], 1], [1, 2, 2]), + ([1, 1, [1, 3], 1], [1, 1, 1, 4, 1]) +]) +def test_flatten_fused_heads(fused, flat): + assert _flatten_fused_heads(fused) == flat + + +def test_over_segmented(): + doc = Doc(Vocab(), words=['a', 'b', 'c']) + gold = GoldParse(doc, words=['ab', 'c'], heads=[1,1]) + assert gold.heads == [1, 2, 2] + assert gold.labels == ['subtok', None, None] + +def test_under_segmented(): + doc = Doc(Vocab(), words=['ab', 'c']) + gold = GoldParse(doc, words=['a', 'b', 'c'], heads=[2,2,2]) + assert gold.heads == [[1,1], 1] + assert gold.labels == [[None, None], None] diff --git a/spacy/tests/parser/test_state.py b/spacy/tests/parser/test_state.py new file mode 100644 index 000000000..18c425cbf --- /dev/null +++ b/spacy/tests/parser/test_state.py @@ -0,0 +1,107 @@ +import pytest + +from ...tokens.doc import Doc +from ...vocab import Vocab +from ...syntax.stateclass import StateClass + + +def get_doc(words, vocab=None): + if vocab is None: + vocab = Vocab() + return Doc(vocab, words=list(words)) + +def test_push(): + '''state.push_stack() should take the first word in the queue (aka buffer) + and put it on the stack, popping that word from the queue.''' + doc = get_doc('abcd') + state = StateClass(doc) + assert state.get_B(0) == 0 + state.push_stack() + assert state.get_B(0) == 1 + +def test_pop(): + '''state.pop_stack() should remove the top word from the stack.''' + doc = get_doc('abcd') + state = StateClass(doc) + assert state.get_B(0) == 0 + state.push_stack() + state.push_stack() + assert state.get_S(0) == 1 + assert state.get_S(1) == 0 + state.pop_stack() + assert state.get_S(0) == 0 + +def test_unshift(): + doc = get_doc('abcd') + state = StateClass(doc) + state.push_stack() + state.push_stack() + state.unshift() + assert state.get_B(0) == 1 + assert state.get_S(0) == 0 + +def test_break(): + doc = get_doc('abcd') + state = StateClass(doc) + state.push_stack() + state.push_stack() + assert state.get_B(0) == 2 + state.set_break(0) + assert state.get_B(0) == -1 + state.unshift() + assert state.get_B(0) == 1 + assert state.get_S(0) == 0 + state.push_stack() + assert state.get_B(0) == -1 + assert state.get_S(0) == 1 + state.push_stack() + assert state.get_B(0) == 3 + assert state.get_S(0) == 2 + +def test_cant_push_empty_buffer(): + doc = get_doc('a') + state = StateClass(doc) + state.push_stack() + assert not state.can_push() + +def test_cant_pop_empty_stack(): + doc = get_doc('a') + state = StateClass(doc) + assert not state.can_pop() + state.push_stack() + assert state.can_pop() + state.pop_stack() + assert not state.can_pop() + +def test_can_pop_empty_buffer(): + doc = get_doc('ab') + state = StateClass(doc) + state.push_stack() + state.push_stack() + assert state.can_pop() + +def test_cant_arc_empty_buffer(): + doc = get_doc('ab') + state = StateClass(doc) + state.push_stack() + state.push_stack() + assert not state.can_arc() + +def test_cant_arc_empty_stack(): + doc = get_doc('ab') + state = StateClass(doc) + assert not state.can_arc() + state.push_stack() + assert state.can_arc() + state.push_stack() + state.pop_stack() + state.pop_stack() + assert not state.can_arc() + +def test_cant_break_empty_buffer(): + doc = get_doc('ab') + state = StateClass(doc) + state.push_stack() + state.push_stack() + assert not state.can_break() +