#!/usr/bin/env python from __future__ import division from __future__ import unicode_literals import os from os import path import shutil import io import random import time import gzip import plac import cProfile import pstats import spacy.util from spacy.en import English from spacy.gold import GoldParse from spacy.syntax.util import Config from spacy.syntax.arc_eager import ArcEager from spacy.syntax.parser import Parser from spacy.scorer import Scorer from spacy.tagger import Tagger # Last updated for spaCy v0.97 def read_conll(file_): """Read a standard CoNLL/MALT-style format""" sents = [] for sent_str in file_.read().strip().split('\n\n'): ids = [] words = [] heads = [] labels = [] tags = [] for i, line in enumerate(sent_str.split('\n')): word, pos_string, head_idx, label = _parse_line(line) words.append(word) if head_idx < 0: head_idx = i ids.append(i) heads.append(head_idx) labels.append(label) tags.append(pos_string) text = ' '.join(words) annot = (ids, words, tags, heads, labels, ['O'] * len(ids)) sents.append((None, [(annot, [])])) return sents def _parse_line(line): pieces = line.split() if len(pieces) == 4: word, pos, head_idx, label = pieces head_idx = int(head_idx) elif len(pieces) == 15: id_ = int(pieces[0].split('_')[-1]) word = pieces[1] pos = pieces[4] head_idx = int(pieces[8])-1 label = pieces[10] else: id_ = int(pieces[0].split('_')[-1]) word = pieces[1] pos = pieces[4] head_idx = int(pieces[6])-1 label = pieces[7] if head_idx == 0: label = 'ROOT' return word, pos, head_idx, label def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False): tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) nlp.tagger(tokens) nlp.parser(tokens) gold = GoldParse(tokens, annot_tuples, make_projective=False) scorer.score(tokens, gold, verbose=verbose, punct_labels=('--', 'p', 'punct')) def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0, gold_preproc=False, force_gold=False): dep_model_dir = path.join(model_dir, 'deps') pos_model_dir = path.join(model_dir, 'pos') if path.exists(dep_model_dir): shutil.rmtree(dep_model_dir) if path.exists(pos_model_dir): shutil.rmtree(pos_model_dir) os.mkdir(dep_model_dir) os.mkdir(pos_model_dir) Config.write(dep_model_dir, 'config', features=feat_set, seed=seed, labels=ArcEager.get_labels(gold_tuples)) nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False) nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates()) nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager) print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %") for itn in range(n_iter): scorer = Scorer() loss = 0 for _, sents in gold_tuples: for annot_tuples, _ in sents: if len(annot_tuples[1]) == 1: continue score_model(scorer, nlp, None, annot_tuples, verbose=False) tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1]) nlp.tagger(tokens) gold = GoldParse(tokens, annot_tuples, make_projective=True) if not gold.is_projective: raise Exception( "Non-projective sentence in training, after we should " "have enforced projectivity: %s" % annot_tuples ) loss += nlp.parser.train(tokens, gold) nlp.tagger.train(tokens, gold.tags) random.shuffle(gold_tuples) print('%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc, scorer.token_acc)) print('end training') nlp.end_training(model_dir) print('done') @plac.annotations( train_loc=("Location of CoNLL 09 formatted training file"), dev_loc=("Location of CoNLL 09 formatted development file"), model_dir=("Location of output model directory"), eval_only=("Skip training, and only evaluate", "flag", "e", bool), n_iter=("Number of training iterations", "option", "i", int), ) def main(train_loc, dev_loc, model_dir, n_iter=15): with io.open(train_loc, 'r', encoding='utf8') as file_: train_sents = read_conll(file_) if not eval_only: train(English, train_sents, model_dir, n_iter=n_iter) nlp = English(data_dir=model_dir) dev_sents = read_conll(io.open(dev_loc, 'r', encoding='utf8')) scorer = Scorer() for _, sents in dev_sents: for annot_tuples, _ in sents: score_model(scorer, nlp, None, annot_tuples) print('TOK', 100-scorer.token_acc) print('POS', scorer.tags_acc) print('UAS', scorer.uas) print('LAS', scorer.las) if __name__ == '__main__': plac.call(main)