spaCy/bin/parser/train.py

281 lines
9.0 KiB
Python
Raw Normal View History

2015-01-09 20:53:26 +03:00
#!/usr/bin/env python
from __future__ import division
from __future__ import unicode_literals
import os
from os import path
import shutil
import codecs
import random
import time
import gzip
import plac
import cProfile
import pstats
import spacy.util
from spacy.en import English
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
from spacy.syntax.parser import GreedyParser
from spacy.syntax.parser import OracleError
2015-01-09 20:53:26 +03:00
from spacy.syntax.util import Config
from spacy.syntax.conll import read_docparse_file
from spacy.syntax.conll import GoldParse
2015-01-09 20:53:26 +03:00
from spacy.scorer import Scorer
2015-01-09 20:53:26 +03:00
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'
2015-01-09 20:53:26 +03:00
def read_tokenized_gold(file_):
"""Read a standard CoNLL/MALT-style format"""
sents = []
for sent_str in file_.read().strip().split('\n\n'):
ids = []
2015-01-09 20:53:26 +03:00
words = []
heads = []
labels = []
tags = []
for i, line in enumerate(sent_str.split('\n')):
id_, word, pos_string, head_idx, label = _parse_line(line)
2015-01-09 20:53:26 +03:00
words.append(word)
if head_idx == -1:
head_idx = i
ids.append(id_)
2015-01-09 20:53:26 +03:00
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
text = ' '.join(words)
sents.append((text, [words], ids, words, tags, heads, labels))
2015-01-09 20:53:26 +03:00
return sents
def read_docparse_gold(file_):
paragraphs = []
2015-02-09 11:57:56 +03:00
for sent_str in file_.read().strip().split('\n\n'):
if not sent_str.strip():
continue
2015-01-09 20:53:26 +03:00
words = []
heads = []
labels = []
tags = []
ids = []
2015-01-09 20:53:26 +03:00
lines = sent_str.strip().split('\n')
2015-02-18 06:02:09 +03:00
raw_text = lines.pop(0).strip()
tok_text = lines.pop(0).strip()
for i, line in enumerate(lines):
id_, word, pos_string, head_idx, label = _parse_line(line)
if label == 'root':
label = 'ROOT'
2015-01-09 20:53:26 +03:00
words.append(word)
if head_idx < 0:
head_idx = id_
ids.append(id_)
2015-01-09 20:53:26 +03:00
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
for sent_str in tok_text.split('<SENT>')]
paragraphs.append((raw_text, tokenized, ids, words, tags, heads, labels))
return paragraphs
2015-01-09 20:53:26 +03:00
def _map_indices_to_tokens(ids, heads):
mapped = []
for head in heads:
if head not in ids:
mapped.append(None)
else:
mapped.append(ids.index(head))
return mapped
2015-01-09 20:53:26 +03:00
def _parse_line(line):
pieces = line.split()
if len(pieces) == 4:
return 0, pieces[0], pieces[1], int(pieces[2]) - 1, pieces[3]
2015-01-09 20:53:26 +03:00
else:
id_ = int(pieces[0])
2015-01-09 20:53:26 +03:00
word = pieces[1]
pos = pieces[3]
head_idx = int(pieces[6])
2015-01-09 20:53:26 +03:00
label = pieces[7]
return id_, word, pos, head_idx, label
2015-01-09 20:53:26 +03:00
loss = 0
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
global loss
tags = []
heads = []
labels = []
orig_words = list(words)
missed = []
for token in tokens:
while annot and token.idx > annot[0][0]:
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
miss_w = words.pop(0)
if not is_punct_label(miss_label):
missed.append(miss_w)
loss += 1
if not annot:
tags.append(None)
heads.append(None)
labels.append(None)
continue
id_, tag, head, label = annot[0]
if token.idx == id_:
tags.append(tag)
heads.append(head)
labels.append(label)
annot.pop(0)
words.pop(0)
elif token.idx < id_:
tags.append(None)
heads.append(None)
labels.append(None)
else:
raise StandardError
#if missed:
# print orig_words
# print missed
# for t in tokens:
# print t.idx, t.orth_
return loss, tags, heads, labels
def iter_data(paragraphs, tokenizer, gold_preproc=False):
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
if not gold_preproc:
tokens = tokenizer(raw)
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
tokens, list(words),
zip(ids, tags, heads, labels))
ids = [t.idx for t in tokens]
heads = _map_indices_to_tokens(ids, heads)
yield tokens, tags, heads, labels
else:
assert len(words) == len(heads)
for words in tokenized:
sent_ids = ids[:len(words)]
sent_tags = tags[:len(words)]
sent_heads = heads[:len(words)]
sent_labels = labels[:len(words)]
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
tokens = tokenizer.tokens_from_list(words)
yield tokens, sent_tags, sent_heads, sent_labels
ids = ids[len(words):]
tags = tags[len(words):]
heads = heads[len(words):]
labels = labels[len(words):]
2015-01-09 20:53:26 +03:00
def get_labels(sents):
left_labels = set()
right_labels = set()
for raw, tokenized, ids, words, tags, heads, labels in sents:
2015-01-09 20:53:26 +03:00
for child, (head, label) in enumerate(zip(heads, labels)):
if head > child:
left_labels.add(label)
elif head < child:
right_labels.add(label)
return list(sorted(left_labels)), list(sorted(right_labels))
def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
gold_preproc=False, force_gold=False, n_sents=0):
2015-01-09 20:53:26 +03:00
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
ner_model_dir = path.join(model_dir, 'ner')
2015-01-09 20:53:26 +03:00
if path.exists(dep_model_dir):
shutil.rmtree(dep_model_dir)
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
if path.exists(ner_model_dir):
shutil.rmtree(ner_model_dir)
2015-01-09 20:53:26 +03:00
os.mkdir(dep_model_dir)
os.mkdir(pos_model_dir)
os.mkdir(ner_model_dir)
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
gold_tuples = read_docparse_file(train_loc)
2015-01-09 20:53:26 +03:00
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=Language.ParserTransitionSystem.get_labels(gold_tuples))
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
if n_sents > 0:
gold_tuples = gold_tuples[:n_sents]
2015-01-09 20:53:26 +03:00
nlp = Language()
ent_strings = [None] * (max(nlp.entity.moves.label_ids.values()) + 1)
for label, i in nlp.entity.moves.label_ids.items():
ent_strings[i] = label
print "Itn.\tUAS\tNER F.\tTag %"
2015-01-09 20:53:26 +03:00
for itn in range(n_iter):
scorer = Scorer()
for raw_text, segmented_text, annot_tuples in gold_tuples:
if gold_preproc:
sents = [nlp.tokenizer.tokens_from_list(s) for s in segmented_text]
else:
sents = [nlp.tokenizer(raw_text)]
for tokens in sents:
gold = GoldParse(tokens, annot_tuples)
nlp.tagger(tokens)
nlp.entity.train(tokens, gold, force_gold=force_gold)
#nlp.parser.train(tokens, gold, force_gold=force_gold)
nlp.tagger.train(tokens, gold.tags)
nlp.entity(tokens)
tokens._ent_strings = tuple(ent_strings)
nlp.parser(tokens)
scorer.score(tokens, gold, verbose=False)
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
random.shuffle(gold_tuples)
2015-01-09 20:53:26 +03:00
nlp.parser.model.end_training()
nlp.entity.model.end_training()
2015-01-09 20:53:26 +03:00
nlp.tagger.model.end_training()
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
2015-01-31 05:44:37 +03:00
global loss
assert not gold_preproc
2015-01-09 20:53:26 +03:00
nlp = Language()
gold_tuples = read_docparse_file(dev_loc)
scorer = Scorer()
for raw_text, segmented_text, annot_tuples in gold_tuples:
tokens = nlp(raw_text)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False)
return scorer
2015-01-09 20:53:26 +03:00
2015-02-23 22:05:04 +03:00
@plac.annotations(
train_loc=("Training file location",),
dev_loc=("Dev. file location",),
model_dir=("Location of output model directory",),
n_sents=("Number of training sentences", "option", "n", int)
)
def main(train_loc, dev_loc, model_dir, n_sents=0):
train(English, train_loc, model_dir,
gold_preproc=False, force_gold=False, n_sents=n_sents)
scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False)
print 'POS', scorer.tags_acc
print 'UAS', scorer.uas
print 'LAS', scorer.las
print 'NER P', scorer.ents_p
print 'NER R', scorer.ents_r
print 'NER F', scorer.ents_f
2015-01-09 20:53:26 +03:00
if __name__ == '__main__':
plac.call(main)