from typing import List, Tuple

from ...pipeline import Lemmatizer
from ...tokens import Token


class DutchLemmatizer(Lemmatizer):
    @classmethod
    def get_lookups_config(cls, mode: str) -> Tuple[List[str], List[str]]:
        if mode == "rule":
            required = ["lemma_lookup", "lemma_rules", "lemma_exc", "lemma_index"]
            return (required, [])
        else:
            return super().get_lookups_config(mode)

    def lookup_lemmatize(self, token: Token) -> List[str]:
        """Overrides parent method so that a lowercased version of the string
        is used to search the lookup table. This is necessary because our
        lookup table consists entirely of lowercase keys."""
        lookup_table = self.lookups.get_table("lemma_lookup", {})
        string = token.text.lower()
        return [lookup_table.get(string, string)]

    # Note: CGN does not distinguish AUX verbs, so we treat AUX as VERB.
    def rule_lemmatize(self, token: Token) -> List[str]:
        # Difference 1: self.rules is assumed to be non-None, so no
        # 'is None' check required.
        # String lowercased from the get-go. All lemmatization results in
        # lowercased strings. For most applications, this shouldn't pose
        # any problems, and it keeps the exceptions indexes small. If this
        # creates problems for proper nouns, we can introduce a check for
        # univ_pos == "PROPN".
        cache_key = (token.lower, token.pos)
        if cache_key in self.cache:
            return self.cache[cache_key]
        string = token.text
        univ_pos = token.pos_.lower()
        if univ_pos in ("", "eol", "space"):
            forms = [string.lower()]
            self.cache[cache_key] = forms
            return forms

        index_table = self.lookups.get_table("lemma_index", {})
        exc_table = self.lookups.get_table("lemma_exc", {})
        rules_table = self.lookups.get_table("lemma_rules", {})
        index = index_table.get(univ_pos, {})
        exceptions = exc_table.get(univ_pos, {})
        rules = rules_table.get(univ_pos, {})

        string = string.lower()
        if univ_pos not in (
            "noun",
            "verb",
            "aux",
            "adj",
            "adv",
            "pron",
            "det",
            "adp",
            "num",
        ):
            forms = [string]
            self.cache[cache_key] = forms
            return forms
        lemma_index = index_table.get(univ_pos, {})
        # string is already lemma
        if string in lemma_index:
            forms = [string]
            self.cache[cache_key] = forms
            return forms
        exc_table = self.lookups.get_table("lemma_exc", {})
        exceptions = exc_table.get(univ_pos, {})
        # string is irregular token contained in exceptions index.
        try:
            forms = [exceptions[string][0]]
            self.cache[cache_key] = forms
            return forms
        except KeyError:
            pass
        # string corresponds to key in lookup table
        lookup_table = self.lookups.get_table("lemma_lookup", {})
        looked_up_lemma = lookup_table.get(string)
        if looked_up_lemma and looked_up_lemma in lemma_index:
            forms = [looked_up_lemma]
            self.cache[cache_key] = forms
            return forms
        rules_table = self.lookups.get_table("lemma_rules", {})
        oov_forms = []
        for old, new in rules:
            if string.endswith(old):
                form = string[: len(string) - len(old)] + new
                if not form:
                    pass
                elif form in index:
                    forms = [form]
                    self.cache[cache_key] = forms
                    return forms
                else:
                    oov_forms.append(form)
        forms = list(set(oov_forms))
        # Back-off through remaining return value candidates.
        if forms:
            for form in forms:
                if form in exceptions:
                    forms = [form]
                    self.cache[cache_key] = forms
                    return forms
            if looked_up_lemma:
                forms = [looked_up_lemma]
                self.cache[cache_key] = forms
                return forms
            else:
                self.cache[cache_key] = forms
                return forms
        elif looked_up_lemma:
            forms = [looked_up_lemma]
            self.cache[cache_key] = forms
            return forms
        else:
            forms = [string]
            self.cache[cache_key] = forms
            return forms