"""Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
.conllu format for development data, allowing the official scorer to be used.
"""
from __future__ import unicode_literals
import plac
import attr
from pathlib import Path
import re
import json
import tqdm

import spacy
import spacy.util
from spacy.tokens import Token, Doc
from spacy.gold import Example
from spacy.syntax.nonproj import projectivize
from collections import defaultdict
from spacy.matcher import Matcher

import itertools
import random
import numpy.random

from bin.ud import conll17_ud_eval

import spacy.lang.zh
import spacy.lang.ja

spacy.lang.zh.Chinese.Defaults.use_jieba = False
spacy.lang.ja.Japanese.Defaults.use_janome = False

random.seed(0)
numpy.random.seed(0)


################
# Data reading #
################

space_re = re.compile("\s+")


def split_text(text):
    return [space_re.sub(" ", par.strip()) for par in text.split("\n\n")]


def read_data(
    nlp,
    conllu_file,
    text_file,
    raw_text=True,
    oracle_segments=False,
    max_doc_length=None,
    limit=None,
):
    """Read the CONLLU format into Example objects. If raw_text=True,
    include Doc objects created using nlp.make_doc and then aligned against
    the gold-standard sequences. If oracle_segments=True, include Doc objects
    created from the gold-standard segments. At least one must be True."""
    if not raw_text and not oracle_segments:
        raise ValueError("At least one of raw_text or oracle_segments must be True")
    paragraphs = split_text(text_file.read())
    conllu = read_conllu(conllu_file)
    # sd is spacy doc; cd is conllu doc
    # cs is conllu sent, ct is conllu token
    docs = []
    golds = []
    for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
        sent_annots = []
        for cs in cd:
            sent = defaultdict(list)
            for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
                if "." in id_:
                    continue
                if "-" in id_:
                    continue
                id_ = int(id_) - 1
                head = int(head) - 1 if head != "0" else id_
                sent["words"].append(word)
                sent["tags"].append(tag)
                sent["heads"].append(head)
                sent["deps"].append("ROOT" if dep == "root" else dep)
                sent["spaces"].append(space_after == "_")
            sent["entities"] = ["-"] * len(sent["words"])
            sent["heads"], sent["deps"] = projectivize(sent["heads"], sent["deps"])
            if oracle_segments:
                docs.append(Doc(nlp.vocab, words=sent["words"], spaces=sent["spaces"]))
                golds.append(sent)

            sent_annots.append(sent)
            if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
                doc, gold = _make_gold(nlp, None, sent_annots)
                sent_annots = []
                docs.append(doc)
                golds.append(gold)
                if limit and len(docs) >= limit:
                    return golds_to_gold_data(docs, golds)

        if raw_text and sent_annots:
            doc, gold = _make_gold(nlp, None, sent_annots)
            docs.append(doc)
            golds.append(gold)
        if limit and len(docs) >= limit:
            return golds_to_gold_data(docs, golds)
    return golds_to_gold_data(docs, golds)


def read_conllu(file_):
    docs = []
    sent = []
    doc = []
    for line in file_:
        if line.startswith("# newdoc"):
            if doc:
                docs.append(doc)
            doc = []
        elif line.startswith("#"):
            continue
        elif not line.strip():
            if sent:
                doc.append(sent)
            sent = []
        else:
            sent.append(list(line.strip().split("\t")))
            if len(sent[-1]) != 10:
                print(repr(line))
                raise ValueError
    if sent:
        doc.append(sent)
    if doc:
        docs.append(doc)
    return docs


def _make_gold(nlp, text, sent_annots):
    # Flatten the conll annotations, and adjust the head indices
    gold = defaultdict(list)
    for sent in sent_annots:
        gold["heads"].extend(len(gold["words"]) + head for head in sent["heads"])
        for field in ["words", "tags", "deps", "entities", "spaces"]:
            gold[field].extend(sent[field])
    # Construct text if necessary
    assert len(gold["words"]) == len(gold["spaces"])
    if text is None:
        text = "".join(
            word + " " * space for word, space in zip(gold["words"], gold["spaces"])
        )
    doc = nlp.make_doc(text)
    gold.pop("spaces")
    return doc, gold


