mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Merge remote-tracking branch 'refs/remotes/honnibal/master'
This commit is contained in:
commit
90c6c5fabf
|
@ -384,16 +384,16 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for i in range(st.length):
|
for i in range(st.length):
|
||||||
# Always attach spaces to the previous word
|
# Always attach spaces to the previous word
|
||||||
if Lexeme.c_check_flag(st._sent[i].lex, IS_SPACE):
|
if Lexeme.c_check_flag(st._sent[i].lex, IS_SPACE):
|
||||||
if i >= 1:
|
|
||||||
st.add_arc(i-1, i, st._sent[i].dep)
|
|
||||||
else:
|
|
||||||
st.add_arc(i+1, i, st._sent[i].dep)
|
|
||||||
if st._sent[i].sent_start and st._sent[i].head == -1:
|
if st._sent[i].sent_start and st._sent[i].head == -1:
|
||||||
st._sent[i].sent_start = False
|
st._sent[i].sent_start = False
|
||||||
# If we had this space token as the start of a sentence,
|
# If we had this space token as the start of a sentence,
|
||||||
# move that sentence start forward one
|
# move that sentence start forward one
|
||||||
if (i + 1) < st.length and not st._sent[i+1].sent_start:
|
if (i + 1) < st.length and not st._sent[i+1].sent_start:
|
||||||
st._sent[i+1].sent_start = True
|
st._sent[i+1].sent_start = True
|
||||||
|
if i >= 1:
|
||||||
|
st.add_arc(i-1, i, st._sent[i].dep)
|
||||||
|
else:
|
||||||
|
st.add_arc(i+1, i, st._sent[i].dep)
|
||||||
elif st._sent[i].head == 0 and st._sent[i].dep == 0:
|
elif st._sent[i].head == 0 and st._sent[i].dep == 0:
|
||||||
st._sent[i].dep = self.root_label
|
st._sent[i].dep = self.root_label
|
||||||
# If we're not using the Break transition, we segment via root-labelled
|
# If we're not using the Break transition, we segment via root-labelled
|
||||||
|
|
|
@ -25,12 +25,14 @@ from thinc.learner import LinearModel
|
||||||
|
|
||||||
class TestLoadVocab(unittest.TestCase):
|
class TestLoadVocab(unittest.TestCase):
|
||||||
def test_load(self):
|
def test_load(self):
|
||||||
|
if path.exists(path.join(English.default_data_dir(), 'vocab')):
|
||||||
vocab = Vocab.from_dir(path.join(English.default_data_dir(), 'vocab'))
|
vocab = Vocab.from_dir(path.join(English.default_data_dir(), 'vocab'))
|
||||||
|
|
||||||
|
|
||||||
class TestLoadTokenizer(unittest.TestCase):
|
class TestLoadTokenizer(unittest.TestCase):
|
||||||
def test_load(self):
|
def test_load(self):
|
||||||
data_dir = English.default_data_dir()
|
data_dir = English.default_data_dir()
|
||||||
|
if path.exists(path.join(data_dir, 'vocab')):
|
||||||
vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
|
vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
|
||||||
tokenizer = Tokenizer.from_dir(vocab, path.join(data_dir, 'tokenizer'))
|
tokenizer = Tokenizer.from_dir(vocab, path.join(data_dir, 'tokenizer'))
|
||||||
|
|
||||||
|
@ -38,6 +40,8 @@ class TestLoadTokenizer(unittest.TestCase):
|
||||||
class TestLoadTagger(unittest.TestCase):
|
class TestLoadTagger(unittest.TestCase):
|
||||||
def test_load(self):
|
def test_load(self):
|
||||||
data_dir = English.default_data_dir()
|
data_dir = English.default_data_dir()
|
||||||
|
|
||||||
|
if path.exists(path.join(data_dir, 'vocab')):
|
||||||
vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
|
vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
|
||||||
tagger = Tagger.from_dir(path.join(data_dir, 'tagger'), vocab)
|
tagger = Tagger.from_dir(path.join(data_dir, 'tagger'), vocab)
|
||||||
|
|
||||||
|
@ -45,7 +49,9 @@ class TestLoadTagger(unittest.TestCase):
|
||||||
class TestLoadParser(unittest.TestCase):
|
class TestLoadParser(unittest.TestCase):
|
||||||
def test_load(self):
|
def test_load(self):
|
||||||
data_dir = English.default_data_dir()
|
data_dir = English.default_data_dir()
|
||||||
|
if path.exists(path.join(data_dir, 'vocab')):
|
||||||
vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
|
vocab = Vocab.from_dir(path.join(data_dir, 'vocab'))
|
||||||
|
if path.exists(path.join(data_dir, 'deps')):
|
||||||
parser = Parser.from_dir(path.join(data_dir, 'deps'), vocab.strings, ArcEager)
|
parser = Parser.from_dir(path.join(data_dir, 'deps'), vocab.strings, ArcEager)
|
||||||
|
|
||||||
def test_load_careful(self):
|
def test_load_careful(self):
|
||||||
|
@ -67,6 +73,7 @@ class TestLoadParser(unittest.TestCase):
|
||||||
|
|
||||||
# n classes. moves.n_moves above
|
# n classes. moves.n_moves above
|
||||||
# n features. len(templates) + 1 above
|
# n features. len(templates) + 1 above
|
||||||
|
if path.exists(model_loc):
|
||||||
model = LinearModel(92, 116)
|
model = LinearModel(92, 116)
|
||||||
model.load(model_loc)
|
model.load(model_loc)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user