spaCy/bin/parser/train_ud.py

152 lines
5.0 KiB
Python

import plac
import json
from os import path
import shutil
import os
import random
from spacy.syntax.util import Config
from spacy.gold import GoldParse
from spacy.tokenizer import Tokenizer
from spacy.vocab import Vocab
from spacy.tagger import Tagger
from spacy.syntax.parser import Parser
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.parser import get_templates
from spacy.scorer import Scorer
from spacy.language import Language
from spacy.tagger import W_orth
TAGGER_TEMPLATES = (
(W_orth,),
)
try:
from codecs import open
except ImportError:
pass
class TreebankParser(object):
@staticmethod
def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
dep_model_dir = path.join(model_dir, 'deps')
pos_model_dir = path.join(model_dir, 'pos')
if path.exists(dep_model_dir):
shutil.rmtree(dep_model_dir)
if path.exists(pos_model_dir):
shutil.rmtree(pos_model_dir)
os.mkdir(dep_model_dir)
os.mkdir(pos_model_dir)
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
labels=labels)
@classmethod
def from_dir(cls, tag_map, model_dir):
vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
tokenizer = Tokenizer(vocab, {}, None, None, None)
tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
return cls(vocab, tokenizer, tagger, parser)
def __init__(self, vocab, tokenizer, tagger, parser):
self.vocab = vocab
self.tokenizer = tokenizer
self.tagger = tagger
self.parser = parser
def train(self, words, tags, heads, deps):
tokens = self.tokenizer.tokens_from_list(list(words))
self.tagger.train(tokens, tags)
tokens = self.tokenizer.tokens_from_list(list(words))
ids = range(len(words))
ner = ['O'] * len(words)
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
make_projective=False)
self.tagger(tokens)
if gold.is_projective:
try:
self.parser.train(tokens, gold)
except:
for id_, word, head, dep in zip(ids, words, heads, deps):
print(id_, word, head, dep)
raise
def __call__(self, words, tags=None):
tokens = self.tokenizer.tokens_from_list(list(words))
if tags is None:
self.tagger(tokens)
else:
self.tagger.tag_from_strings(tokens, tags)
self.parser(tokens)
return tokens
def end_training(self, data_dir):
self.parser.model.end_training(path.join(data_dir, 'deps', 'model'))
self.tagger.model.end_training(path.join(data_dir, 'pos', 'model'))
self.vocab.strings.dump(path.join(data_dir, 'vocab', 'strings.txt'))
def read_conllx(loc):
with open(loc, 'r', 'utf8') as file_:
text = file_.read()
for sent in text.strip().split('\n\n'):
lines = sent.strip().split('\n')
if lines:
if lines[0].startswith('#'):
lines.pop(0)
tokens = []
for line in lines:
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
if '-' in id_:
continue
id_ = int(id_) - 1
head = (int(head) - 1) if head != '0' else id_
dep = 'ROOT' if dep == 'root' else dep
tokens.append((id_, word, tag, head, dep, 'O'))
tuples = zip(*tokens)
yield (None, [(tuples, [])])
def score_model(nlp, gold_docs, verbose=False):
scorer = Scorer()
for _, gold_doc in gold_docs:
for annot_tuples, _ in gold_doc:
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2]))
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=verbose)
return scorer
def main(train_loc, dev_loc, model_dir, tag_map_loc):
with open(tag_map_loc) as file_:
tag_map = json.loads(file_.read())
train_sents = list(read_conllx(train_loc))
labels = ArcEager.get_labels(train_sents)
templates = get_templates('basic')
TreebankParser.setup_model_dir(model_dir, labels, templates)
nlp = TreebankParser.from_dir(tag_map, model_dir)
for itn in range(15):
for _, doc_sents in train_sents:
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
nlp.train(words, tags, heads, deps)
random.shuffle(train_sents)
scorer = score_model(nlp, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
nlp.end_training(model_dir)
scorer = score_model(nlp, read_conllx(dev_loc))
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
if __name__ == '__main__':
plac.call(main)