* Fix test that was failing on travis

This commit is contained in:
Matthew Honnibal 2015-10-13 18:50:39 +11:00
parent 329ae57520
commit 7673e3a32c

View File

@ -25,12 +25,14 @@ from thinc.learner import LinearModel
class TestLoadVocab(unittest.TestCase): class TestLoadVocab(unittest.TestCase):
def test_load(self): def test_load(self):
if path.exists(path.join(English.default_data_dir(), 'vocab')):
vocab = Vocab.from_dir(path.join(English.default_data_dir(), 'vocab')) vocab = Vocab.from_dir(path.join(English.default_data_dir(), 'vocab'))
class TestLoadTokenizer(unittest.TestCase): class TestLoadTokenizer(unittest.TestCase):
def test_load(self): def test_load(self):
data_dir = English.default_data_dir() data_dir = English.default_data_dir()
if path.exists(path.join(data_dir, 'vocab')):
vocab = Vocab.from_dir(path.join(data_dir, 'vocab')) vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
tokenizer = Tokenizer.from_dir(vocab, path.join(data_dir, 'tokenizer')) tokenizer = Tokenizer.from_dir(vocab, path.join(data_dir, 'tokenizer'))
@ -38,6 +40,8 @@ class TestLoadTokenizer(unittest.TestCase):
class TestLoadTagger(unittest.TestCase): class TestLoadTagger(unittest.TestCase):
def test_load(self): def test_load(self):
data_dir = English.default_data_dir() data_dir = English.default_data_dir()
if path.exists(path.join(data_dir, 'vocab')):
vocab = Vocab.from_dir(path.join(data_dir, 'vocab')) vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
tagger = Tagger.from_dir(path.join(data_dir, 'tagger'), vocab) tagger = Tagger.from_dir(path.join(data_dir, 'tagger'), vocab)
@ -45,7 +49,9 @@ class TestLoadTagger(unittest.TestCase):
class TestLoadParser(unittest.TestCase): class TestLoadParser(unittest.TestCase):
def test_load(self): def test_load(self):
data_dir = English.default_data_dir() data_dir = English.default_data_dir()
if path.exists(path.join(data_dir, 'vocab')):
vocab = Vocab.from_dir(path.join(data_dir, 'vocab')) vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
if path.exists(path.join(data_dir, 'deps')):
parser = Parser.from_dir(path.join(data_dir, 'deps'), vocab.strings, ArcEager) parser = Parser.from_dir(path.join(data_dir, 'deps'), vocab.strings, ArcEager)
def test_load_careful(self): def test_load_careful(self):
@ -67,6 +73,7 @@ class TestLoadParser(unittest.TestCase):
# n classes. moves.n_moves above # n classes. moves.n_moves above
# n features. len(templates) + 1 above # n features. len(templates) + 1 above
if path.exists(model_loc):
model = LinearModel(92, 116) model = LinearModel(92, 116)
model.load(model_loc) model.load(model_loc)