import plac
import json
from os import path
import shutil
import os
import random

from spacy.syntax.util import Config
from spacy.gold import GoldParse
from spacy.tokenizer import Tokenizer
from spacy.vocab import Vocab
from spacy.tagger import Tagger
from spacy.syntax.parser import Parser
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.parser import get_templates
from spacy.scorer import Scorer

from spacy.language import Language

from spacy.tagger import W_orth

TAGGER_TEMPLATES = (
    (W_orth,),
)

try:
    from codecs import open
except ImportError:
    pass


class TreebankParser(object):
    @staticmethod
    def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
        dep_model_dir = path.join(model_dir, 'deps')
        pos_model_dir = path.join(model_dir, 'pos')
        if path.exists(dep_model_dir):
            shutil.rmtree(dep_model_dir)
        if path.exists(pos_model_dir):
            shutil.rmtree(pos_model_dir)
        os.mkdir(dep_model_dir)
        os.mkdir(pos_model_dir)

        Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
                     labels=labels)

    @classmethod
    def from_dir(cls, tag_map, model_dir):
        vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
        tokenizer = Tokenizer(vocab, {}, None, None, None)
        tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)

        cfg = Config.read(path.join(model_dir, 'deps'), 'config')
        parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
        return cls(vocab, tokenizer, tagger, parser)

    def __init__(self, vocab, tokenizer, tagger, parser):
        self.vocab = vocab
        self.tokenizer = tokenizer
        self.tagger = tagger
        self.parser = parser

    def train(self, words, tags, heads, deps):
        tokens = self.tokenizer.tokens_from_list(list(words))
        self.tagger.train(tokens, tags)
        
        tokens = self.tokenizer.tokens_from_list(list(words))
        ids = range(len(words))
        ner = ['O'] * len(words)
        gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
                         make_projective=False)
        self.tagger(tokens)
        if gold.is_projective:
            try:
                self.parser.train(tokens, gold)
            except:
                for id_, word, head, dep in zip(ids, words, heads, deps):
                    print(id_, word, head, dep)
                raise

    def __call__(self, words, tags=None):
        tokens = self.tokenizer.tokens_from_list(list(words))
        if tags is None:
            self.tagger(tokens)
        else:
            self.tagger.tag_from_strings(tokens, tags)
        self.parser(tokens)
        return tokens

    def end_training(self, data_dir):
        self.parser.model.end_training(path.join(data_dir, 'deps', 'model'))
        self.tagger.model.end_training(path.join(data_dir, 'pos', 'model'))
        self.vocab.strings.dump(path.join(data_dir, 'vocab', 'strings.txt'))
 

def read_conllx(loc):
    with open(loc, 'r', 'utf8') as file_:
        text = file_.read()
    for sent in text.strip().split('\n\n'):
        lines = sent.strip().split('\n')
        if lines:
            if lines[0].startswith('#'):
                lines.pop(0)
            tokens = []
            for line in lines:
                id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
                if '-' in id_:
                    continue
                id_ = int(id_) - 1
                head = (int(head) - 1) if head != '0' else id_
                dep = 'ROOT' if dep == 'root' else dep
                tokens.append((id_, word, tag, head, dep, 'O'))
            tuples = zip(*tokens)
            yield (None, [(tuples, [])])


def score_model(nlp, gold_docs, verbose=False):
    scorer = Scorer()
    for _, gold_doc in gold_docs:
        for annot_tuples, _ in gold_doc:
            tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2]))
            gold = GoldParse(tokens, annot_tuples)
            scorer.score(tokens, gold, verbose=verbose)
    return scorer


def main(train_loc, dev_loc, model_dir, tag_map_loc):
    with open(tag_map_loc) as file_:
        tag_map = json.loads(file_.read())
    train_sents = list(read_conllx(train_loc))
    labels = ArcEager.get_labels(train_sents)
    templates = get_templates('basic')

    TreebankParser.setup_model_dir(model_dir, labels, templates)
    
    nlp = TreebankParser.from_dir(tag_map, model_dir)

    for itn in range(15):
        for _, doc_sents in train_sents:
            for (ids, words, tags, heads, deps, ner), _ in doc_sents:
                nlp.train(words, tags, heads, deps)
        random.shuffle(train_sents)
        scorer = score_model(nlp, read_conllx(dev_loc))
        print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
    nlp.end_training(model_dir)
    scorer = score_model(nlp, read_conllx(dev_loc))
    print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
 

if __name__ == '__main__':
    plac.call(main)