import random
from typing import List, Optional, Tuple
from pydantic import BaseModel, Field
from geojson_pydantic import Polygon
from geojson_pydantic.features import Feature
from io import BytesIO
from PIL import Image
import sqlite3
import numpy as np


def get_lats_lons(geometry) -> Tuple[List[float], List[float]]:
    # from RFC 7946
    # >   a Geometry object is composed of either:
    # >      o  one position in the case of a Point geometry,
    # >
    # >      o  an array of positions in the case of a LineString or MultiPoint
    # >         geometry,
    # >
    # >      o  an array of LineString or linear ring (see Section 3.1.6)
    # >         coordinates in the case of a Polygon or MultiLineString geometry,
    # >         or
    # >
    # >      o  an array of Polygon coordinates in the case of a MultiPolygon
    # >         geometry.

    if geometry.type == 'Point':
        lats = [geometry.coordinates[1]]
        lons = [geometry.coordinates[0]]
    elif geometry.type == 'LineString' or geometry.type == 'MultiPoint':
        lats = [p[1] for p in geometry.coordinates]
        lons = [p[0] for p in geometry.coordinates]
    elif geometry.type == 'Polygon' or geometry.type == 'MultiLineString':
        lats = [p[1] for p in geometry.coordinates[0]]
        lons = [p[0] for p in geometry.coordinates[0]]
    else: # MultiPolygon
        lats = [p[1] for p in geometry.coordinates[0][0]]
        lons = [p[0] for p in geometry.coordinates[0][0]]
    
    return lats, lons



class TagGroup(BaseModel):
    tag_name: str
    tag_value: str
    splittable: bool = Field(default=False)


class SourceClass(BaseModel):
    tags: List[TagGroup]
    class_name: str


class LoaderConfig(BaseModel):
    raster_count: Optional[int]
    sample_raft: float
    sample_context_raft: float
    gamma: float
    max_patch_size: float
    dilation_size: float


class Configuration(BaseModel):
    classes: List[SourceClass]
    loader: LoaderConfig


def make_id(prefix: str) -> str:
    alphabet = 'abcdefghijklmnopqrstuvwxyz'
    out = f'{prefix}_'
    for _ in range(10):
        out += random.choice(alphabet)

    return out

def raster_array_to_image(raw_array: np.ndarray, gamma: float = 1.0):
    # TODO : replace this with something smarter

    # 8-Band Multispectral (COASTAL, BLUE, GREEN, YELLOW, RED, RED EDGE, NIR1, NIR2)
    #                      (0      , 1   , 2    , 3     , 4  , 5       , 6   , 7   )

    
    stack = raw_array[(0,0,0),:,:]
    stack = np.moveaxis(stack, 0,-1).astype(np.float64)
    stack /= 2048.0
    stack[stack >= 1.0] = 1.0
    
    # apply gamma correction
    stack = np.power(stack, 1/gamma)

    stack *= 255.0
    stack = stack.astype(np.uint8)
    return Image.fromarray(stack)


class Raster:
    filename:     str
    _polygon:      Optional[Polygon]
    _polygon_bbox: Optional[Polygon]
    raster_id:    str

    def __init__(self, filename):
        self.raster_id = make_id('raster')

        self.filename = filename

        self._polygon_bbox = None
        self._polygon = None

    @property
    def polygon(self) -> Polygon:
        if self._polygon is None:
            self._polygon = raster_to_polygon(self.filename, bbox_only=False)
        return self._polygon

    # boilerplate
    @polygon.setter
    def polygon(self, newpoly: Polygon):
        self._polygon = newpoly

    # boilerplate
    @polygon.deleter
    def polygon(self):
        del self._polygon

    @property
    def polygon_bbox(self) -> Polygon:
        if self._polygon_bbox is None:
            self._polygon_bbox = raster_to_polygon(self.filename, bbox_only=True)
        return self._polygon_bbox

    # boilerplate
    @polygon_bbox.setter
    def polygon_bbox(self, newpoly: Polygon):
        self._polygon_bbox = newpoly

    # boilerplate
    @polygon_bbox.deleter
    def polygon_bbox(self):
        del self._polygon_bbox


    def serialize(self, cursor):
        cursor.execute('''
            INSERT INTO Rasters (
                raster_id,
                polygon,
                filename
            ) VALUES (
                ?,
                ?,
                ?
            )
        ''', [
            self.raster_id,
            self.polygon.json(),
            self.filename
        ])


    def get_rio_handle(self):
        return rasterio.open(self.filename)

    def get_raw_bytes(self):
        with rasterio.open(self.filename) as src:
            return src.read()


    def create_image(self, gamma: float = 1.0):
        return raster_array_to_image(
            self.get_raw_bytes(),
            gamma=gamma
        )

    def get_vrt(self):
        # wrapper to get a context

        @contextlib.contextmanager
        def managed_vrt():
            with rasterio.open(self.filename) as src:
                # print('detected input CRS as', src.read_crs())

                # epsg codes: 3857 -> Web Mercator
                #             4326 -> WGS84
                with rasterio.vrt.WarpedVRT(
                        src,
                        crs='EPSG:4326', # WGS 84 by default
                        resampling=Resampling.bilinear
                    ) as vrt:
                        yield vrt

        return managed_vrt()


