# coding: utf-8
from __future__ import unicode_literals

import pytest
from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from spacy.errors import ModelsWarning

from ..util import get_doc


@pytest.fixture
def doc(en_tokenizer):
    # fmt: off
    text = "This is a sentence. This is another sentence. And a third."
    heads = [1, 0, 1, -2, -3, 1, 0, 1, -2, -3, 0, 1, -2, -1]
    deps = ["nsubj", "ROOT", "det", "attr", "punct", "nsubj", "ROOT", "det",
            "attr", "punct", "ROOT", "det", "npadvmod", "punct"]
    # fmt: on
    tokens = en_tokenizer(text)
    return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps)


@pytest.fixture
def doc_not_parsed(en_tokenizer):
    text = "This is a sentence. This is another sentence. And a third."
    tokens = en_tokenizer(text)
    doc = Doc(tokens.vocab, words=[t.text for t in tokens])
    doc.is_parsed = False
    return doc


def test_spans_sent_spans(doc):
    sents = list(doc.sents)
    assert sents[0].start == 0
    assert sents[0].end == 5
    assert len(sents) == 3
    assert sum(len(sent) for sent in sents) == len(doc)


def test_spans_root(doc):
    span = doc[2:4]
    assert len(span) == 2
    assert span.text == "a sentence"
    assert span.root.text == "sentence"
    assert span.root.head.text == "is"


def test_spans_string_fn(doc):
    span = doc[0:4]
    assert len(span) == 4
    assert span.text == "This is a sentence"
    assert span.upper_ == "THIS IS A SENTENCE"
    assert span.lower_ == "this is a sentence"


def test_spans_root2(en_tokenizer):
    text = "through North and South Carolina"
    heads = [0, 3, -1, -2, -4]
    tokens = en_tokenizer(text)
    doc = get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads)
    assert doc[-2:].root.text == "Carolina"


def test_spans_span_sent(doc, doc_not_parsed):
    """Test span.sent property"""
    assert len(list(doc.sents))
    assert doc[:2].sent.root.text == "is"
    assert doc[:2].sent.text == "This is a sentence ."
    assert doc[6:7].sent.root.left_edge.text == "This"
    # test on manual sbd
    doc_not_parsed[0].is_sent_start = True
    doc_not_parsed[5].is_sent_start = True
    assert doc_not_parsed[1:3].sent == doc_not_parsed[0:5]
    assert doc_not_parsed[10:14].sent == doc_not_parsed[5:]


def test_spans_lca_matrix(en_tokenizer):
    """Test span's lca matrix generation"""
    tokens = en_tokenizer("the lazy dog slept")
    doc = get_doc(tokens.vocab, words=[t.text for t in tokens], heads=[2, 1, 1, 0])
    lca = doc[:2].get_lca_matrix()
    assert lca.shape == (2, 2)
    assert lca[0, 0] == 0  # the & the -> the
    assert lca[0, 1] == -1  # the & lazy -> dog (out of span)
    assert lca[1, 0] == -1  # lazy & the -> dog (out of span)
    assert lca[1, 1] == 1  # lazy & lazy -> lazy

    lca = doc[1:].get_lca_matrix()
    assert lca.shape == (3, 3)
    assert lca[0, 0] == 0  # lazy & lazy -> lazy
    assert lca[0, 1] == 1  # lazy & dog -> dog
    assert lca[0, 2] == 2  # lazy & slept -> slept

    lca = doc[2:].get_lca_matrix()
    assert lca.shape == (2, 2)
    assert lca[0, 0] == 0  # dog & dog -> dog
    assert lca[0, 1] == 1  # dog & slept -> slept
    assert lca[1, 0] == 1  # slept & dog -> slept
    assert lca[1, 1] == 1  # slept & slept -> slept


def test_span_similarity_match():
    doc = Doc(Vocab(), words=["a", "b", "a", "b"])
    span1 = doc[:2]
    span2 = doc[2:]
    with pytest.warns(ModelsWarning):
        assert span1.similarity(span2) == 1.0
        assert span1.similarity(doc) == 0.0
        assert span1[:1].similarity(doc.vocab["a"]) == 1.0


