"""Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
.conllu format for development data, allowing the official scorer to be used.
"""
from __future__ import unicode_literals
import plac
import attr
from pathlib import Path
import re
import json
import tqdm

import spacy
import spacy.util
from spacy.tokens import Token, Doc
from spacy.gold import Example
from spacy.pipeline._parser_internals.nonproj import projectivize
from collections import defaultdict
from spacy.matcher import Matcher

import itertools
import random
import numpy.random

from bin.ud import conll17_ud_eval

import spacy.lang.zh
import spacy.lang.ja

spacy.lang.zh.Chinese.Defaults.use_jieba = False
spacy.lang.ja.Japanese.Defaults.use_janome = False

random.seed(0)
numpy.random.seed(0)


################
# Data reading #
################

space_re = re.compile("\s+")


def split_text(text):
    return [space_re.sub(" ", par.strip()) for par in text.split("\n\n")]


def read_data(
    nlp,
    conllu_file,
    text_file,
    raw_text=True,
    oracle_segments=False,
    max_doc_length=None,
    limit=None,
):
    """Read the CONLLU format into Example objects. If raw_text=True,
    include Doc objects created using nlp.make_doc and then aligned against
    the gold-standard sequences. If oracle_segments=True, include Doc objects
    created from the gold-standard segments. At least one must be True."""
    if not raw_text and not oracle_segments:
        raise ValueError("At least one of raw_text or oracle_segments must be True")
    paragraphs = split_text(text_file.read())
    conllu = read_conllu(conllu_file)
    # sd is spacy doc; cd is conllu doc
    # cs is conllu sent, ct is conllu token
    docs = []
    golds = []
    for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
        sent_annots = []
        for cs in cd:
            sent = defaultdict(list)
            for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
                if "." in id_:
                    continue
                if "-" in id_:
                    continue
                id_ = int(id_) - 1
                head = int(head) - 1 if head != "0" else id_
                sent["words"].append(word)
                sent["tags"].append(tag)
                sent["heads"].append(head)
                sent["deps"].append("ROOT" if dep == "root" else dep)
                sent["spaces"].append(space_after == "_")
            sent["entities"] = ["-"] * len(sent["words"])
            sent["heads"], sent["deps"] = projectivize(sent["heads"], sent["deps"])
            if oracle_segments:
                docs.append(Doc(nlp.vocab, words=sent["words"], spaces=sent["spaces"]))
                golds.append(sent)

            sent_annots.append(sent)
            if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
                doc, gold = _make_gold(nlp, None, sent_annots)
                sent_annots = []
                docs.append(doc)
                golds.append(gold)
                if limit and len(docs) >= limit:
                    return golds_to_gold_data(docs, golds)

        if raw_text and sent_annots:
            doc, gold = _make_gold(nlp, None, sent_annots)
            docs.append(doc)
            golds.append(gold)
        if limit and len(docs) >= limit:
            return golds_to_gold_data(docs, golds)
    return golds_to_gold_data(docs, golds)


def read_conllu(file_):
    docs = []
    sent = []
    doc = []
    for line in file_:
        if line.startswith("# newdoc"):
            if doc:
                docs.append(doc)
            doc = []
        elif line.startswith("#"):
            continue
        elif not line.strip():
            if sent:
                doc.append(sent)
            sent = []
        else:
            sent.append(list(line.strip().split("\t")))
            if len(sent[-1]) != 10:
                print(repr(line))
                raise ValueError
    if sent:
        doc.append(sent)
    if doc:
        docs.append(doc)
    return docs


def _make_gold(nlp, text, sent_annots):
    # Flatten the conll annotations, and adjust the head indices
    gold = defaultdict(list)
    for sent in sent_annots:
        gold["heads"].extend(len(gold["words"]) + head for head in sent["heads"])
        for field in ["words", "tags", "deps", "entities", "spaces"]:
            gold[field].extend(sent[field])
    # Construct text if necessary
    assert len(gold["words"]) == len(gold["spaces"])
    if text is None:
        text = "".join(
            word + " " * space for word, space in zip(gold["words"], gold["spaces"])
        )
    doc = nlp.make_doc(text)
    gold.pop("spaces")
    return doc, gold