class Sample:
    raw_mask:    np.ndarray
    raw_context: np.ndarray
    feature:  Feature
    feature_bbox:  Feature
    feature_context:  Feature
    png:      bytes
    context_png:     bytes
    sample_id:       str
    source_raster:   Raster
    source_class:    SourceClass
    gamma: float

    def __init__(
            self,
            feature: Feature,
            feature_bbox: Feature,
            feature_context: Feature,
            raw_mask: np.ndarray,
            raw_context: np.ndarray,
            source_raster: Raster,
            source_class: SourceClass,
            gamma: float
        ):

        self.sample_id = make_id('sample')
        self.source_raster = source_raster
        self.source_class  = source_class

        self.feature = feature
        self.feature_bbox = feature_bbox
        self.feature_context = feature_context

        self.raw_mask = raw_mask
        self.raw_context = raw_context
        self.gamma = gamma

        # some quick setup to patch the ID of the sample into the features
        self.feature.id = self.sample_id
        self.feature_bbox.id = f'{self.sample_id}-bbox'
        self.feature_context.id = f'{self.sample_id}-context'

    
    def create_image(self):
        return raster_array_to_image(self.raw_mask, gamma=self.gamma)

    def create_png(self):
        # convert to a png
        io = BytesIO()
        png = self.create_image()
        png.save(io, format='PNG')

        # png.save(f'samples/{self.sample_id}.png')

        return io.getvalue()

    def create_context_image(self):
        return raster_array_to_image(self.raw_context, gamma=self.gamma)

    def create_context_png(self):
        # convert to a png
        io = BytesIO()
        png = self.create_context_image()
        png.save(io, format='PNG')

        # png.save(f'samples/{self.sample_id}.png')

        return io.getvalue()

    def prepare_row(self):
        # create the pngs
        self.png = self.create_png()
        self.context_png = self.create_context_png()
  
        # convert raw bytes to an NPY
        # raw_mask_bytes = BytesIO()
        # np.save(raw_mask_bytes, self.raw_mask, allow_pickle=False)
        
        # do some math
        lats, lons = get_lats_lons(self.feature.geometry)

        # calculate the bounds
        bbox_north = max(lats)
        bbox_south = min(lats)
        bbox_east  = max(lons)
        bbox_west  = min(lons)

        return [
            self.sample_id,
            self.source_class.class_name,
            self.source_raster.raster_id,

            self.feature.properties['type'],
            self.feature.properties['id'],
            json.dumps(self.feature.properties['tags']),
            self.feature.json(),
            self.feature_bbox.json(),
            self.feature_context.json(),

            self.png,
            self.context_png,

            bbox_east,
            bbox_west,
            bbox_north,
            bbox_south
        ]

    @staticmethod
    def write_rows(cursor, rows):
        cursor.executemany('''
            INSERT INTO Samples (
                sample_id,
                source_class,
                source_raster,

                osm_type,
                osm_id,
                osm_tags,
                feature,
                feature_bbox,
                feature_context,

                img_png,
                img_context,

                bbox_east,
                bbox_west,
                bbox_north,
                bbox_south
            ) VALUES (
                ?,
                ?,
                ?,

                ?,
                ?,
                ?,
                ?,
                ?,
                ?,

                ?,
                ?,

                ?,
                ?,
                ?,
                ?
            )
        ''', rows)

    def serialize(self, cursor):
        # create and serialize a single row
        row = self.prepare_row()
        return self.write_rows([row], cursor)

    @staticmethod
    def deserialize(sample_id: str, cursor: sqlite3.Cursor):
        cursor.execute('''
            SELECT * FROM Samples
                WHERE sample_id=?
        ''', [sample_id])

        row = cursor.fetchone()

        # prep to load in numpy file
        # io = BytesIO()
        # io.write(row['img_raw'])
        # io.seek(0) # rewind to start of file
        # raw_mask = np.load(io)
        
        cursor.execute('''
            SELECT * FROM Tags
                WHERE source_class=?
        ''', [row['source_class']])

        tags = [TagGroup.parse_obj(r) for r in cursor.fetchall()]

        sample = Sample(
            Feature.parse_raw(row['feature']),
            Feature.parse_raw(row['feature_bbox']),
            Feature.parse_raw(row['feature_context']),
            None,
            None,
            None,
            SourceClass(tags=tags, class_name=row['source_class']),
            1.0
        )

        sample.png = row['img_png']
        sample.context_png = row['img_context']
        sample.sample_id = sample_id

        return sample

    def get_image(self):
        io = BytesIO()
        io.write(self.png)
        io.seek(0) # rewind to start of file
        return Image.open(io)

    def get_context_image(self):
        io = BytesIO()
        io.write(self.context_png)
        io.seek(0) # rewind to start of file
        return Image.open(io)