def test_spans_default_sentiment(en_tokenizer):
    """Test span.sentiment property's default averaging behaviour"""
    text = "good stuff bad stuff"
    tokens = en_tokenizer(text)
    tokens.vocab[tokens[0].text].sentiment = 3.0
    tokens.vocab[tokens[2].text].sentiment = -2.0
    doc = Doc(tokens.vocab, words=[t.text for t in tokens])
    assert doc[:2].sentiment == 3.0 / 2
    assert doc[-2:].sentiment == -2.0 / 2
    assert doc[:-1].sentiment == (3.0 + -2) / 3.0


def test_spans_override_sentiment(en_tokenizer):
    """Test span.sentiment property's default averaging behaviour"""
    text = "good stuff bad stuff"
    tokens = en_tokenizer(text)
    tokens.vocab[tokens[0].text].sentiment = 3.0
    tokens.vocab[tokens[2].text].sentiment = -2.0
    doc = Doc(tokens.vocab, words=[t.text for t in tokens])
    doc.user_span_hooks["sentiment"] = lambda span: 10.0
    assert doc[:2].sentiment == 10.0
    assert doc[-2:].sentiment == 10.0
    assert doc[:-1].sentiment == 10.0


def test_spans_are_hashable(en_tokenizer):
    """Test spans can be hashed."""
    text = "good stuff bad stuff"
    tokens = en_tokenizer(text)
    span1 = tokens[:2]
    span2 = tokens[2:4]
    assert hash(span1) != hash(span2)
    span3 = tokens[0:2]
    assert hash(span3) == hash(span1)


def test_spans_by_character(doc):
    span1 = doc[1:-2]
    span2 = doc.char_span(span1.start_char, span1.end_char, label="GPE")
    assert span1.start_char == span2.start_char
    assert span1.end_char == span2.end_char
    assert span2.label_ == "GPE"


def test_span_to_array(doc):
    span = doc[1:-2]
    arr = span.to_array([ORTH, LENGTH])
    assert arr.shape == (len(span), 2)
    assert arr[0, 0] == span[0].orth
    assert arr[0, 1] == len(span[0])


def test_span_as_doc(doc):
    span = doc[4:10]
    span_doc = span.as_doc()
    assert span.text == span_doc.text.strip()
    assert isinstance(span_doc, doc.__class__)
    assert span_doc is not doc
    assert span_doc[0].idx == 0


def test_span_string_label(doc):
    span = Span(doc, 0, 1, label="hello")
    assert span.label_ == "hello"
    assert span.label == doc.vocab.strings["hello"]


def test_span_string_set_label(doc):
    span = Span(doc, 0, 1)
    span.label_ = "hello"
    assert span.label_ == "hello"
    assert span.label == doc.vocab.strings["hello"]


def test_span_ents_property(doc):
    """Test span.ents for the """
    doc.ents = [
        (doc.vocab.strings["PRODUCT"], 0, 1),
        (doc.vocab.strings["PRODUCT"], 7, 8),
        (doc.vocab.strings["PRODUCT"], 11, 14),
    ]
    assert len(list(doc.ents)) == 3
    sentences = list(doc.sents)
    assert len(sentences) == 3
    assert len(sentences[0].ents) == 1
    # First sentence, also tests start of sentence
    assert sentences[0].ents[0].text == "This"
    assert sentences[0].ents[0].label_ == "PRODUCT"
    assert sentences[0].ents[0].start == 0
    assert sentences[0].ents[0].end == 1
    # Second sentence
    assert len(sentences[1].ents) == 1
    assert sentences[1].ents[0].text == "another"
    assert sentences[1].ents[0].label_ == "PRODUCT"
    assert sentences[1].ents[0].start == 7
    assert sentences[1].ents[0].end == 8
    # Third sentence ents, Also tests end of sentence
    assert sentences[2].ents[0].text == "a third ."
    assert sentences[2].ents[0].label_ == "PRODUCT"
    assert sentences[2].ents[0].start == 11
    assert sentences[2].ents[0].end == 14