from .loader import DatasetLoader
from .llm import BaseLLM
from rich.progress import track
from . import console
import json
import os
import time
from pathlib import Path
from typing import Dict
from .download import download_dataset

SAMPLE_SIZE = 1500

def evaluate_model(
        llm: BaseLLM,
        dataset_name: str,
        prompt: str,
        sample_size: int = SAMPLE_SIZE,
        output_dir: Path = Path('./samples'),
        use_context_image: bool = False,
        infer_kwargs: Dict = {}
    ):

    # sample loader
    console.log(f'will write samples to {output_dir}')
    if not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=True)
        console.log(f'created directory {output_dir}')

    # ensure the dataset is downloaded
    console.log('ensuring dataset is downloaded...')
    dataset_path = Path().joinpath(os.path.basename(dataset_name))
    download_dataset(dataset_name, dataset_path)

    dl = DatasetLoader(dataset_path)
    console.log(dl.get_classes())

    loader = dl.get_sample_loader()

    # sample_len = len(glob('samples/*.txt'))
    sample_len = len(list(output_dir.glob('*.txt')))

    loader.set_limit(sample_size - sample_len)

    console.log(f'detected {sample_len} extant samples, {sample_size - sample_len} remaining...')


    with console.status('Loading model...'):
        llm.load_model()

    for sample in track(loader, description='Running sample inference...', total=sample_size, console=console):
        # if os.path.exists(f'samples/{sample.sample_id}.txt'):
        if output_dir.joinpath(f'{sample.sample_id}.txt').exists():
            if loader.limit is not None:
                loader.limit += 1
            continue

        console.log(f'eval {sample.sample_id}...')

        if use_context_image:
            img = sample.get_context_image()
        else:
            img = sample.get_image()

        time_start = time.time()

        output = llm.infer(
            img,
            prompt,
            # 'Is the object in the center of this satellite image most likely a road, urban park, nature reserve, parking lot, building, or industrial area? Explain your reasoning.',
            # repetition_penalty=10,
            # top_p=0.8
            # decoding_method='Nucleus sampling'
            **infer_kwargs
        )
        time_total = time.time() - time_start

        metadata_obj = {
            'sample_id': sample.sample_id,
            'source_class': sample.source_class,
            'elapsed_sec': time_total
        }

        output = f'## {json.dumps(metadata_obj)}' + '\n\n' + output

        console.log(output)
        
        img.save(output_dir.joinpath(f'{sample.sample_id}.png'))
        with open(output_dir.joinpath(f'{sample.sample_id}.txt'), 'w') as fp:
            fp.write(output)