#############################
# Data transforms for spaCy #
#############################


def golds_to_gold_data(docs, golds):
    """Get out the training data format used by begin_training."""
    data = []
    for doc, gold in zip(docs, golds):
        example = Example.from_dict(doc, gold)
        data.append(example)
    return data


##############
# Evaluation #
##############


def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
    with text_loc.open("r", encoding="utf8") as text_file:
        texts = split_text(text_file.read())
        docs = list(nlp.pipe(texts))
    with sys_loc.open("w", encoding="utf8") as out_file:
        write_conllu(docs, out_file)
    with gold_loc.open("r", encoding="utf8") as gold_file:
        gold_ud = conll17_ud_eval.load_conllu(gold_file)
        with sys_loc.open("r", encoding="utf8") as sys_file:
            sys_ud = conll17_ud_eval.load_conllu(sys_file)
        scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
    return scores


def write_conllu(docs, file_):
    merger = Matcher(docs[0].vocab)
    merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
    for i, doc in enumerate(docs):
        matches = merger(doc)
        spans = [doc[start : end + 1] for _, start, end in matches]
        offsets = [(span.start_char, span.end_char) for span in spans]
        for start_char, end_char in offsets:
            doc.merge(start_char, end_char)
        file_.write("# newdoc id = {i}\n".format(i=i))
        for j, sent in enumerate(doc.sents):
            file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
            file_.write("# text = {text}\n".format(text=sent.text))
            for k, token in enumerate(sent):
                file_.write(token._.get_conllu_lines(k) + "\n")
            file_.write("\n")


def print_progress(itn, losses, ud_scores):
    fields = {
        "dep_loss": losses.get("parser", 0.0),
        "tag_loss": losses.get("tagger", 0.0),
        "words": ud_scores["Words"].f1 * 100,
        "sents": ud_scores["Sentences"].f1 * 100,
        "tags": ud_scores["XPOS"].f1 * 100,
        "uas": ud_scores["UAS"].f1 * 100,
        "las": ud_scores["LAS"].f1 * 100,
    }
    header = ["Epoch", "Loss", "LAS", "UAS", "TAG", "SENT", "WORD"]
    if itn == 0:
        print("\t".join(header))
    tpl = "\t".join(
        (
            "{:d}",
            "{dep_loss:.1f}",
            "{las:.1f}",
            "{uas:.1f}",
            "{tags:.1f}",
            "{sents:.1f}",
            "{words:.1f}",
        )
    )
    print(tpl.format(itn, **fields))


# def get_sent_conllu(sent, sent_id):
#    lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]


def get_token_conllu(token, i):
    if token._.begins_fused:
        n = 1
        while token.nbor(n)._.inside_fused:
            n += 1
        id_ = "%d-%d" % (i, i + n)
        lines = [id_, token.text, "_", "_", "_", "_", "_", "_", "_", "_"]
    else:
        lines = []
    if token.head.i == token.i:
        head = 0
    else:
        head = i + (token.head.i - token.i) + 1
    fields = [
        str(i + 1),
        token.text,
        token.lemma_,
        token.pos_,
        token.tag_,
        "_",
        str(head),
        token.dep_.lower(),
        "_",
        "_",
    ]
    lines.append("\t".join(fields))
    return "\n".join(lines)


##################
# Initialization #
##################


def load_nlp(corpus, config):
    lang = corpus.split("_")[0]
    nlp = spacy.blank(lang)
    if config.vectors:
        nlp.vocab.from_disk(config.vectors / "vocab")
    return nlp


