Add tests for gold alignment and parser state

This commit is contained in:
Matthew Honnibal 2018-04-01 17:26:37 +02:00
parent 3d182fbc43
commit 29b77fd0eb
2 changed files with 135 additions and 0 deletions

View File

@ -0,0 +1,28 @@
'''Test logic for mapping annotations when the gold parse and the doc don't
align in tokenization.'''
from __future__ import unicode_literals
import pytest
from ...tokens import Doc
from ...gold import GoldParse
from ...gold import _flatten_fused_heads
from ...vocab import Vocab
@pytest.mark.parametrize('fused,flat', [
([ [(0, 1), 1], 1], [1, 2, 2]),
([1, 1, [1, 3], 1], [1, 1, 1, 4, 1])
])
def test_flatten_fused_heads(fused, flat):
assert _flatten_fused_heads(fused) == flat
def test_over_segmented():
doc = Doc(Vocab(), words=['a', 'b', 'c'])
gold = GoldParse(doc, words=['ab', 'c'], heads=[1,1])
assert gold.heads == [1, 2, 2]
assert gold.labels == ['subtok', None, None]
def test_under_segmented():
doc = Doc(Vocab(), words=['ab', 'c'])
gold = GoldParse(doc, words=['a', 'b', 'c'], heads=[2,2,2])
assert gold.heads == [[1,1], 1]
assert gold.labels == [[None, None], None]

View File

@ -0,0 +1,107 @@
import pytest
from ...tokens.doc import Doc
from ...vocab import Vocab
from ...syntax.stateclass import StateClass
def get_doc(words, vocab=None):
if vocab is None:
vocab = Vocab()
return Doc(vocab, words=list(words))
def test_push():
'''state.push_stack() should take the first word in the queue (aka buffer)
and put it on the stack, popping that word from the queue.'''
doc = get_doc('abcd')
state = StateClass(doc)
assert state.get_B(0) == 0
state.push_stack()
assert state.get_B(0) == 1
def test_pop():
'''state.pop_stack() should remove the top word from the stack.'''
doc = get_doc('abcd')
state = StateClass(doc)
assert state.get_B(0) == 0
state.push_stack()
state.push_stack()
assert state.get_S(0) == 1
assert state.get_S(1) == 0
state.pop_stack()
assert state.get_S(0) == 0
def test_unshift():
doc = get_doc('abcd')
state = StateClass(doc)
state.push_stack()
state.push_stack()
state.unshift()
assert state.get_B(0) == 1
assert state.get_S(0) == 0
def test_break():
doc = get_doc('abcd')
state = StateClass(doc)
state.push_stack()
state.push_stack()
assert state.get_B(0) == 2
state.set_break(0)
assert state.get_B(0) == -1
state.unshift()
assert state.get_B(0) == 1
assert state.get_S(0) == 0
state.push_stack()
assert state.get_B(0) == -1
assert state.get_S(0) == 1
state.push_stack()
assert state.get_B(0) == 3
assert state.get_S(0) == 2
def test_cant_push_empty_buffer():
doc = get_doc('a')
state = StateClass(doc)
state.push_stack()
assert not state.can_push()
def test_cant_pop_empty_stack():
doc = get_doc('a')
state = StateClass(doc)
assert not state.can_pop()
state.push_stack()
assert state.can_pop()
state.pop_stack()
assert not state.can_pop()
def test_can_pop_empty_buffer():
doc = get_doc('ab')
state = StateClass(doc)
state.push_stack()
state.push_stack()
assert state.can_pop()
def test_cant_arc_empty_buffer():
doc = get_doc('ab')
state = StateClass(doc)
state.push_stack()
state.push_stack()
assert not state.can_arc()
def test_cant_arc_empty_stack():
doc = get_doc('ab')
state = StateClass(doc)
assert not state.can_arc()
state.push_stack()
assert state.can_arc()
state.push_stack()
state.pop_stack()
state.pop_stack()
assert not state.can_arc()
def test_cant_break_empty_buffer():
doc = get_doc('ab')
state = StateClass(doc)
state.push_stack()
state.push_stack()
assert not state.can_break()