import hydra
import torch
import json

from os.path import dirname, abspath, join
import src.scripts.model.predict as p
from src.pl_modules import ConsecPLModule
from src.consec_dataset import ConsecSample, ConsecDefinition
from src.disambiguation_corpora import DisambiguationInstance


class SmartDictionary:
    def __init__(self, model_checkpoint_path=None, device=-1, text_encoding_strategy='simple-with-linker',
                 token_batch_size=1024, progress_bar=False):
        print('SmartDictionary initialized')
        if model_checkpoint_path is None:
            model_checkpoint_path = join(
                dirname(dirname(abspath(__file__))),
                'experiments/released-ckpts/consec_wngt_best.ckpt'
            )

        module = ConsecPLModule.load_from_checkpoint(model_checkpoint_path)
        module.to(torch.device(device if device != -1 else "cpu"))
        module.freeze()
        module.sense_extractor.evaluation_mode = True

        # load tokenizer
        tokenizer = hydra.utils.instantiate(module.hparams.tokenizer.consec_tokenizer)

        self.module = module
        self.tokenizer = tokenizer
        self.text_encoding_strategy = text_encoding_strategy
        self.token_batch_size = token_batch_size
        self.progress_bar = progress_bar

    def predict(self, context, target_position, candidate_definitions, parts_of_speech, sort_=True):
        """Returns a probability distribution."""
        n = len(candidate_definitions)
        candidate_definitions = [ConsecDefinition(defn, lemma) for lemma, defn in candidate_definitions]
        context_definitions = []

        # predict
        _, probs = next(
            p.predict(
                self.module,
                self.tokenizer,
                [
                    ConsecSample(
                        sample_id="interactive-d0",
                        position=target_position,
                        disambiguation_context=[
                            DisambiguationInstance("d0", "s0", "i0", t, None, None, None) for t in context
                        ],
                        candidate_definitions=candidate_definitions,
                        gold_definitions=None,
                        context_definitions=context_definitions,
                        in_context_sample_id2position={'interactive-d0': target_position},
                        disambiguation_instance=None,
                        kwargs={},
                    )
                ],
                text_encoding_strategy="simple-with-linker",  # todo hardcoded core param
            )
        )
        if sort_:
            indices = torch.tensor(probs).argsort(descending=True)
        else:
            indices = range(len(probs))
        predictions = [
            {
                'probability': probs[idx],
                'lemma': candidate_definitions[idx].linker,
                'definition': candidate_definitions[idx].text,
                'partOfSpeech': parts_of_speech[idx]
            }
            for idx in indices
        ]
        return predictions


def test_smart_dictionary():
    context = ['I', 'have', 'a', 'beautiful', 'dog']
    target_position = 4
    candidate_definitions = [
        ['dog', 'a member of the genus Canis'],
        ['dog', 'someone who is morally reprehensible'],
    ]

    smart = SmartDictionary()
    predictions = smart.predict(context, target_position, candidate_definitions)
    print(json.dumps(predictions, indent=4))


if __name__ == '__main__':
    test_smart_dictionary()
