* Allow parsers and taggers to be trained on text without gold pre-processing.

This commit is contained in:
Matthew Honnibal 2015-01-30 16:36:24 +11:00
parent 67d6e53a69
commit ca7577d8a9

View File

@ -9,6 +9,7 @@ import codecs
import random import random
import time import time
import gzip import gzip
import nltk
import plac import plac
import cProfile import cProfile
@ -22,6 +23,10 @@ from spacy.syntax.parser import GreedyParser
from spacy.syntax.util import Config from spacy.syntax.util import Config
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'
def read_tokenized_gold(file_): def read_tokenized_gold(file_):
"""Read a standard CoNLL/MALT-style format""" """Read a standard CoNLL/MALT-style format"""
sents = [] sents = []
@ -96,20 +101,20 @@ def _parse_line(line):
loss = 0
def _align_annotations_to_non_gold_tokens(tokens, words, annot): def _align_annotations_to_non_gold_tokens(tokens, words, annot):
global loss
tags = [] tags = []
heads = [] heads = []
labels = [] labels = []
loss = 0 orig_words = list(words)
print [t.orth_ for t in tokens] missed = []
print words
for token in tokens: for token in tokens:
print token.orth_, words[0]
print token.idx, annot[0][0]
while annot and token.idx > annot[0][0]: while annot and token.idx > annot[0][0]:
print 'pop', token.idx, annot[0][0] miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
annot.pop(0) miss_w = words.pop(0)
words.pop(0) if not is_punct_label(miss_label):
missed.append(miss_w)
loss += 1 loss += 1
if not annot: if not annot:
tags.append(None) tags.append(None)
@ -129,6 +134,11 @@ def _align_annotations_to_non_gold_tokens(tokens, words, annot):
labels.append(None) labels.append(None)
else: else:
raise StandardError raise StandardError
#if missed:
# print orig_words
# print missed
# for t in tokens:
# print t.idx, t.orth_
return loss, tags, heads, labels return loss, tags, heads, labels
@ -137,7 +147,8 @@ def iter_data(paragraphs, tokenizer, gold_preproc=False):
if not gold_preproc: if not gold_preproc:
tokens = tokenizer(raw) tokens = tokenizer(raw)
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens( loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
tokens, words, zip(ids, tags, heads, labels)) tokens, list(words),
zip(ids, tags, heads, labels))
ids = [t.idx for t in tokens] ids = [t.idx for t in tokens]
heads = _map_indices_to_tokens(ids, heads) heads = _map_indices_to_tokens(ids, heads)
yield tokens, tags, heads, labels yield tokens, tags, heads, labels
@ -170,7 +181,7 @@ def get_labels(sents):
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0, def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
gold_preproc=True): gold_preproc=False):
dep_model_dir = path.join(model_dir, 'deps') dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos') pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir): if path.exists(dep_model_dir):
@ -194,10 +205,9 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
n_tokens = 0 n_tokens = 0
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer, for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
gold_preproc=gold_preproc): gold_preproc=gold_preproc):
tags = [nlp.tagger.tag_names.index(tag) for tag in tag_strs]
nlp.tagger(tokens) nlp.tagger(tokens)
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False) heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
pos_corr += nlp.tagger.train(tokens, tags) pos_corr += nlp.tagger.train(tokens, tag_strs)
n_tokens += len(tokens) n_tokens += len(tokens)
acc = float(heads_corr) / n_tokens acc = float(heads_corr) / n_tokens
pos_acc = float(pos_corr) / n_tokens pos_acc = float(pos_corr) / n_tokens
@ -223,12 +233,13 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
if heads[i] is None: if heads[i] is None:
skipped += 1 skipped += 1
if labels[i] == 'P' or labels[i] == 'punct': continue
if is_punct_label(labels[i]):
continue continue
n_corr += token.head.i == heads[i] n_corr += token.head.i == heads[i]
total += 1 total += 1
print skipped print loss, skipped, (loss+skipped + total)
return float(n_corr) / total return float(n_corr) / (total + loss)
def main(train_loc, dev_loc, model_dir): def main(train_loc, dev_loc, model_dir):