#############################
# Data transforms for spaCy #
#############################


def golds_to_gold_data(docs, golds):
    """Get out the training data format used by begin_training."""
    data = []
    for doc, gold in zip(docs, golds):
        example = Example.from_dict(doc, gold)
        data.append(example)
    return data


##############
# Evaluation #
##############


def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
    with text_loc.open("r", encoding="utf8") as text_file:
        texts = split_text(text_file.read())
        docs = list(nlp.pipe(texts))
    with sys_loc.open("w", encoding="utf8") as out_file:
        write_conllu(docs, out_file)
    with gold_loc.open("r", encoding="utf8") as gold_file:
        gold_ud = conll17_ud_eval.load_conllu(gold_file)
        with sys_loc.open("r", encoding="utf8") as sys_file:
            sys_ud = conll17_ud_eval.load_conllu(sys_file)
        scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
    return scores


def write_conllu(docs, file_):
    merger = Matcher(docs[0].vocab)
    merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
    for i, doc in enumerate(docs):
        matches = merger(doc)
        spans = [doc[start : end + 1] for _, start, end in matches]
        offsets = [(span.start_char, span.end_char) for span in spans]
        for start_char, end_char in offsets:
            doc.merge(start_char, end_char)
        file_.write("# newdoc id = {i}\n".format(i=i))
        for j, sent in enumerate(doc.sents):
            file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
            file_.write("# text = {text}\n".format(text=sent.text))
            for k, token in enumerate(sent):
                file_.write(token._.get_conllu_lines(k) + "\n")
            file_.write("\n")


def print_progress(itn, losses, ud_scores):
    fields = {
        "dep_loss": losses.get("parser", 0.0),
        "tag_loss": losses.get("tagger", 0.0),
        "words": ud_scores["Words"].f1 * 100,
        "sents": ud_scores["Sentences"].f1 * 100,
        "tags": ud_scores["XPOS"].f1 * 100,
        "uas": ud_scores["UAS"].f1 * 100,
        "las": ud_scores["LAS"].f1 * 100,
    }
    header = ["Epoch", "Loss", "LAS", "UAS", "TAG", "SENT", "WORD"]
    if itn == 0:
        print("\t".join(header))
    tpl = "\t".join(
        (
            "{:d}",
            "{dep_loss:.1f}",
            "{las:.1f}",
            "{uas:.1f}",
            "{tags:.1f}",
            "{sents:.1f}",
            "{words:.1f}",
        )
    )
    print(tpl.format(itn, **fields))


# def get_sent_conllu(sent, sent_id):
#    lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]


def get_token_conllu(token, i):
    if token._.begins_fused:
        n = 1
        while token.nbor(n)._.inside_fused:
            n += 1
        id_ = "%d-%d" % (i, i + n)
        lines = [id_, token.text, "_", "_", "_", "_", "_", "_", "_", "_"]
    else:
        lines = []
    if token.head.i == token.i:
        head = 0
    else:
        head = i + (token.head.i - token.i) + 1
    fields = [
        str(i + 1),
        token.text,
        token.lemma_,
        token.pos_,
        token.tag_,
        "_",
        str(head),
        token.dep_.lower(),
        "_",
        "_",
    ]
    lines.append("\t".join(fields))
    return "\n".join(lines)


##################
# Initialization #
##################


def load_nlp(corpus, config):
    lang = corpus.split("_")[0]
    nlp = spacy.blank(lang)
    if config.vectors:
        nlp.vocab.from_disk(config.vectors / "vocab")
    return nlp


