from argparse import ArgumentParser
from collections import namedtuple
from pathlib import Path
import re
import sys
from typing import List, Optional

from rich.pretty import pprint
from rich.progress import track

from . import console
from .loader import DatasetLoader


SampleReport = namedtuple('SampleReport', ['sample_id', 'output', 'true_class', 'predicted_classes'])



def process_sample_report(filename: str, class_list: List[str]) -> SampleReport:
    sample_id = filename.split('/')[1].split('.')[0]
    # print(sample_id)

    with open(filename, 'r') as fp:
        text = fp.read()

    meta, output = text.split('\n\n', maxsplit=1)

    # console.print('meta', meta)
    # console.print('output', output)

    true_class = re.search('class_name=\'(\\w+)\'', meta)
    assert true_class is not None
    true_class = true_class.group(1)

    # console.print('true class', true_class.group(1))

    matches = []
    for cls in class_list:
        cls = cls.replace('_', ' ')

        match = re.search(cls, output.lower())
        if match is not None:
            matches.append(match)

    # we are ranking the class output by the order in which they appear in the dataset.
    matches.sort(key=lambda m: m.span()[0])
    matches = [m.group(0).replace(' ', '_') for m in matches]

    # pprint(matches, console=console)
    
    return SampleReport(
        sample_id=sample_id,
        true_class=true_class,
        output = output,
        predicted_classes=matches
    )

def score_top_k(all_reports: List[SampleReport], k: int, class_focus: Optional[str] = None):
    if class_focus is not None:
        all_reports = [r for r in all_reports if r.true_class == class_focus]
        if len(all_reports) == 0:
            console.print(f'[yellow]no samples for class {class_focus}[/]')
            return 0

    correct = 0.0
    for rep in all_reports:
        if rep.true_class in rep.predicted_classes[:k]:
            correct += 1
    return correct / len(all_reports)

def collect(dataset: Path, samples: Path):
    dl = DatasetLoader(dataset)

    class_list = [c.class_name for c in dl.get_classes()]
    pprint(class_list, console=console)

    all_sample_files = samples.glob('*.txt')
    all_reports      = []

    for fn in track(all_sample_files, description='Processing samples', console=console):
        # print(fn)
        try:
            report = process_sample_report(str(fn), class_list)
            all_reports.append(report)
        except:
            console.print(f'broke at: [blue]{fn}[/]')
            sys.exit(1)
        # break

    console.print(f'munged {len(all_reports)} reports.')

    console.print('overall accuracy')
    for k in range(1,6):
        top_k = score_top_k(all_reports, k)
        console.print(f'    top {k}: {top_k * 100:0.02f}%')

    for cl in class_list:
        console.print(f'accuracy on class [blue]{cl}[/]:')
        for k in range(1,6):
            top_k = score_top_k(all_reports, k, class_focus=cl)
            console.print(f'    top {k}: {top_k * 100:0.02f}%')

    no_classes = len([r for r in all_reports if len(r.predicted_classes) == 0])
    console.print(f'samples with no predictions: {no_classes} ({(100 * no_classes / len(all_reports)):0.02f}%)' )

    top_classes = [r.predicted_classes[0] for r in all_reports if len(r.predicted_classes) != 0]
    class_hist = { cl:top_classes.count(cl) for cl in class_list }
    class_hist = sorted(class_hist.items(), key=lambda p: p[0])

    console.print('class histogram:')
    for cls, count in class_hist:
        console.print(f'    [blue]{cls}[/]: {count}')
    # pprint(class_hist)


def cli_entrypoint():
    parser = ArgumentParser(
        prog='fisi-collect',
        description='munge all the sample reports from a fisi run',
        epilog=''
    )

    parser.add_argument('dataset')
    parser.add_argument('-s', '--samples', default='samples')

    args = parser.parse_args()

    pprint(args, console=console)

    dataset = Path(args.dataset)
    samples = Path(args.samples)

    assert dataset.exists()
    assert samples.exists()
    collect(
        dataset,
        samples
    )
