import sqlite3
from typing import Optional, List
from .models import Sample, SourceClass, TagGroup
from pathlib import Path

class DatasetLoader:
    conn: sqlite3.Connection

    def __init__(self, filename: Path):
        self.conn = sqlite3.connect(filename)
        self.conn.row_factory = sqlite3.Row

    def get_sample_loader(self):
        return SampleLoader(self.conn.cursor())
    
    def get_classes(self) -> List[SourceClass]:
        c = self.conn.cursor()
        c.execute('''
            SELECT * FROM SourceClasses
        ''')
        rows = c.fetchall()

        classes = []
        for row in rows:
            c.execute('''
                SELECT * FROM Tags
                    WHERE source_class=?
            ''', [row['class_name']])
            tags = [TagGroup.parse_obj(dict(r)) for r in c.fetchall()]

            classes.append(
                SourceClass(
                    class_name=row['class_name'],
                    tags=tags
                )
            )
        
        return classes

class ResumableSampleLoader:
    limit: int
    queue_file_name: str

    def __init__(self, cursor: sqlite3.Cursor, limit: int, queue_file=None):
        self.cursor = cursor
        self.limit  = limit
        self.queue_file_name = queue_file if queue_file is not None else 'sample_queue.txt'

    def set_limit(self, limit):
        self.limit = limit

    def init_queue(self):
        pass

    def __iter__(self):
        return self

    def __next__(self):
        return None


class SampleLoader:
    limit: Optional[int]

    def __init__(self, cursor: sqlite3.Cursor, limit=None):
        self.cursor = cursor
        self.limit = limit

    def set_limit(self, limit):
        self.limit = limit

    def __next__(self):
        if self.limit is not None:
            if self.limit <= 0:
                raise StopIteration() 

        self.cursor.execute('''
            SELECT sample_id FROM Samples
                ORDER BY RANDOM()
                LIMIT 1
        ''')
        
        sample_id = self.cursor.fetchone()['sample_id']

        self.limit -= 1
        return Sample.deserialize(sample_id, self.cursor)
        # raise StopIteration()
    
    def __iter__(self):
        return self
