mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 01:32:32 +03:00
* Work on alignment, for evaluation with non-gold preprocessing
This commit is contained in:
parent
ebf7d2fab1
commit
11ed65b93c
|
@ -45,7 +45,7 @@ def read_tokenized_gold(file_):
|
||||||
|
|
||||||
|
|
||||||
def read_docparse_gold(file_):
|
def read_docparse_gold(file_):
|
||||||
sents = []
|
paragraphs = []
|
||||||
for sent_str in file_.read().strip().split('\n\n'):
|
for sent_str in file_.read().strip().split('\n\n'):
|
||||||
words = []
|
words = []
|
||||||
heads = []
|
heads = []
|
||||||
|
@ -59,10 +59,6 @@ def read_docparse_gold(file_):
|
||||||
id_, word, pos_string, head_idx, label = _parse_line(line)
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||||
if label == 'root':
|
if label == 'root':
|
||||||
label = 'ROOT'
|
label = 'ROOT'
|
||||||
if pos_string == "``":
|
|
||||||
word = "``"
|
|
||||||
elif pos_string == "''":
|
|
||||||
word = "''"
|
|
||||||
words.append(word)
|
words.append(word)
|
||||||
if head_idx < 0:
|
if head_idx < 0:
|
||||||
head_idx = id_
|
head_idx = id_
|
||||||
|
@ -70,30 +66,20 @@ def read_docparse_gold(file_):
|
||||||
heads.append(head_idx)
|
heads.append(head_idx)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
tags.append(pos_string)
|
tags.append(pos_string)
|
||||||
heads = _map_indices_to_tokens(ids, heads)
|
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
|
||||||
words = tok_text.replace('<SENT>', ' ').replace('<SEP>', ' ').split()
|
for sent_str in tok_text.split('<SENT>')]
|
||||||
#print words
|
paragraphs.append((raw_text, tokenized, ids, words, tags, heads, labels))
|
||||||
#print heads
|
return paragraphs
|
||||||
sents.append((words, heads, labels, tags))
|
|
||||||
#sent_strings = tok_text.split('<SENT>')
|
|
||||||
#for sent in sent_strings:
|
|
||||||
# sent_words = sent.replace('<SEP>', ' ').split(' ')
|
|
||||||
# sent_heads = []
|
|
||||||
# sent_labels = []
|
|
||||||
# sent_tags = []
|
|
||||||
# sent_ids = []
|
|
||||||
# while len(sent_heads) < len(sent_words):
|
|
||||||
# sent_heads.append(heads.pop(0))
|
|
||||||
# sent_labels.append(labels.pop(0))
|
|
||||||
# sent_tags.append(tags.pop(0))
|
|
||||||
# sent_ids.append(ids.pop(0))
|
|
||||||
# sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
|
||||||
# sents.append((sent_words, sent_heads, sent_labels, sent_tags))
|
|
||||||
return sents
|
|
||||||
|
|
||||||
def _map_indices_to_tokens(ids, heads):
|
def _map_indices_to_tokens(ids, heads):
|
||||||
return [ids.index(head) for head in heads]
|
mapped = []
|
||||||
|
for head in heads:
|
||||||
|
if head not in ids:
|
||||||
|
mapped.append(None)
|
||||||
|
else:
|
||||||
|
mapped.append(ids.index(head))
|
||||||
|
return mapped
|
||||||
|
|
||||||
|
|
||||||
def _parse_line(line):
|
def _parse_line(line):
|
||||||
|
@ -108,10 +94,71 @@ def _parse_line(line):
|
||||||
label = pieces[7]
|
label = pieces[7]
|
||||||
return id_, word, pos, head_idx, label
|
return id_, word, pos, head_idx, label
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
||||||
|
tags = []
|
||||||
|
heads = []
|
||||||
|
labels = []
|
||||||
|
loss = 0
|
||||||
|
print [t.orth_ for t in tokens]
|
||||||
|
print words
|
||||||
|
for token in tokens:
|
||||||
|
print token.orth_, words[0]
|
||||||
|
while annot and token.idx > annot[0][0]:
|
||||||
|
annot.pop(0)
|
||||||
|
words.pop(0)
|
||||||
|
loss += 1
|
||||||
|
if not annot:
|
||||||
|
tags.append(None)
|
||||||
|
heads.append(None)
|
||||||
|
labels.append(None)
|
||||||
|
continue
|
||||||
|
id_, tag, head, label = annot[0]
|
||||||
|
if token.idx == id_:
|
||||||
|
tags.append(tag)
|
||||||
|
heads.append(head)
|
||||||
|
labels.append(label)
|
||||||
|
annot.pop(0)
|
||||||
|
words.pop(0)
|
||||||
|
elif token.idx < id_:
|
||||||
|
tags.append(None)
|
||||||
|
heads.append(None)
|
||||||
|
labels.append(None)
|
||||||
|
else:
|
||||||
|
raise StandardError
|
||||||
|
return loss, tags, heads, labels
|
||||||
|
|
||||||
|
|
||||||
|
def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
||||||
|
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
|
||||||
|
if not gold_preproc:
|
||||||
|
tokens = tokenizer(raw)
|
||||||
|
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
|
||||||
|
tokens, words, zip(ids, tags, heads, labels))
|
||||||
|
ids = [t.idx for t in tokens]
|
||||||
|
heads = _map_indices_to_tokens(ids, heads)
|
||||||
|
yield tokens, tags, heads, labels
|
||||||
|
else:
|
||||||
|
assert len(words) == len(heads)
|
||||||
|
for words in tokenized:
|
||||||
|
sent_ids = ids[:len(words)]
|
||||||
|
sent_tags = tags[:len(words)]
|
||||||
|
sent_heads = heads[:len(words)]
|
||||||
|
sent_labels = labels[:len(words)]
|
||||||
|
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
||||||
|
tokens = tokenizer.tokens_from_list(words)
|
||||||
|
yield tokens, sent_tags, sent_heads, sent_labels
|
||||||
|
ids = ids[len(words):]
|
||||||
|
tags = tags[len(words):]
|
||||||
|
heads = heads[len(words):]
|
||||||
|
labels = labels[len(words):]
|
||||||
|
|
||||||
|
|
||||||
def get_labels(sents):
|
def get_labels(sents):
|
||||||
left_labels = set()
|
left_labels = set()
|
||||||
right_labels = set()
|
right_labels = set()
|
||||||
for _, heads, labels, _ in sents:
|
for raw, tokenized, ids, words, tags, heads, labels in sents:
|
||||||
for child, (head, label) in enumerate(zip(heads, labels)):
|
for child, (head, label) in enumerate(zip(heads, labels)):
|
||||||
if head > child:
|
if head > child:
|
||||||
left_labels.add(label)
|
left_labels.add(label)
|
||||||
|
@ -120,7 +167,8 @@ def get_labels(sents):
|
||||||
return list(sorted(left_labels)), list(sorted(right_labels))
|
return list(sorted(left_labels)), list(sorted(right_labels))
|
||||||
|
|
||||||
|
|
||||||
def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
|
gold_preproc=True):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
|
@ -132,7 +180,7 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
||||||
pos_model_dir)
|
pos_model_dir)
|
||||||
|
|
||||||
left_labels, right_labels = get_labels(sents)
|
left_labels, right_labels = get_labels(paragraphs)
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
left_labels=left_labels, right_labels=right_labels)
|
left_labels=left_labels, right_labels=right_labels)
|
||||||
|
|
||||||
|
@ -142,62 +190,50 @@ def train(Language, sents, model_dir, n_iter=15, feat_set=u'basic', seed=0):
|
||||||
heads_corr = 0
|
heads_corr = 0
|
||||||
pos_corr = 0
|
pos_corr = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
for words, heads, labels, tags in sents:
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
tags = [nlp.tagger.tag_names.index(tag) for tag in tags]
|
gold_preproc=gold_preproc):
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
tags = [nlp.tagger.tag_names.index(tag) for tag in tag_strs]
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
try:
|
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
||||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
|
||||||
except:
|
|
||||||
print heads
|
|
||||||
raise
|
|
||||||
pos_corr += nlp.tagger.train(tokens, tags)
|
pos_corr += nlp.tagger.train(tokens, tags)
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
pos_acc = float(pos_corr) / n_tokens
|
pos_acc = float(pos_corr) / n_tokens
|
||||||
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
||||||
random.shuffle(sents)
|
random.shuffle(paragraphs)
|
||||||
nlp.parser.model.end_training()
|
nlp.parser.model.end_training()
|
||||||
nlp.tagger.model.end_training()
|
nlp.tagger.model.end_training()
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
def evaluate(Language, dev_loc, model_dir):
|
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
n_corr = 0
|
n_corr = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
skipped = 0
|
||||||
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
||||||
sents = read_docparse_gold(file_)
|
paragraphs = read_docparse_gold(file_)
|
||||||
for words, heads, labels, tags in sents:
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
tokens = nlp.tokenizer.tokens_from_list(words)
|
gold_preproc=gold_preproc):
|
||||||
|
assert len(tokens) == len(labels)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser(tokens)
|
nlp.parser(tokens)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
#print i, token.orth_, token.head.orth_, tokens[heads[i]].orth_, labels[i], token.head.i == heads[i]
|
if heads[i] is None:
|
||||||
|
skipped += 1
|
||||||
if labels[i] == 'P' or labels[i] == 'punct':
|
if labels[i] == 'P' or labels[i] == 'punct':
|
||||||
continue
|
continue
|
||||||
n_corr += token.head.i == heads[i]
|
n_corr += token.head.i == heads[i]
|
||||||
total += 1
|
total += 1
|
||||||
|
print skipped
|
||||||
return float(n_corr) / total
|
return float(n_corr) / total
|
||||||
|
|
||||||
|
|
||||||
PROFILE = False
|
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||||
train_sents = read_docparse_gold(file_)
|
train_sents = read_docparse_gold(file_)
|
||||||
train_sents = train_sents
|
#train(English, train_sents, model_dir, gold_preproc=False)
|
||||||
if PROFILE:
|
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
|
||||||
import cProfile
|
|
||||||
import pstats
|
|
||||||
cmd = "train(EN, train_sents, tag_names, model_dir, n_iter=2)"
|
|
||||||
cProfile.runctx(cmd, globals(), locals(), "Profile.prof")
|
|
||||||
s = pstats.Stats("Profile.prof")
|
|
||||||
s.strip_dirs().sort_stats("time").print_stats()
|
|
||||||
else:
|
|
||||||
train(English, train_sents, model_dir)
|
|
||||||
print evaluate(English, dev_loc, model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user