def initialize_pipeline(nlp, examples, config):
    nlp.add_pipe(nlp.create_pipe("parser"))
    if config.multitask_tag:
        nlp.parser.add_multitask_objective("tag")
    if config.multitask_sent:
        nlp.parser.add_multitask_objective("sent_start")
    nlp.parser.moves.add_action(2, "subtok")
    nlp.add_pipe(nlp.create_pipe("tagger"))
    for eg in examples:
        for tag in eg.get_aligned("TAG", as_string=True):
            if tag is not None:
                nlp.tagger.add_label(tag)
    # Replace labels that didn't make the frequency cutoff
    actions = set(nlp.parser.labels)
    label_set = set([act.split("-")[1] for act in actions if "-" in act])
    for eg in examples:
        gold = eg.gold
        for i, label in enumerate(gold.labels):
            if label is not None and label not in label_set:
                gold.labels[i] = label.split("||")[0]
    return nlp.begin_training(lambda: examples)


########################
# Command line helpers #
########################


@attr.s
class Config(object):
    vectors = attr.ib(default=None)
    max_doc_length = attr.ib(default=10)
    multitask_tag = attr.ib(default=True)
    multitask_sent = attr.ib(default=True)
    nr_epoch = attr.ib(default=30)
    batch_size = attr.ib(default=1000)
    dropout = attr.ib(default=0.2)

    @classmethod
    def load(cls, loc):
        with Path(loc).open("r", encoding="utf8") as file_:
            cfg = json.load(file_)
        return cls(**cfg)


class Dataset(object):
    def __init__(self, path, section):
        self.path = path
        self.section = section
        self.conllu = None
        self.text = None
        for file_path in self.path.iterdir():
            name = file_path.parts[-1]
            if section in name and name.endswith("conllu"):
                self.conllu = file_path
            elif section in name and name.endswith("txt"):
                self.text = file_path
        if self.conllu is None:
            msg = "Could not find .txt file in {path} for {section}"
            raise IOError(msg.format(section=section, path=path))
        if self.text is None:
            msg = "Could not find .txt file in {path} for {section}"
        self.lang = self.conllu.parts[-1].split("-")[0].split("_")[0]


class TreebankPaths(object):
    def __init__(self, ud_path, treebank, **cfg):
        self.train = Dataset(ud_path / treebank, "train")
        self.dev = Dataset(ud_path / treebank, "dev")
        self.lang = self.train.lang


@plac.annotations(
    ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
    parses_dir=("Directory to write the development parses", "positional", None, Path),
    config=("Path to json formatted config file", "positional", None, Config.load),
    corpus=(
        "UD corpus to train and evaluate on, e.g. UD_Spanish-AnCora",
        "positional",
        None,
        str,
    ),
    limit=("Size limit", "option", "n", int),
)
def main(ud_dir, parses_dir, config, corpus, limit=0):
    Token.set_extension("get_conllu_lines", method=get_token_conllu)
    Token.set_extension("begins_fused", default=False)
    Token.set_extension("inside_fused", default=False)

    Token.set_extension("get_conllu_lines", method=get_token_conllu)
    Token.set_extension("begins_fused", default=False)
    Token.set_extension("inside_fused", default=False)

    paths = TreebankPaths(ud_dir, corpus)
    if not (parses_dir / corpus).exists():
        (parses_dir / corpus).mkdir()
    print("Train and evaluate", corpus, "using lang", paths.lang)
    nlp = load_nlp(paths.lang, config)

    examples = read_data(
        nlp,
        paths.train.conllu.open(encoding="utf8"),
        paths.train.text.open(encoding="utf8"),
        max_doc_length=config.max_doc_length,
        limit=limit,
    )

    optimizer = initialize_pipeline(nlp, examples, config)

    for i in range(config.nr_epoch):
        batches = spacy.minibatch_by_words(examples, size=config.batch_size)
        losses = {}
        n_train_words = sum(len(eg.reference.doc) for eg in examples)
        with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
            for batch in batches:
                pbar.update(sum(len(eg.reference.doc) for eg in batch))
                nlp.update(
                    examples=batch, sgd=optimizer, drop=config.dropout, losses=losses,
                )

        out_path = parses_dir / corpus / "epoch-{i}.conllu".format(i=i)
        with nlp.use_params(optimizer.averages):
            scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
            print_progress(i, losses, scores)


if __name__ == "__main__":
    plac.call(main)