import random
from typing import List, Optional, Tuple
from pydantic import BaseModel
from geojson_pydantic import Polygon
from geojson_pydantic.features import Feature
from io import BytesIO
from PIL import Image
import sqlite3
import numpy as np

class TagGroup(BaseModel):
    tag_name: str
    tag_value: str

class SourceClass(BaseModel):
    tags: List[TagGroup]
    class_name: str


class Configuration(BaseModel):
    classes: List[SourceClass]
    raster_count: Optional[int]


def make_id(prefix: str) -> str:
    alphabet = 'abcdefghijklmnopqrstuvwxyz'
    out = f'{prefix}_'
    for _ in range(10):
        out += random.choice(alphabet)

    return out


def get_lats_lons(geometry) -> Tuple[List[float], List[float]]:
    # from RFC 7946
    # >   a Geometry object is composed of either:
    # >      o  one position in the case of a Point geometry,
    # >
    # >      o  an array of positions in the case of a LineString or MultiPoint
    # >         geometry,
    # >
    # >      o  an array of LineString or linear ring (see Section 3.1.6)
    # >         coordinates in the case of a Polygon or MultiLineString geometry,
    # >         or
    # >
    # >      o  an array of Polygon coordinates in the case of a MultiPolygon
    # >         geometry.

    if geometry.type == 'Point':
        lats = [geometry.coordinates[1]]
        lons = [geometry.coordinates[0]]
    elif geometry.type == 'LineString' or geometry.type == 'MultiPoint':
        lats = [p[1] for p in geometry.coordinates]
        lons = [p[0] for p in geometry.coordinates]
    elif geometry.type == 'Polygon' or geometry.type == 'MultiLineString':
        lats = [p[1] for p in geometry.coordinates[0]]
        lons = [p[0] for p in geometry.coordinates[0]]
    else: # MultiPolygon
        lats = [p[1] for p in geometry.coordinates[0][0]]
        lons = [p[0] for p in geometry.coordinates[0][0]]
    
    return lats, lons



class Raster:
    filename:     str
    polygon:      Polygon
    polygon_bbox: Polygon
    raster_id:    str

    def __init__(self, filename):
        self.raster_id = make_id('raster')

        self.filename = filename
        # self.polygon = raster_to_polygon(product, bbox_only=False)
        # self.polygon_bbox = raster_to_polygon(product, bbox_only=True)

    def serialize(self, cursor):
        cursor.execute('''
            INSERT INTO Rasters (
                raster_id,
                polygon,
                filename
            ) VALUES (
                ?,
                ?,
                ?
            )
        ''', [
            self.raster_id,
            self.polygon.json(),
            self.filename
        ])

class Sample:
    raw_mask: np.ndarray
    feature:  Feature
    png:      bytes
    sample_id:       str
    source_raster:   Raster
    source_class:    SourceClass

    def __init__(self, feature: Feature, raw_mask: np.ndarray, source_raster: Raster, source_class: SourceClass):
        self.sample_id = make_id('sample')
        self.source_raster = source_raster
        self.source_class  = source_class
    
        self.feature = feature
        self.raw_mask = raw_mask

    def create_png(self):
        # TODO : replace this with something smarter

        # 8-Band Multispectral (COASTAL, BLUE, GREEN, YELLOW, RED, RED EDGE, NIR1, NIR2)
        #                      (0      , 1   , 2    , 3     , 4  , 5       , 6   , 7   )

        
        stack = self.raw_mask[(0,0,0),:,:]
        stack = np.moveaxis(stack, 0,-1).astype(np.float64)

        stack /= 2048.0
        stack[stack >= 1.0] = 1.0
        stack *= 255.0
        stack = stack.astype(np.uint8)

        # convert to a png
        io = BytesIO()
        png = Image.fromarray(stack)
        png.save(io, format='PNG')

        # png.save(f'samples/{self.sample_id}.png')

        return io.getvalue()


    def serialize(self, cursor):
        # create the png
        self.png = self.create_png()
  
        # convert raw bytes to an NPY
        raw_mask_bytes = BytesIO()
        np.save(raw_mask_bytes, self.raw_mask, allow_pickle=False)
        
        # do some math
        lats, lons = get_lats_lons(self.feature.geometry)

        # calculate the bounds
        bbox_north = max(lats)
        bbox_south = min(lats)
        bbox_east  = max(lons)
        bbox_west  = min(lons)


        cursor.execute('''
            INSERT INTO Samples (
                sample_id,
                source_class,
                source_raster,

                osm_type,
                osm_id,
                osm_tags,
                feature,

                img_png,
                img_raw,

                bbox_east,
                bbox_west,
                bbox_north,
                bbox_south
            ) VALUES (
                ?,
                ?,
                ?,

                ?,
                ?,
                ?,
                ?,

                ?,
                ?,

                ?,
                ?,
                ?,
                ?
            )
        ''', [
            self.sample_id,
            self.source_class.class_name,
            self.source_raster.raster_id,

            self.feature.properties['type'],
            self.feature.properties['id'],
            json.dumps(self.feature.properties['tags']),
            self.feature.json(),

            self.png,
            raw_mask_bytes.getvalue(),

            bbox_east,
            bbox_west,
            bbox_north,
            bbox_south
        ])

    @staticmethod
    def deserialize(sample_id: str, cursor: sqlite3.Cursor):
        cursor.execute('''
            SELECT * FROM Samples
                WHERE sample_id=?
        ''', [sample_id])

        row = cursor.fetchone()

        # prep to load in numpy file
        io = BytesIO()
        io.write(row['img_raw'])
        io.seek(0) # rewind to start of file
        raw_mask = np.load(io)
        
        cursor.execute('''
            SELECT * FROM Tags
                WHERE source_class=?
        ''', [row['source_class']])

        tags = [TagGroup.parse_obj(dict(r)) for r in cursor.fetchall()]

        sample = Sample(
            Feature.parse_raw(row['feature']),
            raw_mask,
            None,
            SourceClass(tags=tags, class_name=row['source_class'])
        )

        sample.png = row['img_png']
        sample.sample_id = sample_id

        return sample

    def get_image(self):
        io = BytesIO()
        io.write(self.png)
        io.seek(0) # rewind to start of file
        return Image.open(io)

