mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 19:08:06 +03:00
* Messily use unsegmented sentences to train the parser
This commit is contained in:
parent
320b045daa
commit
b4348ce1c3
|
@ -26,6 +26,7 @@ def read_tokenized_gold(file_):
|
||||||
"""Read a standard CoNLL/MALT-style format"""
|
"""Read a standard CoNLL/MALT-style format"""
|
||||||
sents = []
|
sents = []
|
||||||
for sent_str in file_.read().strip().split('\n\n'):
|
for sent_str in file_.read().strip().split('\n\n'):
|
||||||
|
ids = []
|
||||||
words = []
|
words = []
|
||||||
heads = []
|
heads = []
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -35,10 +36,11 @@ def read_tokenized_gold(file_):
|
||||||
words.append(word)
|
words.append(word)
|
||||||
if head_idx == -1:
|
if head_idx == -1:
|
||||||
head_idx = i
|
head_idx = i
|
||||||
|
ids.append(id_)
|
||||||
heads.append(head_idx)
|
heads.append(head_idx)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
tags.append(pos_string)
|
tags.append(pos_string)
|
||||||
sents.append((words, heads, labels, tags))
|
sents.append((ids_, words, heads, labels, tags))
|
||||||
return sents
|
return sents
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,31 +51,62 @@ def read_docparse_gold(file_):
|
||||||
heads = []
|
heads = []
|
||||||
labels = []
|
labels = []
|
||||||
tags = []
|
tags = []
|
||||||
|
ids = []
|
||||||
lines = sent_str.strip().split('\n')
|
lines = sent_str.strip().split('\n')
|
||||||
raw_text = lines[0]
|
raw_text = lines[0]
|
||||||
tok_text = lines[1]
|
tok_text = lines[1]
|
||||||
for i, line in enumerate(lines[2:]):
|
for i, line in enumerate(lines[2:]):
|
||||||
word, pos_string, head_idx, label = _parse_line(line)
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||||
|
if label == 'root':
|
||||||
|
label = 'ROOT'
|
||||||
|
if pos_string == "``":
|
||||||
|
word = "``"
|
||||||
|
elif pos_string == "''":
|
||||||
|
word = "''"
|
||||||
words.append(word)
|
words.append(word)
|
||||||
if head_idx == -1:
|
if head_idx < 0:
|
||||||
head_idx = i
|
head_idx = id_
|
||||||
|
ids.append(id_)
|
||||||
heads.append(head_idx)
|
heads.append(head_idx)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
tags.append(pos_string)
|
tags.append(pos_string)
|
||||||
words = tok_text.replace('<SEP>', ' ').replace('<SENT>', ' ').split(' ')
|
heads = _map_indices_to_tokens(ids, heads)
|
||||||
|
words = tok_text.replace('<SENT>', ' ').replace('<SEP>', ' ').split()
|
||||||
|
#print words
|
||||||
|
#print heads
|
||||||
sents.append((words, heads, labels, tags))
|
sents.append((words, heads, labels, tags))
|
||||||
|
#sent_strings = tok_text.split('<SENT>')
|
||||||
|
#for sent in sent_strings:
|
||||||
|
# sent_words = sent.replace('<SEP>', ' ').split(' ')
|
||||||
|
# sent_heads = []
|
||||||
|
# sent_labels = []
|
||||||
|
# sent_tags = []
|
||||||
|
# sent_ids = []
|
||||||
|
# while len(sent_heads) < len(sent_words):
|
||||||
|
# sent_heads.append(heads.pop(0))
|
||||||
|
# sent_labels.append(labels.pop(0))
|
||||||
|
# sent_tags.append(tags.pop(0))
|
||||||
|
# sent_ids.append(ids.pop(0))
|
||||||
|
# sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
||||||
|
# sents.append((sent_words, sent_heads, sent_labels, sent_tags))
|
||||||
return sents
|
return sents
|
||||||
|
|
||||||
|
def _map_indices_to_tokens(ids, heads):
|
||||||
|
return [ids.index(head) for head in heads]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_line(line):
|
def _parse_line(line):
|
||||||
pieces = line.split()
|
pieces = line.split()
|
||||||
if len(pieces) == 4:
|
if len(pieces) == 4:
|
||||||
return pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
|
return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
|
||||||
else:
|
else:
|
||||||
|
id_ = int(pieces[0])
|
||||||
word = pieces[1]
|
word = pieces[1]
|
||||||
pos = pieces[3]
|
pos = pieces[3]
|
||||||
head_idx = int(pieces[6]) - 1
|
head_idx = int(pieces[6])
|
||||||
label = pieces[7]
|
label = pieces[7]
|
||||||
return word, pos, head_idx, label
|
return id_, word, pos, head_idx, label
|
||||||
|
|
||||||
def get_labels(sents):
|
def get_labels(sents):
|
||||||
left_labels = set()
|
left_labels = set()
|
||||||
|
@ -113,7 +146,11 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||||
tags = [nlp.tagger.tag_names.index(tag) for tag in tags]
|
tags = [nlp.tagger.tag_names.index(tag) for tag in tags]
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels)
|
try:
|
||||||
|
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
||||||
|
except:
|
||||||
|
print heads
|
||||||
|
raise
|
||||||
pos_corr += nlp.tagger.train(tokens, tags)
|
pos_corr += nlp.tagger.train(tokens, tags)
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
|
@ -122,7 +159,6 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||||
random.shuffle(sents)
|
random.shuffle(sents)
|
||||||
nlp.parser.model.end_training()
|
nlp.parser.model.end_training()
|
||||||
nlp.tagger.model.end_training()
|
nlp.tagger.model.end_training()
|
||||||
#nlp.parser.model.dump(path.join(dep_model_dir, 'model'), freq_thresh=0)
|
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,13 +167,13 @@ def evaluate(Language, dev_loc, model_dir):
|
||||||
n_corr = 0
|
n_corr = 0
|
||||||
total = 0
|
total = 0
|
||||||
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
||||||
sents = read_tokenized_gold(file_)
|
sents = read_docparse_gold(file_)
|
||||||
for words, heads, labels, tags in sents:
|
for words, heads, labels, tags in sents:
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
tokens = nlp.tokenizer.tokens_from_list(words)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser(tokens)
|
nlp.parser(tokens)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
#print i, token.string, i + token.head, heads[i], labels[i]
|
#print i, token.orth_, token.head.orth_, tokens[heads[i]].orth_, labels[i], token.head.i == heads[i]
|
||||||
if labels[i] == 'P' or labels[i] == 'punct':
|
if labels[i] == 'P' or labels[i] == 'punct':
|
||||||
continue
|
continue
|
||||||
n_corr += token.head.i == heads[i]
|
n_corr += token.head.i == heads[i]
|
||||||
|
@ -150,7 +186,8 @@ PROFILE = False
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||||
train_sents = read_tokenized_gold(file_)
|
train_sents = read_docparse_gold(file_)
|
||||||
|
train_sents = train_sents
|
||||||
if PROFILE:
|
if PROFILE:
|
||||||
import cProfile
|
import cProfile
|
||||||
import pstats
|
import pstats
|
||||||
|
|
Loading…
Reference in New Issue
Block a user