def initialize_pipeline(nlp, examples, config):
    nlp.add_pipe(nlp.create_pipe("parser"))
    if config.multitask_tag:
        nlp.parser.add_multitask_objective("tag")
    if config.multitask_sent:
        nlp.parser.add_multitask_objective("sent_start")
    nlp.parser.moves.add_action(2, "subtok")
    nlp.add_pipe(nlp.create_pipe("tagger"))
    for eg in examples:
        for tag in eg.get_aligned("TAG", as_string=True):
            if tag is not None:
                nlp.tagger.add_label(tag)
    # Replace labels that didn't make the frequency cutoff
    actions = set(nlp.parser.labels)
    label_set = set([act.split("-")[1] for act in actions if "-" in act])
    for eg in examples:
        gold = eg.gold
        for i, label in enumerate(gold.labels):
            if label is not None and label not in label_set:
                gold.labels[i] = label.split("||")[0]
    return nlp.begin_training(lambda: examples)


########################
# Command line helpers #
########################


@attr.s
class Config(object):
    vectors = attr.ib(default=None)
    max_doc_length = attr.ib(default=10)
    multitask_tag = attr.ib(default=True)
    multitask_sent = attr.ib(default=True)
    nr_epoch = attr.ib(default=30)
    batch_size = attr.ib(default=1000)
    dropout = attr.ib(default=0.2)

    @classmethod
    def load(cls, loc):
        with Path(loc).open("r", encoding="utf8") as file_:
            cfg = json.load(file_)
        return cls(**cfg)


class Dataset(object):
    def __init__(self, path, section):
        self.path = path
        self.section = section
        self.conllu = None
        self.text = None
        for file_path in self.path.iterdir():
            name = file_path.parts[-1]
            if section in name and name.endswith("conllu"):
                self.conllu = file_path
            elif section in name and name.endswith("txt"):
                self.text = file_path
        if self.conllu is None:
            msg = "Could not find .txt file in {path} for {section}"
            raise IOError(msg.format(section=section, path=path))
        if self.text is None:
            msg = "Could not find .txt file in {path} for {section}"
        self.lang = self.conllu.parts[-1].split("-")[0].split("_")[0]


class TreebankPaths(object):
    def __init__(self, ud_path, treebank, **cfg):
        self.train = Dataset(ud_path / treebank, "train")
        self.dev = Dataset(ud_path / treebank, "dev")
        self.lang = self.train.lang


@plac.annotations(
    ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
    parses_dir=("Directory to write the development parses", "positional", None, Path),
    config=("Path to json formatted config file", "positional", None, Config.load),
    corpus=(
        "UD corpus to train and evaluate on, e.g. UD_Spanish-AnCora",
        "positional",
        None,
        str,
    ),
    limit=("Size limit", "option", "n", int),
)
def main(ud_dir, parses_dir, config, corpus, limit=0):
    Token.set_extension("get_conllu_lines", method=get_token_conllu)
    Token.set_extension("begins_fused", default=False)
    Token.set_extension("inside_fused", default=False)

    Token.set_extension("get_conllu_lines", method=get_token_conllu)
    Token.set_extension("begins_fused", default=False)
    Token.set_extension("inside_fused", default=False)

    paths = TreebankPaths(ud_dir, corpus)
    if not (parses_dir / corpus).exists():
        (parses_dir / corpus).mkdir()
    print("Train and evaluate", corpus, "using lang", paths.lang)
    nlp = load_nlp(paths.lang, config)

    examples = read_data(
        nlp,
        paths.train.conllu.open(encoding="utf8"),
        paths.train.text.open(encoding="utf8"),
        max_doc_length=config.max_doc_length,
        limit=limit,
    )

    optimizer = initialize_pipeline(nlp, examples, config)

    for i in range(config.nr_epoch):
        batches = spacy.minibatch_by_words(examples, size=config.batch_size)
        losses = {}
        n_train_words = sum(len(eg.reference.doc) for eg in examples)
        with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
            for batch in batches:
                pbar.update(sum(len(eg.reference.doc) for eg in batch))
                nlp.update(
                    examples=batch, sgd=optimizer, drop=config.dropout, losses=losses,
                )

        out_path = parses_dir / corpus / "epoch-{i}.conllu".format(i=i)
        with nlp.use_params(optimizer.averages):
            scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
            print_progress(i, losses, scores)


if __name__ == "__main__":
    plac.call(main)